Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/datacustomcode/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ class DataAccessLayerCalls(pydantic.BaseModel):
def validate_access_layer(self) -> DataAccessLayerCalls:
if self.read_dlo and self.read_dmo:
raise ValueError("Cannot read from DLO and DMO in the same file.")
if len(self.write_to_dlo) > 1 or len(self.write_to_dmo) > 1:
raise ValueError(
"Cannot write to more than one DLO or DMO in the same file."
)
if not self.read_dlo and not self.read_dmo:
raise ValueError("Must read from at least one DLO or DMO.")
if self.read_dlo and self.write_to_dmo:
Expand Down
22 changes: 16 additions & 6 deletions tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_read_dmo_write_dlo_throws_error(self):
finally:
os.unlink(temp_path)

def test_invalid_multiple_writes(self):
def test_multiple_writes(self):
"""Test scanning a file with multiple write operations."""
content = textwrap.dedent(
"""
Expand All @@ -258,15 +258,25 @@ def test_invalid_multiple_writes(self):
# Read from DLO
df = client.read_dlo("input_dlo")

# Write to multiple DLOs - invalid
client.write_to_dlo("output_dlo_1", df, "overwrite")
client.write_to_dlo("output_dlo_2", df, "overwrite")
# Transform data for different outputs
df_filtered = df.filter(df.col > 10)
df_aggregated = df.groupBy("category").agg({"value": "sum"})

# Write to multiple DLOs
client.write_to_dlo("output_filtered", df_filtered, "overwrite")
client.write_to_dlo("output_aggregated", df_aggregated, "overwrite")
"""
)
temp_path = create_test_script(content)
try:
with pytest.raises(ValueError, match="Cannot write to more than one DLO"):
scan_file(temp_path)
result = scan_file(temp_path)
assert "input_dlo" in result.read_dlo
assert "output_filtered" in result.write_to_dlo
assert "output_aggregated" in result.write_to_dlo
assert len(result.read_dlo) == 1
assert len(result.write_to_dlo) == 2
assert len(result.read_dmo) == 0
assert len(result.write_to_dmo) == 0
finally:
os.unlink(temp_path)

Expand Down