Skip to content

Commit

Permalink
fix: fix type checking in utils.load_pipeline()
Browse files Browse the repository at this point in the history
  • Loading branch information
pwwang committed Aug 11, 2024
1 parent 66d885a commit 200cddc
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
26 changes: 10 additions & 16 deletions pipen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,15 +667,15 @@ async def load_pipeline(
try:
if isinstance(obj, str):
obj = _get_obj_from_spec(obj)
if isinstance(obj, Pipen) or (
isinstance(obj, type) and issubclass(obj, (Pipen, Proc, ProcGroup))
):
pass
else:
raise TypeError(
"Expected a Pipen, Proc, ProcGroup class, or a Pipen object, "
f"got {type(obj)}"
)
if isinstance(obj, Pipen) or (
isinstance(obj, type) and issubclass(obj, (Pipen, Proc, ProcGroup))
):
pass
else:
raise TypeError(
"Expected a Pipen, Proc, ProcGroup class, or a Pipen object, "
f"got {type(obj)}"
)

pipeline = obj
if isinstance(obj, type) and issubclass(obj, Proc):
Expand All @@ -689,15 +689,9 @@ async def load_pipeline(
# Avoid "pipeline" to be used as pipeline name by varname
(pipeline, ) = (obj(**kwargs), ) # type: ignore

else: # obj is a Pipen instance
elif isinstance(obj, Pipen):
pipeline._kwargs.update(kwargs)

if not isinstance(pipeline, Pipen):
raise TypeError(
"Expected a Pipen, Proc or ProcGroup class, "
f"got {type(pipeline)}"
)

# Initialize the pipeline so that the arguments definied by
# other plugins (i.e. pipen-args) to take in place.
pipeline.workdir = Path(pipeline.config.workdir).joinpath(
Expand Down
11 changes: 11 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,14 @@ def create_dead_link(path):
link.unlink()
link.symlink_to(target)
target.unlink()


# for load_pipeline tests
pipeline = Pipen(
name=f"simple_pipeline_{Pipen.PIPELINE_COUNT + 1}",
desc="No description",
loglevel="debug",
cache=True,
workdir=gettempdir() + "/.pipen",
outdir=gettempdir() + f"/pipen_simple_{Pipen.PIPELINE_COUNT}",
).set_starts(SimpleProc)
7 changes: 7 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ class P1(Proc):
assert len(pipeline.procs) == 1


@pytest.mark.forked
@pytest.mark.asyncio
async def test_load_pipeline_pipen_object(tmp_path):
p = await load_pipeline(f"{HERE}/helpers.py:pipeline", a=1)
assert p._kwargs["a"] == 1


@pytest.mark.forked
# To avoid: Another plugin named simpleplugin has already been registered.
@pytest.mark.asyncio
Expand Down

0 comments on commit 200cddc

Please sign in to comment.