Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MultipleSourceLoader lazy and fix its use of fusion #1602

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def process(self) -> MultiStream:
self._maybe_set_classification_policy()
return self.add_data_classification(self.load_data())

def get_splits(self):
return list(self().keys())

class LazyLoader(Loader):
split: Optional[str] = NonPositionalField(default=None)

Expand Down Expand Up @@ -742,7 +745,7 @@ def load_iterables(self):
return dataset


class MultipleSourceLoader(Loader):
class MultipleSourceLoader(LazyLoader):
"""Allows loading data from multiple sources, potentially mixing different types of loaders.

Args:
Expand All @@ -766,20 +769,23 @@ class MultipleSourceLoader(Loader):

sources: List[Loader]

# MultipleSourceLoaders uses the the data classification from source loaders,
# so only need to add it, if explicitly requested to override.
def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
if self.data_classification_policy is None:
return multi_stream
return super().add_data_classification(multi_stream)

def load_iterables(self):
pass

def load_data(self):
return FixedFusion(
subsets=self.sources, max_instances_per_subset=self.get_limit()
).process()
def get_splits(self):
splits = []
for loader in self.sources:
splits.extend(loader.get_splits())
return list(set(splits))

def split_generator(self, split: str) -> Generator[Any, None, None]:
yield from FixedFusion(
subsets=self.sources,
max_instances_per_subset=self.get_limit(),
include_splits=[split],
)()[split]


class LoadFromDictionary(Loader):
Expand Down
2 changes: 2 additions & 0 deletions tests/library/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,11 @@ def test_multiple_source_loader(self):
sources=[
LoadCSV(files={"test": files["train"]}),
LoadCSV(files={"test": files["test"]}),
LoadCSV(files={"demos_pool": files["train"]}),
]
)
ms = loader()
self.assertSetEqual(set(ms.keys()), {"demos_pool", "test"})
assert len(dfs["test"]) + len(dfs["train"]) == len(list(ms["test"]))

def test_load_from_dictionary(self):
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"filename": "src/unitxt/loaders.py",
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
"is_verified": false,
"line_number": 592,
"line_number": 595,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-02-10T14:10:51Z"
"generated_at": "2025-02-12T09:37:42Z"
}