Skip to content

Commit

Permalink
Option for asynchronous writes in TarWriter.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Mar 25, 2024
1 parent b0b5b2a commit 9e9cc8f
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 38 deletions.
99 changes: 69 additions & 30 deletions mlspm/data_generation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import io
import multiprocessing as mp
import os
import queue
import tarfile
import time
from multiprocessing.shared_memory import SharedMemory
from os import PathLike
from pathlib import Path
from typing import Optional, TypedDict
import warnings

import numpy as np
from PIL import Image
Expand All @@ -19,65 +21,87 @@ class TarWriter:
:meth:`add_sample`.
Each tar file has a maximum number of samples, and whenever that maximum is reached, a new tar file is created.
The generated tar files are named as ``{base_name}_{n}.tar`` and saved into the specified folder. The current tar file
handle is always available in the attribute :attr:`ft`, and is automatically closed when the context ends.
The generated tar files are named as ``{base_name}_{n}.tar`` and saved into the specified folder.
Arguments:
base_path: Path to directory where tar files are saved.
base_name: Base name for output tar files. The number of the tar file is appended to the name.
max_count: Maximum number of samples per tar file.
png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower
write speed.
async_write: Write tar files asynchronously in a parallel process.
"""

def __init__(self, base_path: PathLike = "./", base_name: str = "", max_count: int = 100, png_compress_level=4):
def __init__(self, base_path: PathLike = "./", base_name: str = "", max_count: int = 100, async_write=True):
self.base_path = Path(base_path)
self.base_name = base_name
self.max_count = max_count
self.png_compress_level = png_compress_level
self.async_write = async_write

def __enter__(self):
self.sample_count = 0
self.total_count = 0
self.tar_count = 0
self.ft = self._get_tar_file()
if self.async_write:
self._launch_write_process()
else:
self._ft = self._get_tar_file()
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.ft.close()
if self.async_write:
self._event_done.set()
if not self._event_tar_close.wait(60):
warnings.warn("Write process did not respond within timeout period. Last tar file may not have been closed properly.")
else:
self._ft.close()

def _launch_write_process(self):
self._q = mp.Queue(1)
self._event_done = mp.Event()
self._event_tar_close = mp.Event()
p = mp.Process(target=self._write_async)
p.start()

def _write_async(self):
self._ft = self._get_tar_file()
try:
while True:
try:
sample = self._q.get(block=False)
self._add_sample(*sample)
continue
except queue.Empty:
pass
if self._event_done.is_set() and self._q.empty():
self._ft.close()
self._event_tar_close.set()
return
except:
self._ft.close()
self._event_tar_close.set()

def _get_tar_file(self):
file_path = self.base_path / f"{self.base_name}_{self.tar_count}.tar"
if os.path.exists(file_path):
raise RuntimeError(f"Tar file already exists at `{file_path}`")
return tarfile.open(file_path, "w", format=tarfile.GNU_FORMAT)

def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray] = None, comment_str: str = ""):
"""
Add a sample to the current tar file.
Arguments:
X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
comment_str: Comment line (second line) to add to the xyz file.
"""
def _add_sample(self, X, xyzs, Y, comment_str):

if self.sample_count >= self.max_count:
self.tar_count += 1
self.sample_count = 0
self.ft.close()
self.ft = self._get_tar_file()
self._ft.close()
self._ft = self._get_tar_file()

# Write AFM images
for i, x in enumerate(X):
for j in range(x.shape[-1]):
xj = x[:, :, j]
xj = ((xj - xj.min()) / np.ptp(xj) * (2**8 - 1)).astype(np.uint8) # Convert range to 0-255 integers
img_bytes = io.BytesIO()
Image.fromarray(xj.T[::-1], mode="L").save(img_bytes, "png", compress_level=self.png_compress_level)
Image.fromarray(xj.T[::-1], mode="L").save(img_bytes, "png")
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
self.ft.addfile(get_tarinfo(f"{self.total_count}.{j:02d}.{i}.png", img_bytes), img_bytes)
self._ft.addfile(get_tarinfo(f"{self.total_count}.{j:02d}.{i}.png", img_bytes), img_bytes)
img_bytes.close()

# Write xyz file
Expand All @@ -89,7 +113,7 @@ def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarr
xyz_bytes.write(bytearray(f"{xyz[i]:10.8f}\t", "utf-8"))
xyz_bytes.write(bytearray("\n", "utf-8"))
xyz_bytes.seek(0) # Return stream to start so that addfile can read it correctly
self.ft.addfile(get_tarinfo(f"{self.total_count}.xyz", xyz_bytes), xyz_bytes)
self._ft.addfile(get_tarinfo(f"{self.total_count}.xyz", xyz_bytes), xyz_bytes)
xyz_bytes.close()

# Write image descriptors (if any)
Expand All @@ -98,12 +122,27 @@ def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarr
img_bytes = io.BytesIO()
np.save(img_bytes, y.astype(np.float32))
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
self.ft.addfile(get_tarinfo(f"{self.total_count}.desc_{i}.npy", img_bytes), img_bytes)
self._ft.addfile(get_tarinfo(f"{self.total_count}.desc_{i}.npy", img_bytes), img_bytes)
img_bytes.close()

self.sample_count += 1
self.total_count += 1

def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray] = None, comment_str: str = ""):
"""
Add a sample to the current tar file.
Arguments:
X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
comment_str: Comment line (second line) to add to the xyz file.
"""
if self.async_write:
self._q.put((X, xyzs, Y, comment_str), block=True, timeout=60)
else:
self._add_sample(X, xyzs, Y, comment_str)


def get_tarinfo(fname: str, file_bytes: io.BytesIO):
info = tarfile.TarInfo(fname)
Expand All @@ -128,13 +167,12 @@ class TarSampleList(TypedDict, total=False):

class TarDataGenerator:
"""
Iterable that loads data from tar archives with data saved in npz format for generating samples
with the GeneratorAFMTrainer in ppafm.
Iterable that loads data from tar archives with data saved in npz format for generating samples with ``GeneratorAFMTrainer``
in *ppafm*.
The npz files should contain the following entries:
- ``'data'``: An array containing the potential/density on a 3D grid. The potential is assumed to be in
units of eV and density in units of e/Å^3.
- ``'data'``: An array containing the potential/density on a 3D grid.
- ``'origin'``: Lattice origin in 3D space as an array of shape ``(3,)``.
- ``'lattice'``: Lattice vectors as an array of shape ``(3, 3)``, where the rows are the vectors.
- ``'xyz'``: Atom xyz coordinates as an array of shape ``(n_atoms, 3)``.
Expand All @@ -148,8 +186,9 @@ class TarDataGenerator:
- ``'rho_sample'``: Sample electron density if the sample dict contained ``rho``, or ``None`` otherwise.
- ``'rot'``: Rotation matrix.
Note: it is recommended to use ``multiprocessing.set_start_method('spawn')`` when using the :class:`TarDataGenerator`.
Otherwise a lot of warnings about leaked memory objects may be thrown on exit.
Note:
It is recommended to use ``multiprocessing.set_start_method('spawn')`` when using the :class:`TarDataGenerator`.
Otherwise a lot of warnings about leaked memory objects may be thrown on exit.
Arguments:
samples: List of sample dicts as :class:`TarSampleList`. File paths should be relative to ``base_path``.
Expand Down
5 changes: 1 addition & 4 deletions papers/asd-afm/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,7 @@ def on_sample_start(self):
if i % 100 == 0:
elapsed = time.perf_counter() - start_gen
eta = elapsed / (i + 1) * (len(generator) - i)
print(
f"{mode} sample {i}/{len(generator)}, writing to `{tar_writer.ft.name}`, "
f"Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s"
)
print(f"{mode} sample {i}/{len(generator)}, Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s")

print(f"Done with {mode} - Elapsed time: {time.perf_counter() - start_gen:.1f}")

Expand Down
5 changes: 1 addition & 4 deletions papers/ed-afm/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,7 @@ def handle_distance(self):
if i % 100 == 0:
elapsed = time.perf_counter() - start_gen
eta = elapsed / (i + 1) * (len(generator) - i)
print(
f"{mode} sample {i}/{len(generator)}, writing to `{tar_writer.ft.name}`, "
f"Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s"
)
print(f"{mode} sample {i}/{len(generator)}, Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s")

print(f"Done with {mode} - Elapsed time: {time.perf_counter() - start_gen:.1f}")

Expand Down

0 comments on commit 9e9cc8f

Please sign in to comment.