Skip to content

Commit adcbc2c

Browse files
committed
improve LinkMatrixBuilder
1 parent 14a6fd2 commit adcbc2c

File tree

1 file changed

+147
-17
lines changed

1 file changed

+147
-17
lines changed

src/field/link_matrix_builder.rs

Lines changed: 147 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,79 @@
11
use std::num::NonZeroUsize;
22

33
#[cfg(feature = "serde-serialize")]
4-
use serde::Serialize;
4+
use serde::{Deserialize, Serialize};
55

66
use super::LinkMatrix;
77
use crate::lattice::LatticeCyclique;
88
use crate::CMatrix3;
99

10-
#[derive(Debug, PartialEq)]
10+
#[non_exhaustive]
11+
#[derive(Debug, PartialEq, Clone)]
1112
#[cfg_attr(feature = "serde-serialize", derive(Serialize))]
12-
enum LinkMatrixBuilderType<'a, 'lat, Rng: rand::Rng + ?Sized, const D: usize> {
13-
Generated(&'lat LatticeCyclique<D>, GenType<'a, Rng>),
13+
enum LinkMatrixBuilderType<'rng, 'lat, Rng: rand::Rng + ?Sized, const D: usize> {
14+
/// Generate data procedurally
15+
Generated(&'lat LatticeCyclique<D>, GenType<'rng, Rng>),
16+
/// Data already existing
1417
Data(Vec<CMatrix3>),
1518
}
1619

20+
/// Type of generation
21+
#[non_exhaustive]
1722
#[derive(Debug, PartialEq)]
18-
#[cfg_attr(feature = "serde-serialize", derive(Serialize))]
19-
enum GenType<'a, Rng: rand::Rng + ?Sized> {
23+
#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
24+
enum GenType<'rng, Rng: rand::Rng + ?Sized> {
25+
/// Cold generation all ellements are set to the default
2026
Cold,
21-
Hot(&'a mut Rng),
27+
/// Random deterministe
28+
#[cfg_attr(feature = "serde-serialize", serde(skip_deserializing))]
29+
HotDeterministe(&'rng mut Rng),
30+
/// Random deterministe but own the RNG (for instance the result of `clone`)
31+
HotDeterministeOwned(Box<Rng>),
32+
/// Random threaded (non deterministe)
2233
HotThreaded(NonZeroUsize),
2334
}
2435

25-
impl<'a, 'lat, Rng: rand::Rng + ?Sized, const D: usize> LinkMatrixBuilderType<'a, 'lat, Rng, D> {
36+
impl<'rng, Rng: rand::Rng + Clone + ?Sized> Clone for GenType<'rng, Rng> {
37+
fn clone(&self) -> Self {
38+
match self {
39+
Self::Cold => Self::Cold,
40+
Self::HotDeterministe(rng_ref) => {
41+
Self::HotDeterministeOwned(Box::new((*rng_ref).clone()))
42+
}
43+
Self::HotDeterministeOwned(rng_box) => Self::HotDeterministeOwned(rng_box.clone()),
44+
Self::HotThreaded(n) => Self::HotThreaded(*n),
45+
}
46+
}
47+
}
48+
49+
impl<'rng, 'lat, Rng: rand::Rng + ?Sized, const D: usize>
50+
LinkMatrixBuilderType<'rng, 'lat, Rng, D>
51+
{
2652
pub fn into_link_matrix(self) -> LinkMatrix {
2753
match self {
2854
Self::Data(data) => LinkMatrix::new(data),
2955
Self::Generated(l, gen_type) => match gen_type {
3056
GenType::Cold => LinkMatrix::new_cold(l),
31-
GenType::Hot(rng) => LinkMatrix::new_deterministe(l, rng),
57+
GenType::HotDeterministe(rng) => LinkMatrix::new_deterministe(l, rng),
3258
// the unwrap is safe because n is non zero
3359
// there is a possibility to panic in a thread but very unlikly
3460
// (either something break in this API or in thread_rng())
61+
GenType::HotDeterministeOwned(mut rng_box) => {
62+
LinkMatrix::new_deterministe(l, &mut rng_box)
63+
}
3564
GenType::HotThreaded(n) => LinkMatrix::new_random_threaded(l, n.get()).unwrap(),
3665
},
3766
}
3867
}
3968
}
4069

41-
#[derive(Debug, PartialEq)]
70+
#[derive(Debug, PartialEq, Clone)]
4271
#[cfg_attr(feature = "serde-serialize", derive(Serialize))]
43-
pub struct LinkMatrixBuilder<'a, 'lat, Rng: rand::Rng + ?Sized, const D: usize> {
44-
builder_type: LinkMatrixBuilderType<'a, 'lat, Rng, D>,
72+
pub struct LinkMatrixBuilder<'rng, 'lat, Rng: rand::Rng + ?Sized, const D: usize> {
73+
builder_type: LinkMatrixBuilderType<'rng, 'lat, Rng, D>,
4574
}
4675

47-
impl<'a, 'lat, Rng: rand::Rng + ?Sized, const D: usize> LinkMatrixBuilder<'a, 'lat, Rng, D> {
76+
impl<'rng, 'lat, Rng: rand::Rng + ?Sized, const D: usize> LinkMatrixBuilder<'rng, 'lat, Rng, D> {
4877
pub fn new_from_data(data: Vec<CMatrix3>) -> Self {
4978
Self {
5079
builder_type: LinkMatrixBuilderType::Data(data),
@@ -57,7 +86,7 @@ impl<'a, 'lat, Rng: rand::Rng + ?Sized, const D: usize> LinkMatrixBuilder<'a, 'l
5786
}
5887
}
5988

60-
pub fn set_cold(&mut self) -> &mut Self {
89+
pub fn set_cold(mut self) -> Self {
6190
match self.builder_type {
6291
LinkMatrixBuilderType::Data(_) => {}
6392
LinkMatrixBuilderType::Generated(l, _) => {
@@ -67,17 +96,18 @@ impl<'a, 'lat, Rng: rand::Rng + ?Sized, const D: usize> LinkMatrixBuilder<'a, 'l
6796
self
6897
}
6998

70-
pub fn set_hot(&mut self, rng: &'a mut Rng) -> &mut Self {
99+
pub fn set_hot_deterministe(mut self, rng: &'rng mut Rng) -> Self {
71100
match self.builder_type {
72101
LinkMatrixBuilderType::Data(_) => {}
73102
LinkMatrixBuilderType::Generated(l, _) => {
74-
self.builder_type = LinkMatrixBuilderType::Generated(l, GenType::Hot(rng));
103+
self.builder_type =
104+
LinkMatrixBuilderType::Generated(l, GenType::HotDeterministe(rng));
75105
}
76106
}
77107
self
78108
}
79109

80-
pub fn set_hot_threaded(&mut self, number_of_threads: NonZeroUsize) -> &mut Self {
110+
pub fn set_hot_threaded(mut self, number_of_threads: NonZeroUsize) -> Self {
81111
match self.builder_type {
82112
LinkMatrixBuilderType::Data(_) => {}
83113
LinkMatrixBuilderType::Generated(l, _) => {
@@ -92,3 +122,103 @@ impl<'a, 'lat, Rng: rand::Rng + ?Sized, const D: usize> LinkMatrixBuilder<'a, 'l
92122
self.builder_type.into_link_matrix()
93123
}
94124
}
125+
126+
#[doc(hidden)]
127+
impl<'rng, 'lat, Rng: rand::Rng + ?Sized, const D: usize>
128+
From<LinkMatrixBuilderType<'rng, 'lat, Rng, D>> for LinkMatrixBuilder<'rng, 'lat, Rng, D>
129+
{
130+
fn from(builder_type: LinkMatrixBuilderType<'rng, 'lat, Rng, D>) -> Self {
131+
Self { builder_type }
132+
}
133+
}
134+
135+
impl<'rng, 'lat, Rng: rand::Rng + ?Sized, const D: usize>
136+
From<LinkMatrixBuilder<'rng, 'lat, Rng, D>> for LinkMatrix
137+
{
138+
fn from(builder: LinkMatrixBuilder<'rng, 'lat, Rng, D>) -> Self {
139+
builder.build()
140+
}
141+
}
142+
143+
#[cfg(test)]
144+
mod test {
145+
use std::num::NonZeroUsize;
146+
147+
use rand::rngs::StdRng;
148+
use rand::SeedableRng;
149+
150+
use super::*;
151+
use crate::error::LatticeInitializationError;
152+
153+
const SEED_RNG: u64 = 0x45_78_93_f4_4a_b0_67_f0;
154+
155+
#[test]
156+
fn builder() -> Result<(), LatticeInitializationError> {
157+
let lattice = LatticeCyclique::<3>::new(1_f64, 10)?;
158+
let m = LinkMatrixBuilder::<'_, '_, rand::rngs::ThreadRng, 3>::new_generated(&lattice)
159+
.set_cold()
160+
.build();
161+
assert_eq!(m, LinkMatrix::new_cold(&lattice));
162+
163+
let mut rng = StdRng::seed_from_u64(SEED_RNG);
164+
let builder = LinkMatrixBuilder::<'_, '_, _, 3>::new_generated(&lattice)
165+
.set_hot_deterministe(&mut rng);
166+
let m = builder.clone().build();
167+
assert_eq!(m, builder.build());
168+
let _ = LinkMatrixBuilder::<'_, '_, rand::rngs::ThreadRng, 3>::new_generated(&lattice)
169+
.set_hot_threaded(NonZeroUsize::new(rayon::current_num_threads().min(1)).unwrap())
170+
.build();
171+
assert!(LinkMatrixBuilder::<'_, '_, _, 3>::new_from_data(vec![])
172+
.set_cold()
173+
.set_hot_deterministe(&mut rng)
174+
.set_hot_threaded(NonZeroUsize::new(1).unwrap())
175+
.build()
176+
.is_empty());
177+
assert_eq!(
178+
LinkMatrixBuilder::<'_, '_, rand::rngs::ThreadRng, 3>::new_from_data(
179+
vec![CMatrix3::identity(); 5]
180+
)
181+
.build()
182+
.as_ref(),
183+
vec![CMatrix3::identity(); 5]
184+
);
185+
assert_eq!(
186+
LinkMatrix::from(
187+
LinkMatrixBuilder::<'_, '_, rand::rngs::ThreadRng, 3>::new_from_data(
188+
vec![CMatrix3::identity(); 100]
189+
)
190+
)
191+
.as_ref(),
192+
vec![CMatrix3::identity(); 100]
193+
);
194+
Ok(())
195+
}
196+
197+
#[test]
198+
fn gen_type() {
199+
let mut rng = StdRng::seed_from_u64(SEED_RNG);
200+
assert_eq!(
201+
GenType::<'_, StdRng>::Cold.clone(),
202+
GenType::<'_, StdRng>::Cold
203+
);
204+
assert_eq!(
205+
GenType::HotDeterministeOwned(Box::new(rng.clone())).clone(),
206+
GenType::HotDeterministeOwned(Box::new(rng.clone()))
207+
);
208+
assert_eq!(
209+
GenType::<'_, StdRng>::HotThreaded(NonZeroUsize::new(1).unwrap()).clone(),
210+
GenType::<'_, StdRng>::HotThreaded(NonZeroUsize::new(1).unwrap())
211+
);
212+
let gen_type = GenType::HotDeterministe(&mut rng);
213+
assert_ne!(gen_type.clone(), gen_type);
214+
}
215+
216+
#[test]
217+
fn trait_misc() {
218+
let builder_type = LinkMatrixBuilderType::<'_, '_, StdRng, 10>::Data(vec![]);
219+
assert_eq!(
220+
LinkMatrixBuilder::from(builder_type.clone()).builder_type,
221+
builder_type
222+
);
223+
}
224+
}

0 commit comments

Comments
 (0)