Skip to content

Commit

Permalink
Update potential/density buffers without reallocation.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Mar 22, 2024
1 parent a64a559 commit fd88880
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions mlspm/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def __init__(self, samples: list[TarSampleList], base_path: PathLike = "./", n_p
self.samples = samples
self.base_path = Path(base_path)
self.n_proc = n_proc
self.pot = None
self.rho = None

def __len__(self) -> int:
"""Total number of samples (including rotations)"""
Expand Down Expand Up @@ -292,24 +294,40 @@ def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: m

def _get_queue_sample(self):

if self._timings:
t0 = time.perf_counter()

i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots = self.q.get(timeout=200)

if self._timings:
t1 = time.perf_counter()

shm_pot = mp.shared_memory.SharedMemory(sample_id_pot)
pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf)
pot = HartreePotential(pot, lvec_pot)
# This starts a copy to the OpenCL device. Better to start it here so that buffer preparation is instant during the simulation.
pot.cl_array
if self.pot is None:
self.pot = HartreePotential(pot, lvec_pot)
else:
self.pot.update_array(pot, lvec_pot)

if self._timings:
t2 = time.perf_counter()

if sample_id_rho is not None:
shm_rho = mp.shared_memory.SharedMemory(sample_id_rho)
rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf)
rho = ElectronDensity(rho, lvec_rho)
rho.cl_array
if self.rho is None:
self.rho = ElectronDensity(rho, lvec_rho)
else:
self.rho.update_array(pot, lvec_pot)
else:
shm_rho = None
rho = None

return i_proc, xyzs, Zs, rots, pot, shm_pot, rho, shm_rho, sample_id_pot
if self._timings:
t3 = time.perf_counter()
print(f"[Main, receive data, id {sample_id_pot}] Queue / Pot / Rho: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}")

return i_proc, xyzs, Zs, rots, self.pot, shm_pot, self.rho, shm_rho, sample_id_pot

def _yield_samples(self):

Expand Down

0 comments on commit fd88880

Please sign in to comment.