Skip to content

Commit

Permalink
draft for scenario duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
Toan Quach committed Dec 27, 2024
1 parent 2e143bb commit 658f02f
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 8 deletions.
30 changes: 23 additions & 7 deletions taipy/core/data/_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ class _DataManager(_Manager[DataNode], _VersionMixin):
_EVENT_ENTITY_TYPE = EventEntityType.DATA_NODE
_repository: _DataFSRepository

@classmethod
def _get_owner_id(
cls, scope, cycle_id, scenario_id
) -> Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]]:
if scope == Scope.SCENARIO:
return scenario_id
elif scope == Scope.CYCLE:
return cycle_id
else:
return None

@classmethod
def _bulk_get_or_create(
cls,
Expand All @@ -48,13 +59,7 @@ def _bulk_get_or_create(
dn_configs_and_owner_id = []
for dn_config in data_node_configs:
scope = dn_config.scope
owner_id: Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]]
if scope == Scope.SCENARIO:
owner_id = scenario_id
elif scope == Scope.CYCLE:
owner_id = cycle_id
else:
owner_id = None
owner_id = cls._get_owner_id(scope, cycle_id, scenario_id)
dn_configs_and_owner_id.append((dn_config, owner_id))

data_nodes = cls._repository._get_by_configs_and_owner_ids(
Expand Down Expand Up @@ -174,3 +179,14 @@ def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None)
for fil in filters:
fil.update({"config_id": config_id})
return cls._repository._load_all(filters)

@classmethod
def _clone(
cls, dn: DataNode, cycle_id: Optional[CycleId] = None, scenario_id: Optional[ScenarioId] = None
) -> DataNode:
dn.id = dn._new_id(dn._config_id)
dn._owner_id = cls._get_owner_id(dn._scope, cycle_id, scenario_id)
dn._parent_ids = set()
cls._set(dn)
# dn._clone_data()
return dn
42 changes: 42 additions & 0 deletions taipy/core/scenario/_scenario_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,45 @@ def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None)
for fil in filters:
fil.update({"config_id": config_id})
return cls._repository._load_all(filters)

@classmethod
def _clone(cls, scenario: Scenario) -> Scenario:
"""
Clone a scenario.
Arguments:
scenario (Scenario): The scenario to clone.
Returns:
Scenario: The cloned scenario.
"""
scenario.id = scenario._new_id(scenario.config_id)
# TODO: update sequences

# Clone tasks and data nodes
_task_manager = _TaskManagerFactory._build_manager()
_data_manager = _DataManagerFactory._build_manager()

cloned_tasks = set()
for task in scenario.tasks.values():
cloned_tasks.add(_task_manager._clone(task, None, scenario.id))
scenario._tasks = cloned_tasks

cloned_additional_data_nodes = set()
for data_node in scenario.additional_data_nodes.values():
cloned_additional_data_nodes.add(_data_manager._clone(data_node, None, scenario.id))
scenario._additional_data_nodes = cloned_additional_data_nodes

for task in cloned_tasks:
if scenario.id not in task._parent_ids:
task._parent_ids.update([scenario.id])
_task_manager._set(task)

for dn in cloned_additional_data_nodes:
if scenario.id not in dn._parent_ids:
dn._parent_ids.update([scenario.id])
_data_manager._set(dn)

cls._set(scenario)

return scenario
12 changes: 12 additions & 0 deletions taipy/core/task/_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,15 @@ def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None)
for fil in filters:
fil.update({"config_id": config_id})
return cls._repository._load_all(filters)

@classmethod
def _clone(cls, task: Task, cycle_id: Optional[CycleId] = None, scenario_id: Optional[ScenarioId] = None) -> Task:
data_manager = _DataManagerFactory._build_manager()
inputs = [data_manager._clone(i, cycle_id, scenario_id) for i in task.input.values()]
outputs = [data_manager._clone(o, cycle_id, scenario_id) for o in task.output.values()]
task.id = task._new_id(task.config_id)
task._parent_ids = set()
for dn in set(inputs + outputs):
dn._parent_ids.update([task.id])
cls._set(task)
return task
7 changes: 6 additions & 1 deletion taipy/core/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
skippable: bool = False,
) -> None:
self._config_id = _validate_id(config_id)
self.id = id or TaskId(self.__ID_SEPARATOR.join([self._ID_PREFIX, self.config_id, str(uuid.uuid4())]))
self.id = id or self._new_id(config_id)
self._owner_id = owner_id
self._parent_ids = parent_ids or set()
self._input = {dn.config_id: dn for dn in input or []}
Expand All @@ -127,6 +127,11 @@ def __init__(
self._properties = _Properties(self, **properties)
self._init_done = True

@staticmethod
def _new_id(config_id: str) -> TaskId:
"""Generate a unique task identifier."""
return TaskId(Task.__ID_SEPARATOR.join([Task._ID_PREFIX, config_id, str(uuid.uuid4())]))

def __hash__(self) -> int:
return hash(self.id)

Expand Down

0 comments on commit 658f02f

Please sign in to comment.