Skip to content

Commit

Permalink
Merge pull request #43 from ahuang11/add_progress
Browse files Browse the repository at this point in the history
Add progress bar
  • Loading branch information
ahuang11 authored Apr 16, 2024
2 parents 83c42bf + 979c148 commit ba8c00c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
10 changes: 9 additions & 1 deletion streamjoy/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def default_holoviews_renderer(
kwargs["toolbar"] = None
elif backend == "matplotlib":
kwargs["cbar_extend"] = kwargs.get("cbar_extend", "both")
hv_obj.opts(**kwargs)

if isinstance(hv_obj, hv.Overlay):
for hv_el in hv_obj:
try:
hv_el.opts(**kwargs)
except Exception:
pass
else:
hv_obj.opts(**kwargs)

return hv_obj
14 changes: 11 additions & 3 deletions streamjoy/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import imageio.v3 as iio
import numpy as np
import param
from dask.diagnostics import ProgressBar
from dask.distributed import Client, Future, fire_and_forget
from imageio.core.v3_plugin_api import PluginV3
from PIL import Image, ImageDraw
Expand Down Expand Up @@ -168,6 +169,11 @@ class MediaStream(param.Parameterized):
doc="The number of threads to use per worker.",
)

show_progress = param.Boolean(
default=True,
doc="Whether to show the progress bar when rendering.",
)

scratch_dir = param.Path(
doc="The directory to use for temporary files.", check_exists=False
)
Expand Down Expand Up @@ -201,6 +207,7 @@ def __init__(self, resources: list[Any] | None = None, **params) -> None:
params["_tbd_kwargs"][param_key] = params.pop(param_key)

super().__init__(**params)
self._progress_bar = ProgressBar(minimum=3 if self.show_progress else np.inf)

@classmethod
def from_numpy(
Expand Down Expand Up @@ -416,7 +423,7 @@ def _render_images(
f"got {resources=!r}."
) from exc

if renderer:
if renderer and not renderer.__name__.startswith("default"):
try:
iterable_0 = [iterable[0] for iterable in renderer_iterables]
renderer(resource_0, *iterable_0, **renderer_kwargs)
Expand Down Expand Up @@ -446,7 +453,8 @@ def _render_images(
renderer(resource, *iterable, **renderer_kwargs)
for resource, *iterable in zip_longest(resources, *renderer_iterables)
]
resources = dask.compute(jobs, scheduler="threads")[0]
with self._progress_bar:
resources = dask.compute(jobs, scheduler="threads")[0]
resource_0 = _utils.get_result(_utils.get_first(resources))

is_like_image = isinstance(resource_0, np.ndarray) and resource_0.ndim == 3
Expand Down Expand Up @@ -784,7 +792,7 @@ def _write_images(
buf.init_video_stream(self.codec, **init_kwargs)

if "crf" in write_kwargs:
buf._video_stream.options = {'crf': str(write_kwargs.pop("crf"))}
buf._video_stream.options = {"crf": str(write_kwargs.pop("crf"))}

intro_frame = self._create_intro(images)
self._prepend_intro(buf, intro_frame, **write_kwargs)
Expand Down

0 comments on commit ba8c00c

Please sign in to comment.