Skip to content

Commit

Permalink
Merge pull request #126 from PySpur-Dev/feat/cascading-failures
Browse files Browse the repository at this point in the history
Feat/cascading failures
  • Loading branch information
srijanpatel authored Jan 28, 2025
2 parents 3b61b53 + ee67cb4 commit 316a4bd
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 17 deletions.
16 changes: 15 additions & 1 deletion backend/app/api/run_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from ..schemas.run_schemas import RunResponseSchema
from ..database import get_db
from ..models.run_model import RunModel
from ..models.run_model import RunModel, RunStatus
from ..models.task_model import TaskStatus

router = APIRouter()

Expand Down Expand Up @@ -46,6 +47,19 @@ def get_run(run_id: str, db: Session = Depends(get_db)):
@router.get("/{run_id}/status/", response_model=RunResponseSchema)
def get_run_status(run_id: str, db: Session = Depends(get_db)):
run = db.query(RunModel).filter(RunModel.id == run_id).first()
if not run:
raise HTTPException(status_code=404, detail="Run not found")
if run.status != RunStatus.FAILED:
failed_tasks = [task for task in run.tasks if task.status == TaskStatus.FAILED]
running_and_pending_tasks = [
task
for task in run.tasks
if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING]
]
if failed_tasks and len(running_and_pending_tasks) == 0:
run.status = RunStatus.FAILED
db.commit()
db.refresh(run)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
return run
7 changes: 6 additions & 1 deletion backend/app/api/workflow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,12 @@ def list_runs(
failed_tasks = [
task for task in run.tasks if task.status == TaskStatus.FAILED
]
if failed_tasks:
running_and_pending_tasks = [
task
for task in run.tasks
if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING]
]
if failed_tasks and len(running_and_pending_tasks) == 0:
run.status = RunStatus.FAILED
db.commit()
db.refresh(run)
Expand Down
72 changes: 62 additions & 10 deletions backend/app/execution/workflow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
from .workflow_execution_context import WorkflowExecutionContext


class UpstreamFailure(Exception):
pass


class UnconnectedNode(Exception):
pass


class WorkflowExecutor:
"""
Handles the execution of a workflow.
Expand All @@ -42,6 +50,7 @@ def __init__(
self._node_tasks: Dict[str, asyncio.Task[Optional[BaseNodeOutput]]] = {}
self._initial_inputs: Dict[str, Dict[str, Any]] = {}
self._outputs: Dict[str, Optional[BaseNodeOutput]] = {}
self._failed_nodes: Set[str] = set()
self._build_node_dict()
self._build_dependencies()

Expand Down Expand Up @@ -150,21 +159,39 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]:
if node_id in self._outputs:
return self._outputs[node_id]

# Wait for dependencies
# Check if any predecessor nodes failed
dependency_ids = self._dependencies.get(node_id, set())

# Wait for dependencies
predecessor_outputs: List[Optional[BaseNodeOutput]] = []
if dependency_ids:
predecessor_outputs = await asyncio.gather(
*(
self._get_async_task_for_node_execution(dep_id)
for dep_id in dependency_ids
try:
predecessor_outputs = await asyncio.gather(
*(
self._get_async_task_for_node_execution(dep_id)
for dep_id in dependency_ids
),
)
)
except Exception as e:
raise UpstreamFailure(
f"Node {node_id} skipped due to upstream failure"
)

if any(dep_id in self._failed_nodes for dep_id in dependency_ids):
print(f"Node {node_id} skipped due to upstream failure")
self._failed_nodes.add(node_id)
raise UpstreamFailure(f"Node {node_id} skipped due to upstream failure")

if node.node_type != "CoalesceNode" and any(
[output is None for output in predecessor_outputs]
):
self._outputs[node_id] = None
if self.task_recorder:
self.task_recorder.update_task(
node_id=node_id,
status=TaskStatus.CANCELED,
end_time=datetime.now(),
)
return None

# Get source handles mapping
Expand All @@ -189,7 +216,7 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]:
if self.task_recorder:
self.task_recorder.update_task(
node_id=node_id,
status=TaskStatus.PENDING,
status=TaskStatus.CANCELED,
end_time=datetime.now(),
)
return None
Expand Down Expand Up @@ -225,7 +252,7 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]:
# If node_input is empty, return None
if not node_input:
self._outputs[node_id] = None
return None
raise UnconnectedNode(f"Node {node_id} has no input")

node_instance = NodeFactory.create_node(
node_name=node.title, node_type_name=node.node_type, config=node.config
Expand Down Expand Up @@ -255,6 +282,17 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]:
# Store output
self._outputs[node_id] = output
return output
except UpstreamFailure as e:
self._failed_nodes.add(node_id)
self._outputs[node_id] = None
if self.task_recorder:
self.task_recorder.update_task(
node_id=node_id,
status=TaskStatus.CANCELED,
end_time=datetime.now(),
error="Upstream failure",
)
raise e
except Exception as e:
error_msg = (
f"Node execution failed:\n"
Expand All @@ -265,6 +303,7 @@ async def _execute_node(self, node_id: str) -> Optional[BaseNodeOutput]:
f"Error: {str(e)}"
)
print(error_msg)
self._failed_nodes.add(node_id)
if self.task_recorder:
self.task_recorder.update_task(
node_id=node_id,
Expand Down Expand Up @@ -317,12 +356,25 @@ async def run(
if node.parent_id:
nodes_to_run.discard(node.id)

# drop outputs for nodes that need to be run
for node_id in nodes_to_run:
self._outputs.pop(node_id, None)

# Start tasks for all nodes
for node_id in nodes_to_run:
self._get_async_task_for_node_execution(node_id)

# Wait for all tasks to complete
await asyncio.gather(*self._node_tasks.values())
# Wait for all tasks to complete, but don't propagate exceptions
results = await asyncio.gather(
*self._node_tasks.values(), return_exceptions=True
)

# Process results to handle any exceptions
for node_id, result in zip(self._node_tasks.keys(), results):
if isinstance(result, Exception):
print(f"Node {node_id} failed with error: {str(result)}")
self._failed_nodes.add(node_id)
self._outputs[node_id] = None

# return the non-None outputs
return {
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/components/nodes/DynamicNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ const DynamicNode: React.FC<DynamicNodeProps> = ({
{renderHandles()}
</div>
{nodeData.error && <NodeErrorDisplay error={nodeData.error} />}
{displayOutput && <NodeOutputDisplay key="output-display" output={nodeData.run} />}
{displayOutput && <NodeOutputDisplay key={`output-display-${id}`} output={nodeData.run} />}
</BaseNode>
</div>
<NodeOutputModal
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/components/nodes/InputNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ const InputNode: React.FC<InputNodeProps> = ({ id, data, readOnly = false, ...pr
{renderWorkflowInputs()}
{renderAddField()}
</div>
{data.run && <NodeOutputDisplay output={data.run} />}
{data.run && <NodeOutputDisplay output={data.run} key={`output-${id}`} />}
</div>
</BaseNode>
</div>
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/components/nodes/logic/RouterNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ export const RouterNode: React.FC<RouterNodeProps> = ({ id, data, readOnly = fal
</div>
))}
</div>
<NodeOutputDisplay output={data.run} />
<NodeOutputDisplay output={data.run} key={`output-${id}`} />
</BaseNode>
)
}
17 changes: 15 additions & 2 deletions frontend/src/hooks/useWorkflowExecution.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,26 @@ export const useWorkflowExecution = ({ onAlert }: UseWorkflowExecutionProps) =>
clearInterval(currentStatusInterval)
onAlert('Workflow run completed.', 'success')
}
if (statusResponse.status === 'FAILED' || tasks.some((task) => task.status === 'FAILED')) {
if (
statusResponse.status === 'FAILED' ||
(tasks.some((task) => task.status === 'FAILED') &&
!tasks.some((task) => task.status === 'RUNNING' || task.status === 'PENDING'))
) {
setIsRunning(false)
setCompletionPercentage(0)
// Clear all intervals
statusIntervals.current.forEach((interval) => clearInterval(interval))
clearInterval(currentStatusInterval)
onAlert('Workflow run failed.', 'danger')

// Check if some tasks succeeded while others failed
if (
tasks.some((task) => task.status === 'COMPLETED') &&
tasks.some((task) => task.status === 'FAILED')
) {
onAlert('Workflow ran with some failed tasks.', 'warning')
} else {
onAlert('Workflow run failed.', 'danger')
}
return
}
} catch (error) {
Expand Down

0 comments on commit 316a4bd

Please sign in to comment.