Skip to content

Commit c1ffbce

Browse files
committed
fix(cli)🐛: Fix synthetic data generation by patching alphabet handling
- Introduced a monkey patch for the `generate_random_sequences` function to ensure the custom alphabet is used. - Removed the alphabet parameter from task-specific parameters to prevent conflicts. - Ensured the original function is restored after dataset generation or upon encountering errors.
1 parent 75336e1 commit c1ffbce

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,5 @@ docs/cli.md
152152
message_log.db
153153
catboost_info/*
154154
examples/output/*
155+
*.pkl
156+
data/

fast_seqfunc/cli.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,29 @@ def generate_synthetic(
342342
if task != "length_dependent":
343343
task_params["length"] = sequence_length
344344

345+
# We need to patch the generate_random_sequences function to use our alphabet
346+
# This approach uses monkey patching to avoid having to modify all task functions
347+
original_generate_random_sequences = synthetic.generate_random_sequences
348+
349+
def patched_generate_random_sequences(*args, **kwargs):
350+
"""
351+
Patched version of `generate_random_sequences` that uses a custom alphabet.
352+
353+
This function overrides the alphabet parameter with our custom alphabet while
354+
preserving all other parameters passed to the original function.
355+
356+
:param args: Positional arguments to pass to the original function
357+
:param kwargs: Keyword arguments to pass to the original function
358+
:return: Result from the original generate_random_sequences function
359+
"""
360+
# Override the alphabet parameter with our custom alphabet,
361+
# but keep other parameters
362+
kwargs["alphabet"] = alphabet
363+
return original_generate_random_sequences(*args, **kwargs)
364+
365+
# Replace the function temporarily
366+
synthetic.generate_random_sequences = patched_generate_random_sequences
367+
345368
# Add task-specific parameters based on the task type
346369
if task == "motif_position":
347370
# Use custom motif if provided
@@ -395,9 +418,6 @@ def generate_synthetic(
395418
task_params["min_length"] = min_length
396419
task_params["max_length"] = max_length
397420

398-
# Add alphabet parameter to all tasks
399-
task_params["alphabet"] = alphabet
400-
401421
# Validate the task
402422
valid_tasks = [
403423
"g_count",
@@ -417,6 +437,11 @@ def generate_synthetic(
417437
)
418438
raise typer.Exit(1)
419439

440+
# The task functions don't directly accept an alphabet parameter
441+
# so we need to remove it from task_params
442+
if "alphabet" in task_params:
443+
del task_params["alphabet"]
444+
420445
# Generate the dataset
421446
try:
422447
df = synthetic.generate_dataset_by_task(
@@ -433,6 +458,8 @@ def generate_synthetic(
433458
output_path = output_dir / f"{file_prefix}{task}_data.csv"
434459
df.to_csv(output_path, index=False)
435460
logger.info(f"Saved full dataset to {output_path}")
461+
# Restore original function
462+
synthetic.generate_random_sequences = original_generate_random_sequences
436463
return
437464

438465
# Validate split ratios
@@ -489,6 +516,9 @@ def generate_synthetic(
489516
except Exception as e:
490517
logger.error(f"Error generating synthetic data: {e}")
491518
raise typer.Exit(1)
519+
finally:
520+
# Make sure to restore the original function even if an error occurs
521+
synthetic.generate_random_sequences = original_generate_random_sequences
492522

493523

494524
@app.command()

0 commit comments

Comments
 (0)