Skip to content

Commit

Permalink
Pipeline: cleanup and improve docs (#443)
Browse files Browse the repository at this point in the history
* add documentation

* remove parameter binary_output because it was not used

* remove parameter "shuffle" from dataloader_params because it would collide with fixed shuffle=False
  • Loading branch information
ArneBinder authored Jan 13, 2025
1 parent 45f1469 commit 600d4b7
Showing 1 changed file with 42 additions and 7 deletions.
49 changes: 42 additions & 7 deletions src/pytorch_ie/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,15 @@ class Pipeline:
Pipeline supports running on CPU or GPU through the device argument (see below).
Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` )
output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we
provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the
pickle format.
Args:
model (:class:`~pytorch_ie.PyTorchIEModel`):
The deep learning model to use for the pipeline.
taskmodule (:class:`~pytorch_ie.TaskModule`): The taskmodule to use for encoding
and decoding the documents.
device (:obj:`Union[int, str]`, `optional`, defaults to :obj:`"cpu"`):
The device to run the pipeline on. This can be a CPU device (:obj:`"cpu"`), a GPU
device (:obj:`"cuda"`) or a specific GPU device (:obj:`"cuda:X"`, where :obj:`X`
is the index of the GPU).
"""

default_input_names = None
Expand All @@ -50,15 +55,13 @@ def __init__(
model: PyTorchIEModel,
taskmodule: TaskModule,
device: Union[int, str] = "cpu",
binary_output: bool = False,
**kwargs,
):
self.taskmodule = taskmodule
device_str = (
("cpu" if device < 0 else f"cuda:{device}") if isinstance(device, int) else device
)
self.device = torch.device(device_str)
self.binary_output = binary_output

# Module.to() returns just self, but moved to the device. This is not correctly
# reflected in typing of PyTorch.
Expand Down Expand Up @@ -192,7 +195,7 @@ def _sanitize_parameters(
forward_parameters[p_name] = pipeline_parameters[p_name]

# set dataloader parameters
for p_name in ["batch_size", "num_workers", "shuffle"]:
for p_name in ["batch_size", "num_workers"]:
if p_name in pipeline_parameters:
dataloader_params[p_name] = pipeline_parameters[p_name]

Expand Down Expand Up @@ -299,6 +302,38 @@ def __call__(
*args,
**kwargs,
) -> Union[Document, Sequence[Document]]:
"""
The __call__ method is the entry point for the pipeline. It will run the pipeline workflow in the following
order:
1. Encode the documents
2. Run the model forward pass(es) on the encodings
3. Combine the model outputs with the inputs encodings and integrate them back into the documents
Args:
documents (:obj:`Union[Document, Sequence[Document]]`): The documents to process. If a single document is
passed, the output will be a single document. If a list of documents is passed, the output will be a
list of documents.
document_batch_size (:obj:`int`, `optional`): The batch size to use for encoding the documents with the
taskmodule. If not provided, the default batch size of the taskmodule will be used.
show_progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to show a progress bar
during inference.
fast_dev_run (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to run a fast development
run. If set to :obj:`True`, only the first two model inputs will be processed.
batch_size (:obj:`int`, `optional`, defaults to :obj:`1`): The batch size to use for the dataloader. If not
provided, a batch size of 1 will be used.
num_workers (:obj:`int`, `optional`, defaults to :obj:`8`): The number of workers to use for the dataloader.
If not provided, 8 workers will be used.
inplace (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to modify the input documents
in place. Requires the input to be a mutable sequence of documents or a single document.
Note that all the arguments except `documents` can be set in the `__init__` method and/or overridden in the
`__call__` method.
Returns:
:obj:`Union[Document, Sequence[Document]]`: The processed documents. If a single document was passed, a
single document will be returned. If a list of documents was passed, a list of documents will be returned.
"""
if args:
logger.warning(f"Ignoring args : {args}")
(
Expand Down

0 comments on commit 600d4b7

Please sign in to comment.