Skip to content

Commit 477cfcc

Browse files
authored
Merge pull request #1057 from parea-ai/fix-deepcopy-fails-in-esp
fix: don't duplicate inputs in experiment
2 parents 9df911a + e42afc6 commit 477cfcc

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

parea/experiment/experiment.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import os
77
from collections import defaultdict
88
from concurrent.futures import ThreadPoolExecutor
9-
from copy import deepcopy
109
from functools import partial
1110
from urllib.parse import quote
1211

@@ -119,7 +118,11 @@ async def experiment(
119118
len_test_cases = len(data) if isinstance(data, list) else 0
120119

121120
if n_trials > 1:
122-
data = duplicate_dicts(data, n_trials)
121+
try:
122+
data = duplicate_dicts(data, n_trials)
123+
except TypeError as e:
124+
logger.error(f"Error duplicating input data. You need to manually duplicate input data and set n_trials=1. \n", exc_info=e)
125+
raise e
123126
len_test_cases = len(data) if isinstance(data, list) else 0
124127
print(f"Running {n_trials} trials of the experiment \n")
125128

@@ -129,14 +132,12 @@ async def experiment(
129132

130133
async def limit_concurrency(sample):
131134
async with sem:
132-
sample_copy = deepcopy(sample)
133-
target = sample_copy.pop("target", None)
134-
return await func(_parea_target_field=target, **sample_copy)
135+
kwargs = {"_parea_target_field": sample.get("target", None), **{k: v for k, v in sample.items() if k != "target"}}
136+
return await func(**kwargs)
135137

136138
def limit_concurrency_sync(sample):
137-
sample_copy = deepcopy(sample)
138-
target = sample_copy.pop("target", None)
139-
return func(_parea_target_field=target, **sample_copy)
139+
kwargs = {"_parea_target_field": sample.get("target", None), **{k: v for k, v in sample.items() if k != "target"}}
140+
return func(**kwargs)
140141

141142
if inspect.iscoroutinefunction(func):
142143
tasks = [asyncio.ensure_future(limit_concurrency(sample)) for sample in data]
@@ -158,7 +159,7 @@ def limit_concurrency_sync(sample):
158159
import traceback
159160

160161
traceback.print_exc()
161-
print(f"\nExperiment stopped due to an error (note you can deactivate this behavior by setting stop_on_error=False): {str(e)}\n")
162+
logger.error(f"\nExperiment stopped due to an error (note you can deactivate this behavior by setting stop_on_error=False): {str(e)}\n", exc_info=e)
162163
for task in tasks:
163164
task.cancel()
164165
else:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.205"
9+
version = "0.2.206"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

0 commit comments

Comments
 (0)