diff --git a/docs/project/images/dag_architecture.png b/docs/project/images/dag_architecture.png new file mode 100644 index 00000000..9fb602de Binary files /dev/null and b/docs/project/images/dag_architecture.png differ diff --git a/docs/project/user_guide/DAG_client_side_architecture.md b/docs/project/user_guide/DAG_client_side_architecture.md new file mode 100644 index 00000000..994ee6a6 --- /dev/null +++ b/docs/project/user_guide/DAG_client_side_architecture.md @@ -0,0 +1,96 @@ +# DAG Client-side Architecture (DDD Overview) + +This document provides the architectural context for the client-side DAG pipeline. +It is intended to be read before `DAG_client_side_implementation.md`, which dives into +the threaded adapter internals. + +## Bounded context and language +The bounded context here is DAG execution on HTC Grid. The ubiquitous language is: +- DAG (aggregate) and DAG nodes (entities) +- dependency/precedence edges +- task status: pending, submitted, completed, failed +- ready tasks: nodes whose dependencies are completed +- task definition: payload sent to the grid connector + +## Layered view (DDD + ports/adapters) + +### Domain layer +- `BaseDAG` defines the core contract for readiness, completion, and task payload creation. +- `NxHtcDagContainer` is the NetworkX-backed aggregate that stores DAG structure and status. +- Domain rules include: acyclic graph, required node attributes, and state transitions. +- Task payloads are derived from node metadata (currently `worker_arguments`). + +### Application layer +- `HTCDagScheduler` is the application service that drives the processing loop: + ready -> submit -> poll -> complete. +- The CLI in `examples/client/python/dag_client.py` is the composition root: + it loads config, builds the DAG, wires the scheduler and adapter, and runs. + +### Infrastructure layer +- `GridConnectorDagAdapter` is the threaded adapter that talks to grid connectors. +- `GridConnectorFactory` creates per-thread connector instances from config. +- Grid connectors (`AWSConnector`, `MockGridConnector`) implement `send()` and `get_results()`. +- Input adapters (`BusinessDagLoader`, `DAGGenerator`) translate external data into a DAG. + +## Ports and adapters mapping +- Port: `BaseDAG` (domain) -> Adapter: `NxHtcDagContainer` (NetworkX). +- Port: grid connector interface (`send` / `get_results`) -> Adapters: concrete connectors. +- `GridConnectorDagAdapter` bridges the scheduler (application) with connectors (infra). + +## DAG class diagram + +```mermaid +classDiagram + class BaseDAG { + <> + +get_nodes_with_resolved_dependencies() + +mark_node_completed(node_id) + +is_dag_completed() + +get_node_by_id(node_id) + +build_grid_task(node_id) + } + + class NxHtcDagContainer { + +get_ready_tasks() + +mark_task_submitted(task_id) + +mark_task_complete(task_id) + +is_dag_complete() + +get_completed_task_count() + +get_total_task_count() + } + + class HTCDagScheduler + class GridConnectorDagAdapter + class BusinessDagLoader + class DAGGenerator + class nx.DiGraph + + BaseDAG <|-- NxHtcDagContainer + NxHtcDagContainer o-- nx.DiGraph + HTCDagScheduler --> BaseDAG : uses + GridConnectorDagAdapter --> BaseDAG : builds tasks + BusinessDagLoader ..> nx.DiGraph : loads + DAGGenerator ..> nx.DiGraph : generates +``` + +## Execution flow (high level) +1. Load or generate a DAG. +2. Wrap it in `NxHtcDagContainer` (domain aggregate). +3. Scheduler finds ready tasks and enqueues them to the adapter. +4. Adapter workers submit batches and poll results from connectors. +5. Scheduler marks completed nodes and repeats until the DAG completes. + +## Current coupling notes +- `HTCDagScheduler` calls helpers on `NxHtcDagContainer` (for example, + `get_ready_tasks`, `mark_task_complete`, `get_completed_task_count`), so + alternative `BaseDAG` implementations must provide compatible methods or be wrapped. +- `BusinessDagLoader` sets default `task_type` and `worker_arguments` if missing, + which matches the task definition shape used by the adapter. + +## Extension points +- New DAG backend: implement `BaseDAG` for another graph library. +- New grid connector: implement the connector interface and wire it in the factory. +- New input format: add a loader that produces a NetworkX DAG or a `BaseDAG` implementation. + +## Related documentation +- `docs/project/user_guide/DAG_client_side_implementation.md` diff --git a/docs/project/user_guide/DAG_client_side_implementation.md b/docs/project/user_guide/DAG_client_side_implementation.md new file mode 100644 index 00000000..6d0d5582 --- /dev/null +++ b/docs/project/user_guide/DAG_client_side_implementation.md @@ -0,0 +1,311 @@ +# Threaded Connector Adapter - Current Implementation + +For the architectural overview (DDD and component boundaries), see +`docs/project/user_guide/DAG_client_side_architecture.md`. + +## Summary +This document describes the current `GridConnectorDagAdapter` implementation: a threaded +adapter that keeps long-lived connectors, uses in-memory queues, and lets the scheduler +remain lightweight. + +## What problem it solves +- Keeps a fixed set of worker threads and connectors alive for the lifetime of the adapter. +- Lets the scheduler thread do only two cheap operations: enqueue ready tasks and drain + completed task IDs. +- Moves network latency in `send()` and `get_results()` off the scheduler thread. + +## Where it fits +- The scheduler (`utils/dag/schedulers/htc_dag_scheduler.py`) calls: + - `adapter.submit_tasks(ready_nodes, dag_container)` + - `adapter.poll_completed()` and then `dag_container.mark_task_complete(node_id)` +- The adapter creates connectors via `utils/dag/grid_connector_factory.py` + (one connector per worker thread). +- The DAG container (`NxHtcDagContainer`) builds task definitions and tracks status. + +## Goals +- Start a configured number of threads when the adapter is initialized; threads persist + until shutdown/destruction. +- Provide two thread-safe queues: + - **Incoming tasks queue**: work items to be submitted to the grid. + - **Completed results queue**: completed DAG task IDs ready to be consumed. +- Make `submit_tasks()` non-blocking and cheap: it should enqueue tasks (with locking). +- Make worker threads responsible for: + - batching and submitting newly queued tasks to the grid + - polling for completion of tasks previously submitted by that worker + - pushing completed DAG IDs into the results queue +- Make `poll_completed()` non-blocking and cheap: it should drain/return items from + the results queue. + +## Non-goals +- Changing the scheduler loop semantics (`HTCDagScheduler`) beyond swapping adapter + implementation. +- Changing connector APIs; the adapter must continue to use `send()` and `get_results()` + as provided by the current connector implementations. +- Perfect delivery guarantees across process crashes (in-memory queues only). + +## Public API (compatibility target) +The adapter is a drop-in replacement for the existing adapter used by the scheduler. + +- `submit_tasks(node_ids: List[str], dag_container: BaseDAG) -> Dict[Any, Dict[str, str]]` + - Enqueues work and returns quickly. + - Return value can remain `{}` for compatibility (scheduler ignores it today). +- `poll_completed() -> List[Hashable]` + - Drains completed DAG IDs from the results queue and returns them (possibly empty). +- `active_count() -> int` + - Number of in-flight grid sessions across all workers (one session per `send()` call). +- `get_errors() -> List[WorkerError]` + - Returns a snapshot of recent worker errors (bounded by `errors_maxlen`). +- `raise_if_failed() -> None` + - Raises if a worker crashed or all workers stopped. +- `shutdown(wait: bool = True, timeout: Optional[float] = None) -> None` + - Signals workers to stop and optionally joins them. + - Prefer explicit shutdown over relying on `__del__`. + +## Configuration +All keys are optional; defaults shown below. + +- `num_connector_threads` (int, default `2`): number of long-lived worker threads/connectors. + Must be > 0. +- `max_dequeue_per_loop` (int, default `100`): max number of queued tasks a worker will pick up + and submit per loop iteration. Must be > 0. +- `poll_timeout_sec` (float, default `0.01`): passed to `connector.get_results(..., timeout_sec=...)`. + Must be >= 0. +- `poll_interval_sec` (float, default `0.1`): wait/sleep interval when there is no immediate work; + prevents busy looping. Must be >= 0. +- `retry_attempts` (int, default `3`): submission retry attempts. Must be > 0. +- `max_poll_failures` (int, default `3`): consecutive `get_results()` failures before remaining + tasks are requeued. Must be > 0. +- `errors_maxlen` (int, default `1000`): max number of stored `WorkerError` entries. Must be >= 0. +- `use_mock_grid` (bool) and connector-specific settings for per-thread connector construction + (via `GridConnectorFactory`). + +## Core state and data model + +### Incoming task queue (`self._tasks`) +- Type: `collections.deque[QueuedTask]` +- Producer: scheduler thread via `submit_tasks()` +- Consumers: worker threads via `_worker_main_loop()` + +Each `QueuedTask` contains: +- `node_id`: DAG node ID +- `task_definition`: payload for `connector.send([...])` + +### Completed results queue (`self._results`) +- Type: `collections.deque[Hashable]` +- Producer: worker threads (after polling finished tasks) +- Consumer: scheduler thread via `poll_completed()` + +Items are DAG node IDs ready to be marked completed by the scheduler. + +### De-duplication set (`self._queued_or_inflight`) +`submit_tasks()` deduplicates by keeping a set of node IDs that are either: +- currently queued, or +- in-flight on the grid (submitted but not yet returned via `poll_completed()`) + +The set is cleared for a node only when `poll_completed()` drains that node from +`self._results`. + +### Per-worker active sessions (`active_sessions`) +Each worker thread maintains its own dictionary: + +- `active_sessions: Dict[str, SessionState]` + +Each `SessionState` represents one grid submission (one `connector.send(...)` call): +- `session_id`: session identifier (from connector response, or a local fallback) +- `submission_handle`: full response object from `send()` (passed back to `get_results()`) +- `grid_to_dag`: maps grid task ID -> DAG node ID +- `remaining_grid_task_ids`: set of grid task IDs not yet emitted to the results queue +- `task_definitions_by_node`: stored task payloads, used for requeueing +- `poll_failures`: consecutive poll failures for this session + +`remaining_grid_task_ids` is important because many connectors return cumulative `finished` +lists. Without this set, the same finished task could be emitted multiple times. + +## Threading model + +### Lifecycle +1. Adapter `__init__`: + - Create shared queues (incoming + results). + - Create locks/conditions. + - Create a stop signal (`threading.Event`). + - Start `num_connector_threads` threads; each thread creates/authenticates its own + connector instance and enters the worker loop. +2. Adapter `shutdown()`: + - Set stop event. + - Notify any waiting workers. + - Join threads (optional). + +### Shared synchronization +The current implementation uses `collections.deque` with explicit `threading.Condition` +objects so that: +- `submit_tasks()` can enqueue multiple items and notify workers once. +- workers can atomically dequeue up to `max_dequeue_per_loop` items. + +## Worker loop (required behavior) +Each worker thread runs a loop until shutdown: + +1. **Pick up tasks**: dequeue up to `max_dequeue_per_loop` tasks from the incoming tasks queue. +2. **Submit new tasks**: + - Submit the dequeued tasks as a single `connector.send(task_definitions)` call. + - Extract `session_id` + `task_ids` (if present) to build `grid_to_dag` mapping. + - Store `SessionState` in the worker's `active_sessions`. + - Apply retry logic (`retry_attempts`); on final failure, requeue tasks. +3. **Poll for completions** (only sessions submitted by this worker): + - Iterate the worker's `active_sessions`. + - Call `connector.get_results(submission_handle, timeout_sec=poll_timeout_sec)`. + - Extract finished grid task IDs. + - Map finished grid IDs to DAG IDs and push them into the completed results queue. + - Remove completed grid IDs from `remaining_grid_task_ids`. + - When a session has no remaining grid tasks, remove it from `active_sessions`. + - If `get_results()` fails repeatedly (`max_poll_failures`), requeue remaining tasks + and drop the session. +4. **Idle control**: + - If there were no new tasks and there are no active sessions, wait on a condition + variable for new work (with a timeout). + - If there are active sessions but no new tasks, sleep `poll_interval_sec` between + polls to avoid busy spinning. + +### Pseudocode +```python +def worker_main(thread_id: int) -> None: + connector = create_and_authenticate_connector(thread_id) + active_sessions: dict[str, SessionState] = {} + + while not stop_event.is_set(): + new_items = dequeue_up_to(max_dequeue_per_loop) + + if new_items: + for attempt in range(retry_attempts): + try: + handle = connector.send([i.task_definition for i in new_items]) + session = build_session_state(handle, new_items) + active_sessions[session.session_id] = session + break + except Exception: + if attempt == retry_attempts - 1: + record_error(thread_id, new_items) + requeue_items(new_items) + else: + sleep(1.0) + + for session_id, session in list(active_sessions.items()): + try: + results = connector.get_results(session.submission_handle, timeout_sec=poll_timeout_sec) + session.poll_failures = 0 + except Exception: + session.poll_failures += 1 + if session.poll_failures >= max_poll_failures: + requeue_remaining(session) + del active_sessions[session_id] + continue + finished = extract_finished(results) + for grid_task_id in finished: + node_id = session.grid_to_dag.get(grid_task_id) + if node_id is not None and grid_task_id in session.remaining_grid_task_ids: + session.remaining_grid_task_ids.remove(grid_task_id) + results_queue_push(node_id) + if not session.remaining_grid_task_ids: + del active_sessions[session_id] + + if not new_items and not active_sessions: + wait_for_work_or_timeout(poll_interval_sec) + elif active_sessions and poll_interval_sec: + sleep(poll_interval_sec) +``` + +## `submit_tasks()` behavior +`submit_tasks()` should: +- Deduplicate: do not enqueue the same DAG task ID more than once. +- Update DAG status early enough to prevent re-queuing by the scheduler on subsequent + iterations. +- Recommended: call `dag_container.mark_task_submitted(node_id)` in `submit_tasks()` + (single-threaded from scheduler). +- Build `task_definition` (recommended) and enqueue items under a lock. +- Notify a condition variable to wake sleeping workers. +- Note: `poll_completed()` removes completed node IDs from the de-dup set. + +Notes: +- If the DAG container is not thread-safe for status updates, keep all DAG status + transitions on the scheduler thread. +- Queues are unbounded: `submit_tasks()` enqueues all deduped tasks and workers push + all completed results until they are drained via `poll_completed()`. + +## `poll_completed()` behavior +`poll_completed()` should: +- Acquire the results queue lock. +- Drain all currently available completed DAG IDs (or up to a configured max) and return + them. +- Return quickly; do not call the grid connector from `poll_completed()`. +- Clear the de-dup guard for returned node IDs. + +## Error handling and observability +Background threads can fail; the implementation surfaces failures via: + +- A thread-safe, bounded `errors` deque (`errors_maxlen`) storing `WorkerError` records. +- Logging with thread id and stage for `send`, `get_results`, and other exceptions. +- `raise_if_failed()` raising when a worker crashes (`worker_crash`) or when all workers + stop. +- `send()` failures are retried; when retries are exhausted, tasks are requeued. +- `get_results()` failures are tracked per session; after `max_poll_failures`, remaining + tasks are requeued. + +## Sequence diagram (submit -> run -> poll -> complete) + +```mermaid +sequenceDiagram + participant S as Scheduler (HTCDagScheduler) + participant D as DAG Container (BaseDAG) + participant A as GridConnectorDagAdapter + participant W as Worker Thread (per connector) + participant C as Grid Connector (AWSConnector/MockGridConnector) + + loop scheduler iterations + S->>D: get_ready_tasks() + S->>A: submit_tasks(node_ids, D) + A->>D: build_grid_task(node_id) x N + A->>A: enqueue QueuedTask x N\n+ notify workers + A->>D: mark_task_submitted(node_id) x N + + W->>A: dequeue up to max_dequeue_per_loop + W->>C: send(task_definitions) + C-->>W: submission_handle(session_id, task_ids) + W->>W: create SessionState\n(grid_to_dag, remaining_grid_task_ids) + + loop poll until complete + W->>C: get_results(handle, timeout) + C-->>W: results(finished=[...])\n(cumulative) + W->>A: append completed node_id(s)\ninto results queue + end + + S->>A: poll_completed() + A-->>S: [completed node_id, ...] + S->>D: mark_task_complete(node_id) x K + end + + Note over S,A: App shutdown + S->>A: shutdown() + A-->>W: stop_event + notify_all() +``` + +## Metrics (not implemented) +Metrics are not tracked in the current adapter. If needed, consider adding: +- Queue sizes: incoming depth, results depth. +- Submission metrics: total tasks enqueued/submitted, retries, failures. +- Poll metrics: poll duration, finished tasks per poll. +- Active session count per worker and total. + +## Integration notes (with `HTCDagScheduler`) +- Scheduler loop remains the same: + - `submit_tasks(ready_nodes, dag_container)` becomes an enqueue operation. + - `poll_completed()` becomes a drain of completed IDs produced by workers. +- Ensure tasks are marked `submitted` before they can appear as ready again. + +## Testing strategy (recommended) +- Unit tests with a deterministic mock connector: + - enqueue tasks, verify workers submit in batches + - simulate results and verify `poll_completed()` returns expected DAG IDs + - verify dedupe behavior in `submit_tasks()` + - verify `shutdown()` terminates threads promptly +- Concurrency tests: + - multiple `submit_tasks()` calls while polling is active + - high-throughput enqueue + completion drain without deadlocks diff --git a/examples/client/python/business_dag_loader.py b/examples/client/python/business_dag_loader.py new file mode 100644 index 00000000..c7d5d03a --- /dev/null +++ b/examples/client/python/business_dag_loader.py @@ -0,0 +1,113 @@ +# Copyright 2024 Amazon.com, Inc. or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ + +import argparse +import json +import logging +from pathlib import Path + +import networkx as nx +from networkx.readwrite import json_graph + + +class BusinessDagLoader: + """Loads a DAG file into a NetworkX DiGraph. + + Currently supports NetworkX node-link JSON. + """ + + def __init__(self, copies: int = 1) -> None: + if copies < 1: + raise ValueError("copies must be >= 1") + self._copies = copies + self._logger = logging.getLogger(self.__class__.__name__) + + def load(self, path: str) -> nx.DiGraph: + data = json.loads(Path(path).read_text(encoding="utf-8")) + return self.loads(data) + + def loads(self, data: object) -> nx.DiGraph: + if not isinstance(data, dict): + raise ValueError("Invalid JSON: expected an object at the root") + + if "edges" in data: + edges_key = "edges" + elif "links" in data: + edges_key = "links" + else: + raise ValueError("Invalid node-link JSON: expected 'edges' or 'links' key") + + graph = json_graph.node_link_graph(data, link=edges_key) + + digraph = nx.DiGraph(graph) + for _, node_data in digraph.nodes(data=True): + node_data.setdefault("task_type", "compute") + node_data.setdefault("worker_arguments", ["1000", "1", "1"]) + + if self._copies > 1: + return self._replicate_dag(digraph, copies=self._copies) + + return digraph + + def _replicate_dag(self, nx_dag: nx.DiGraph, copies: int, super_root: str = "SUPER_ROOT") -> nx.DiGraph: + """ + Create a new DAG composed of `copies` disjoint copies of the input DAG, all + attached beneath a new super-root node. + + Each copy's nodes are prefixed with an index (`0_`, `1_`, ...) to keep names unique. + The super-root is connected to one root of each copy (root with the highest out-degree). + """ + if copies < 1: + raise ValueError("copies must be >= 1") + if nx_dag.number_of_nodes() == 0: + raise ValueError("input DAG is empty") + + roots = [n for n in nx_dag.nodes() if nx_dag.in_degree(n) == 0] + + if not roots: + raise ValueError("input DAG has no roots (graph may contain cycles)") + + # Choose the root with the highest out_degree. + first_root = max(roots, key=lambda n: nx_dag.out_degree(n)) + self._logger.debug( + "Selected root with max out_degree: %s (out_degree=%s)", + first_root, + nx_dag.out_degree(first_root), + ) + + combined = nx.DiGraph() + combined.add_node(super_root) + + combined.nodes[super_root]["task_type"] = "compute" + combined.nodes[super_root]["worker_arguments"] = ["1", "1", "1"] + + for idx in range(copies): + mapping = {node: f"{idx}_{node}" for node in nx_dag.nodes()} + copy_graph = nx.relabel_nodes(nx_dag, mapping, copy=True) + combined.update(copy_graph) + combined.add_edge(super_root, mapping[first_root]) + + return combined + + +def main() -> int: + parser = argparse.ArgumentParser(description="Load a NetworkX node-link JSON file into a DiGraph.") + parser.add_argument("path", nargs="?", default="morse_trie_networkx.json", help="Path to node-link JSON file") + args = parser.parse_args() + + path = Path(args.path) + graph = BusinessDagLoader().load(path) + + print(f"Loaded graph type: {type(graph).__name__}") + print(f"Nodes: {graph.number_of_nodes()}, Edges: {graph.number_of_edges()}") + if "" in graph: + print("Root node '' present: yes") + else: + print("Root node '' present: no") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/client/python/config/config.json b/examples/client/python/config/config.json new file mode 100644 index 00000000..071b7ca8 --- /dev/null +++ b/examples/client/python/config/config.json @@ -0,0 +1,21 @@ +{ + "use_mock_grid": false, + "polling_interval_seconds": 2, + "max_dequeue_per_loop": 500, + "task_timeout_sec": 300, + "max_concurrent_submissions": 5, + "retry_attempts": 3, + "num_connector_threads": 2, + "log_level": "INFO", + "show_dag_visualization": false, + "mock_tasks_sleep_time_ms": 100, + "poll_timeout_sec": 0.01, + "poll_interval_sec": 0.1, + + + "htc_grid": { + "client_config_file": "/etc/agent/Agent_config.tfvars.json", + "username": "", + "password": "" + } +} diff --git a/examples/client/python/dag_client.py b/examples/client/python/dag_client.py new file mode 100644 index 00000000..71d69242 --- /dev/null +++ b/examples/client/python/dag_client.py @@ -0,0 +1,124 @@ +# Copyright 2024 Amazon.com, Inc. or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ + +# HTC-Client-DAG: Main client application for processing DAGs using HTC Grid. + +import argparse +import json +import logging +import os +import sys +import pickle + +from utils.dag.base.dag_generator import DAGGenerator +from utils.dag.base.nx_htc_dag_container import NxHtcDagContainer +from business_dag_loader import BusinessDagLoader + +from utils.dag.schedulers.htc_dag_scheduler import HTCDagScheduler +from utils.dag.adapters.grid_connector_dag_adapter import GridConnectorDagAdapter + + +def setup_logging(log_level: str = "DEBUG") -> logging.Logger: + """Setup logging configuration and return the module logger.""" + logging.basicConfig( + level=getattr(logging, log_level.upper(), logging.DEBUG), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) + lg = logging.getLogger("HTCClientDAG") + + return lg + + +logger = logging.getLogger("HTCClientDAG") + + +def load_config(config_file: str) -> dict: + if not config_file: + raise ValueError("config_file is required to load configuration") + if not os.path.exists(config_file): + raise FileNotFoundError(f"Config file {config_file} not found") + try: + with open(config_file, "r", encoding="utf-8") as file: + return json.load(file) + except Exception as exc: + raise ValueError(f"Failed to load config file {config_file}: {exc}") from exc + + +def main() -> int: + + parser = argparse.ArgumentParser(description="HTC-Client-DAG: Process DAGs using HTC Grid") + parser.add_argument("--config", default="config/config.json", help="Path to configuration file") + parser.add_argument("--mock", action="store_true", help="Use mock grid connector") + parser.add_argument("--generate", action="store_true", help="Generate test DAG.") + parser.add_argument("--depth", type=int, default=3, help="DAG depth for generation.") + parser.add_argument("--breadth", type=int, default=2, help="DAG breadth for generation.") + parser.add_argument("--dag-file", help="Load DAG from a file, format must match business dag loader.") + parser.add_argument("--visualize", action="store_true", help="Show DAG visualization during execution (for small scale tests only).") + parser.add_argument("--copies", type=int, default=1, help="Replicates input DAG from the file for large scale tests.") + + args = parser.parse_args() + + try: + + config = load_config(args.config) + + + global logger + logger = setup_logging(config.get("log_level", "INFO")) + logger.info("Configuration loaded from %s", args.config) + + if args.visualize: + config["show_dag_visualization"] = True + + if args.mock: + config["use_mock_grid"] = True + logger.info("Using mock grid connector for local testing") + + nx_dag = None + if args.generate: + nx_dag = DAGGenerator(config).generate_dag(args.depth, args.breadth) + elif args.dag_file: + nx_dag = BusinessDagLoader(copies=args.copies).load(args.dag_file) + else: + raise ValueError("Must specify either --generate or --dag-file") + + + + logger.info("DAG STATISTICS:") + logger.info("Nodes: %s, Edges: %s", nx_dag.number_of_nodes(), nx_dag.number_of_edges()) + data = pickle.dumps(nx_dag) + logger.info( + "Pickled NetworkX DiGraph size: %.2f MB (%s bytes)", + len(data) / 1_000_000, + len(data), + ) + + nx_dag_container = NxHtcDagContainer(nx_dag) + + adapter = GridConnectorDagAdapter(config, logger) + + try: + scheduler = HTCDagScheduler( + config=config, + dag_container=nx_dag_container, + grid_connector_adapter=adapter, + ) + + # Process DAG + success = scheduler.run() + return 0 if success else 1 + finally: + try: + adapter.shutdown(wait=True, timeout=float(config.get("shutdown_timeout_sec", 10.0))) + except Exception: + logger.exception("Failed to shutdown threaded adapter cleanly") + + except Exception as e: + logger.error(f"Application failed: {str(e)}", exc_info=True) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/client/python/data/sample_dag.json b/examples/client/python/data/sample_dag.json new file mode 100644 index 00000000..d7ad9d83 --- /dev/null +++ b/examples/client/python/data/sample_dag.json @@ -0,0 +1,487 @@ +{ + "directed": true, + "multigraph": false, + "graph": {}, + "nodes": [ + { + "layer": 0, + "pos": [ + 0.0, + -0.0 + ], + "id": "" + }, + { + "layer": 1, + "pos": [ + 0.5, + -1.0 + ], + "id": "t" + }, + { + "layer": 1, + "pos": [ + -0.5, + -1.0 + ], + "id": "e" + }, + { + "layer": 2, + "pos": [ + 0.5, + -2.0 + ], + "id": "m" + }, + { + "layer": 2, + "pos": [ + 1.5, + -2.0 + ], + "id": "n" + }, + { + "layer": 2, + "pos": [ + -1.5, + -2.0 + ], + "id": "a" + }, + { + "layer": 2, + "pos": [ + -0.5, + -2.0 + ], + "id": "i" + }, + { + "layer": 3, + "pos": [ + -0.5, + -3.0 + ], + "id": "o" + }, + { + "layer": 3, + "pos": [ + -2.5, + -3.0 + ], + "id": "g" + }, + { + "layer": 3, + "pos": [ + -1.5, + -3.0 + ], + "id": "k" + }, + { + "layer": 3, + "pos": [ + -3.5, + -3.0 + ], + "id": "d" + }, + { + "layer": 3, + "pos": [ + 3.5, + -3.0 + ], + "id": "w" + }, + { + "layer": 3, + "pos": [ + 0.5, + -3.0 + ], + "id": "r" + }, + { + "layer": 3, + "pos": [ + 2.5, + -3.0 + ], + "id": "u" + }, + { + "layer": 3, + "pos": [ + 1.5, + -3.0 + ], + "id": "s" + }, + { + "layer": 4, + "pos": [ + 1.5, + -4.0 + ], + "id": "q" + }, + { + "layer": 4, + "pos": [ + 5.5, + -4.0 + ], + "id": "z" + }, + { + "layer": 4, + "pos": [ + 4.5, + -4.0 + ], + "id": "y" + }, + { + "layer": 4, + "pos": [ + -4.5, + -4.0 + ], + "id": "c" + }, + { + "layer": 4, + "pos": [ + 3.5, + -4.0 + ], + "id": "x" + }, + { + "layer": 4, + "pos": [ + -5.5, + -4.0 + ], + "id": "b" + }, + { + "layer": 4, + "pos": [ + -1.5, + -4.0 + ], + "id": "j" + }, + { + "layer": 4, + "pos": [ + 0.5, + -4.0 + ], + "id": "p" + }, + { + "layer": 4, + "pos": [ + -0.5, + -4.0 + ], + "id": "l" + }, + { + "layer": 4, + "pos": [ + -3.5, + -4.0 + ], + "id": "f" + }, + { + "layer": 4, + "pos": [ + 2.5, + -4.0 + ], + "id": "v" + }, + { + "layer": 4, + "pos": [ + -2.5, + -4.0 + ], + "id": "h" + } + ], + "edges": [ + { + "char": "—", + "source": "", + "target": "t" + }, + { + "char": "•", + "source": "", + "target": "e" + }, + { + "char": "—", + "source": "t", + "target": "m" + }, + { + "char": "•", + "source": "t", + "target": "n" + }, + { + "char": "—", + "source": "e", + "target": "a" + }, + { + "char": "•", + "source": "e", + "target": "i" + }, + { + "char": "—", + "source": "m", + "target": "o" + }, + { + "char": "•", + "source": "m", + "target": "g" + }, + { + "char": "—", + "source": "n", + "target": "k" + }, + { + "char": "•", + "source": "n", + "target": "d" + }, + { + "char": "—", + "source": "a", + "target": "w" + }, + { + "char": "•", + "source": "a", + "target": "r" + }, + { + "char": "—", + "source": "i", + "target": "u" + }, + { + "char": "•", + "source": "i", + "target": "s" + }, + { + "char": "—", + "source": "g", + "target": "q" + }, + { + "char": "•", + "source": "g", + "target": "z" + }, + { + "char": "—", + "source": "k", + "target": "y" + }, + { + "char": "•", + "source": "k", + "target": "c" + }, + { + "char": "—", + "source": "d", + "target": "x" + }, + { + "char": "•", + "source": "d", + "target": "b" + }, + { + "char": "—", + "source": "w", + "target": "j" + }, + { + "char": "•", + "source": "w", + "target": "p" + }, + { + "char": "•", + "source": "r", + "target": "l" + }, + { + "char": "•", + "source": "u", + "target": "f" + }, + { + "char": "—", + "source": "s", + "target": "v" + }, + { + "char": "•", + "source": "s", + "target": "h" + } + ], + "links": [ + { + "char": "—", + "source": "", + "target": "t" + }, + { + "char": "•", + "source": "", + "target": "e" + }, + { + "char": "—", + "source": "t", + "target": "m" + }, + { + "char": "•", + "source": "t", + "target": "n" + }, + { + "char": "—", + "source": "e", + "target": "a" + }, + { + "char": "•", + "source": "e", + "target": "i" + }, + { + "char": "—", + "source": "m", + "target": "o" + }, + { + "char": "•", + "source": "m", + "target": "g" + }, + { + "char": "—", + "source": "n", + "target": "k" + }, + { + "char": "•", + "source": "n", + "target": "d" + }, + { + "char": "—", + "source": "a", + "target": "w" + }, + { + "char": "•", + "source": "a", + "target": "r" + }, + { + "char": "—", + "source": "i", + "target": "u" + }, + { + "char": "•", + "source": "i", + "target": "s" + }, + { + "char": "—", + "source": "g", + "target": "q" + }, + { + "char": "•", + "source": "g", + "target": "z" + }, + { + "char": "—", + "source": "k", + "target": "y" + }, + { + "char": "•", + "source": "k", + "target": "c" + }, + { + "char": "—", + "source": "d", + "target": "x" + }, + { + "char": "•", + "source": "d", + "target": "b" + }, + { + "char": "—", + "source": "w", + "target": "j" + }, + { + "char": "•", + "source": "w", + "target": "p" + }, + { + "char": "•", + "source": "r", + "target": "l" + }, + { + "char": "•", + "source": "u", + "target": "f" + }, + { + "char": "—", + "source": "s", + "target": "v" + }, + { + "char": "•", + "source": "s", + "target": "h" + } + ] +} \ No newline at end of file diff --git a/examples/submissions/k8s_jobs/Dockerfile.Submitter b/examples/submissions/k8s_jobs/Dockerfile.Submitter index ca33573b..4ea5252b 100644 --- a/examples/submissions/k8s_jobs/Dockerfile.Submitter +++ b/examples/submissions/k8s_jobs/Dockerfile.Submitter @@ -1,5 +1,5 @@ ARG HTCGRID_ECR_REPO -FROM ${HTCGRID_ECR_REPO}/ecr-public/docker/library/python:3.8-slim +FROM ${HTCGRID_ECR_REPO}/ecr-public/docker/library/python:3.9-slim # Run as user nobody:nogroup ARG USER=65534 #nobody @@ -11,12 +11,22 @@ RUN apt-get update && \ gcc && \ mkdir -p /dist/python/ /app/py_connector +# Copy Python wheels and application code COPY ./dist/python/* /dist/python/ -COPY ./examples/client/python/* /app/py_connector/ +COPY ./examples/client/python/ /app/py_connector/ WORKDIR /app/py_connector + +# Install packages: wheels first, then requirements RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir /dist/python/*.whl && \ pip install --no-cache-dir -r requirements.txt && \ + # Create temporary directories for multiprocessing + mkdir -p /app/tmp /tmp && \ + chmod 1777 /app/tmp /tmp && \ chown -R $USER:$GROUP /app /dist -USER ${USER}:${GROUP} +# Set TMPDIR environment variable to use our writable temp directory +ENV TMPDIR=/app/tmp + +USER ${USER}:${GROUP} \ No newline at end of file diff --git a/examples/submissions/k8s_jobs/Makefile b/examples/submissions/k8s_jobs/Makefile index baee4ef1..7927c9eb 100644 --- a/examples/submissions/k8s_jobs/Makefile +++ b/examples/submissions/k8s_jobs/Makefile @@ -33,5 +33,9 @@ generated: mkdir -p $(GENERATED) && cat cancel-one-long-task-test.yaml.tpl | sed "s/{{account_id}}/$(ACCOUNT_ID)/;s/{{region}}/$(REGION)/;s/{{image_name}}/$(SUBMITTER_IMAGE_NAME)/;s/{{image_tag}}/$(TAG)/" > $(GENERATED)/cancel-one-long-task-test.yaml + mkdir -p $(GENERATED) && cat dag-workload-generated.yaml.tpl | sed "s/{{account_id}}/$(ACCOUNT_ID)/;s/{{region}}/$(REGION)/;s/{{image_name}}/$(SUBMITTER_IMAGE_NAME)/;s/{{image_tag}}/$(TAG)/" > $(GENERATED)/dag-workload-generated.yaml + + mkdir -p $(GENERATED) && cat dag-workload-from-file.yaml.tpl | sed "s/{{account_id}}/$(ACCOUNT_ID)/;s/{{region}}/$(REGION)/;s/{{image_name}}/$(SUBMITTER_IMAGE_NAME)/;s/{{image_tag}}/$(TAG)/" > $(GENERATED)/dag-workload-from-file.yaml + clean: rm -rf $(GENERATED) diff --git a/examples/submissions/k8s_jobs/dag-workload-from-file.yaml.tpl b/examples/submissions/k8s_jobs/dag-workload-from-file.yaml.tpl new file mode 100644 index 00000000..283beaa3 --- /dev/null +++ b/examples/submissions/k8s_jobs/dag-workload-from-file.yaml.tpl @@ -0,0 +1,55 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: dag-workload-from-file + annotations: + seccomp.security.alpha.kubernetes.io/pod: "runtime/default" +spec: + template: + spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + seccompProfile: + type: RuntimeDefault + containers: + - name: generator + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + runAsNonRoot: true + seccompProfile: + type: RuntimeDefault + capabilities: + drop: + - NET_RAW + - ALL + image: {{account_id}}.dkr.ecr.{{region}}.amazonaws.com/{{image_name}}:{{image_tag}} + imagePullPolicy: Always + resources: + limits: + cpu: 8000m + memory: 12000Mi + requests: + cpu: 8000m + memory: 12000Mi + command: ["python3", "./dag_client.py", "--dag-file", "./data/sample_dag.json"] + volumeMounts: + - name: agent-config-volume + mountPath: /etc/agent + env: + - name: INTRA_VPC + value: "1" + restartPolicy: Never + nodeSelector: + htc/node-type: core + tolerations: + - effect: NoSchedule + key: htc/node-type + operator: Equal + value: core + volumes: + - name: agent-config-volume + configMap: + name: agent-configmap + backoffLimit: 0 diff --git a/examples/submissions/k8s_jobs/dag-workload-generated.yaml.tpl b/examples/submissions/k8s_jobs/dag-workload-generated.yaml.tpl new file mode 100644 index 00000000..5df57453 --- /dev/null +++ b/examples/submissions/k8s_jobs/dag-workload-generated.yaml.tpl @@ -0,0 +1,55 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: dag-workload-generated + annotations: + seccomp.security.alpha.kubernetes.io/pod: "runtime/default" +spec: + template: + spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + seccompProfile: + type: RuntimeDefault + containers: + - name: generator + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + runAsNonRoot: true + seccompProfile: + type: RuntimeDefault + capabilities: + drop: + - NET_RAW + - ALL + image: {{account_id}}.dkr.ecr.{{region}}.amazonaws.com/{{image_name}}:{{image_tag}} + imagePullPolicy: Always + resources: + limits: + cpu: 8000m + memory: 12000Mi + requests: + cpu: 8000m + memory: 12000Mi + command: ["python3", "./dag_client.py", "--generate", "--depth", "3", "--breadth", "2"] + volumeMounts: + - name: agent-config-volume + mountPath: /etc/agent + env: + - name: INTRA_VPC + value: "1" + restartPolicy: Never + nodeSelector: + htc/node-type: core + tolerations: + - effect: NoSchedule + key: htc/node-type + operator: Equal + value: core + volumes: + - name: agent-config-volume + configMap: + name: agent-configmap + backoffLimit: 0 diff --git a/source/client/python/api-v0.1/api/connector.py b/source/client/python/api-v0.1/api/connector.py index ba94c4bc..8b84c3c5 100644 --- a/source/client/python/api-v0.1/api/connector.py +++ b/source/client/python/api-v0.1/api/connector.py @@ -15,6 +15,7 @@ from utils.state_table_common import TASK_STATE_FINISHED from warrant_lite import WarrantLite from apscheduler.schedulers.background import BackgroundScheduler +from api.connector_interface import GridConnectorInterface if os.environ.get("INTRA_VPC"): from privateapi import Configuration, ApiClient, ApiException @@ -51,7 +52,7 @@ def get_safe_session_id(): return str(uuid.uuid1()) -class AWSConnector: +class AWSConnector(GridConnectorInterface): """This class implements the API for managing jobs""" in_out_manager = None diff --git a/source/client/python/api-v0.1/api/connector_interface.py b/source/client/python/api-v0.1/api/connector_interface.py new file mode 100644 index 00000000..099be7a1 --- /dev/null +++ b/source/client/python/api-v0.1/api/connector_interface.py @@ -0,0 +1,48 @@ +# Copyright 2024 Amazon.com, Inc. or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ + +from abc import ABC, abstractmethod +from typing import Dict, List, Any, Optional + +class GridConnectorInterface(ABC): + """Abstract interface for grid connectors""" + + @abstractmethod + def init(self, config: Dict[str, Any]) -> None: + """Initialize the grid connector + + Args: + config: Configuration dictionary + """ + pass + + @abstractmethod + def authenticate(self) -> None: + """Authenticate with the grid""" + pass + + @abstractmethod + def send(self, task_vector: List[Dict[str, Any]]) -> str: + """Submit a vector of tasks to the grid + + Args: + task_vector: List of task definitions + + Returns: + Submission response object for tracking + """ + pass + + @abstractmethod + def get_results(self, submission_resp: str, timeout_sec: Optional[int] = None) -> Optional[Dict[str, Any]]: + """Get results for a specific submission + + Args: + submission_resp: Submission response object from send() + timeout_sec: Maximum wait time in seconds (0 for non-blocking check) + + Returns: + Results if complete, None if still in progress + """ + pass diff --git a/source/client/python/api-v0.1/api/mock_connector.py b/source/client/python/api-v0.1/api/mock_connector.py new file mode 100644 index 00000000..1d4ddbce --- /dev/null +++ b/source/client/python/api-v0.1/api/mock_connector.py @@ -0,0 +1,191 @@ +# Copyright 2024 Amazon.com, Inc. or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ + +import time +import uuid +import logging +from typing import Dict, List, Any, Optional +from api.connector_interface import GridConnectorInterface + +logger = logging.getLogger("MockGridConnector") + +class MockGridConnector(GridConnectorInterface): + """Mock implementation of the Grid Connector for local testing""" + + def __init__(self): + """Initialize the mock grid connector""" + self.active_submissions = {} # Mapping of submission ID to task status + self.config = None + logger.info("Mock Grid Connector initialized") + + def init(self, config: Dict[str, Any]) -> None: + """Initialize the grid connector + + Args: + config: Configuration dictionary + """ + self.config = config + logger.info("Mock Grid Connector configuration loaded") + + def authenticate(self) -> None: + """Authenticate with the grid""" + # Mock implementation - always succeeds + logger.info("Mock authentication successful") + + def send(self, task_vector: List[Dict[str, Any]]) -> object: + """Submit a vector of tasks to the grid + + Args: + task_vector: List of task definitions + + Returns: + Mock PostSubmitResponse object for tracking + """ + # Generate a unique submission ID + session_id = str(uuid.uuid4()) + + # Generate task IDs in HTC Grid format + task_ids = [f"{session_id}_{i}" for i in range(len(task_vector))] + + submitted_at = time.monotonic() + durations_sec: List[float] = [] + for task in task_vector: + try: + worker_args = task.get("worker_arguments", []) + sleep_ms = int(worker_args[0]) if worker_args else 0 + durations_sec.append(max(0.0, sleep_ms / 1000.0)) + except Exception: + durations_sec.append(0.0) + + # Model a grid-style submission: tasks run "in parallel", and the session completes + # when the longest-running task completes. + task_complete_at = [submitted_at + d for d in durations_sec] + session_complete_at = max(task_complete_at) if task_complete_at else submitted_at + + # Store submission info + self.active_submissions[session_id] = { + "status": "processing", + "tasks": task_vector, + "task_ids": task_ids, + "results": None, + "thread": None, + "submitted_at": submitted_at, + "task_complete_at": task_complete_at, + "session_complete_at": session_complete_at, + } + + # Create mock PostSubmitResponse object + class MockPostSubmitResponse: + def __init__(self, session_id, task_ids): + self.session_id = session_id + self.task_ids = task_ids + + def get(self, key, default=None): + """Make it dict-like for compatibility""" + if key == 'session_id': + return self.session_id + elif key == 'task_ids': + return self.task_ids + return default + + def __getitem__(self, key): + """Support dict-like access""" + if key == 'session_id': + return self.session_id + elif key == 'task_ids': + return self.task_ids + raise KeyError(key) + + def __contains__(self, key): + """Support 'in' operator""" + return key in ['session_id', 'task_ids'] + + def __str__(self): + return f"{{'session_id': '{self.session_id}', 'task_ids': {self.task_ids}}}" + + response = MockPostSubmitResponse(session_id, task_ids) + logger.debug(f"Mock submission created: {session_id} with {len(task_vector)} tasks") + return response + + def get_results(self, submission_dict: Dict[str, Any], timeout_sec: Optional[int] = None) -> Optional[object]: + """Get results for a specific submission + + Args: + submission_dict: Dictionary with session_id and task_ids + timeout_sec: Maximum wait time in seconds (0 for non-blocking check) + + Returns: + Mock GetResponse object if complete, None if still in progress + """ + session_id = submission_dict.get('session_id') + if not session_id or session_id not in self.active_submissions: + logger.warning(f"Unknown session ID: {session_id}") + return None + + submission = self.active_submissions[session_id] + now = time.monotonic() + + all_task_ids: List[str] = list(submission.get("task_ids") or []) + + finished_task_ids: List[str] = [] + if submission.get("status") == "complete": + finished_task_ids = all_task_ids + else: + task_complete_at = submission.get("task_complete_at") + if isinstance(task_complete_at, list) and len(task_complete_at) == len(all_task_ids): + finished_task_ids = [ + task_id + for task_id, complete_at in zip(all_task_ids, task_complete_at) + if now >= float(complete_at) + ] + + if len(finished_task_ids) == len(all_task_ids) and all_task_ids: + submission["status"] = "complete" + + # Always return cumulative "finished so far" in exact HTC Grid format. + return { + "cancelled": [], + "cancelled_OUTPUT": [], + "failed": [], + "failed_OUTPUT": [], + "finished": finished_task_ids, + "finished_OUTPUT": ["mock_output"] * len(finished_task_ids), + "metadata": {"tasks_in_response": len(finished_task_ids)}, + } + + def _process_tasks(self, submission_id: str, task_vector: List[Dict[str, Any]], task_ids: List[str]) -> None: + """Process tasks in the background + + Args: + submission_id: Submission ID + task_vector: List of task definitions + task_ids: List of task IDs + """ + results = {} + + for i, task in enumerate(task_vector): + # Extract sleep duration from first parameter + try: + sleep_ms = int(task["worker_arguments"][0]) + + # Mark task as completed + results[f"task_{i}"] = { + "status": "completed", + "result": f"Mock result for task {i}", + "worker_arguments": task["worker_arguments"] + } + + except (KeyError, IndexError, ValueError) as e: + logger.error(f"Error processing mock task {i}: {str(e)}") + results[f"task_{i}"] = { + "status": "failed", + "error": str(e) + } + + # Update submission status + self.active_submissions[submission_id]["status"] = "complete" + self.active_submissions[submission_id]["results"] = results + self.active_submissions[submission_id]["session_complete_at"] = time.monotonic() + + logger.info(f"Mock submission {submission_id} processing completed") diff --git a/source/client/python/utils/setup.py b/source/client/python/utils/setup.py index 010d31b6..414db9c0 100644 --- a/source/client/python/utils/setup.py +++ b/source/client/python/utils/setup.py @@ -1,4 +1,4 @@ -# Copyright 2024 Amazon.com, Inc. or its affiliates. +# Copyright 2024 Amazon.com, Inc. or its affiliates. # SPDX-License-Identifier: Apache-2.0 # Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ @@ -15,7 +15,7 @@ description="Utilities for testing and profiling the HTC-grid", long_description=long_description, long_description_content_type="text/markdown", - packages=setuptools.find_packages(), + packages=setuptools.find_namespace_packages(include=['utils*']), classifiers=[ "Programming Language :: Python :: 3", "Operating System :: OS Independent", diff --git a/source/client/python/utils/utils/dag/__init__.py b/source/client/python/utils/utils/dag/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/source/client/python/utils/utils/dag/adapters/__init__.py b/source/client/python/utils/utils/dag/adapters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/source/client/python/utils/utils/dag/adapters/grid_connector_dag_adapter.py b/source/client/python/utils/utils/dag/adapters/grid_connector_dag_adapter.py new file mode 100644 index 00000000..c441288f --- /dev/null +++ b/source/client/python/utils/utils/dag/adapters/grid_connector_dag_adapter.py @@ -0,0 +1,485 @@ +""" +Threaded Grid Connector DAG Adapter. + +Implements a long-lived worker-thread model with: +- an incoming tasks queue (enqueue-only from submit_tasks) +- a completed results queue (drain-only from poll_completed) + +See docs/THREADED_CONNECTOR_ADAPTER_SPEC.md for the intended behavior. +""" + +from __future__ import annotations + +import json +import logging +import threading +import time +from collections import deque +from dataclasses import dataclass +from typing import Any, Deque, Dict, Hashable, List, Optional, Set + +from utils.dag.grid_connector_factory import BaseGridConnectorFactory, GridConnectorFactory +from utils.dag.base.base_dag import BaseDAG + + +@dataclass(frozen=True) +class QueuedTask: + node_id: Hashable + task_definition: Dict[str, Any] + + +@dataclass +class SessionState: + session_id: str + + # full response from the connector + submission_handle: Any + + # Mapping from grid task IDs (what the connector returns in results["finished"]) back to + # DAG node IDs (what the scheduler understands). + grid_to_dag: Dict[str, Hashable] + + # The same grid task ID can appear on every get_results() call. + # remaining_grid_task_ids lets the worker ignore already-processed IDs + # and avoid pushing the same node_id into _results + remaining_grid_task_ids: Set[str] + + # Mapping from DAG node IDs back to their task definitions for requeueing. + task_definitions_by_node: Dict[Hashable, Dict[str, Any]] + + poll_failures: int = 0 + + +@dataclass(frozen=True) +class WorkerError: + thread_id: int + stage: str + error: Exception + when_ts: float + + +class GridConnectorDagAdapter: + """Threaded connector adapter skeleton based on THREADED_CONNECTOR_ADAPTER_SPEC.md.""" + + def __init__( + self, + config: Dict[str, Any], + logger: logging.Logger, + connector_factory: Optional[BaseGridConnectorFactory] = None, + ) -> None: + self._logger = logger + self.config = config + self._connector_factory: BaseGridConnectorFactory = connector_factory or GridConnectorFactory( + config=self.config, + ) + + self.retry_attempts = int(config.get("retry_attempts", 3)) + if self.retry_attempts <= 0: + raise ValueError(f"retry_attempts must be > 0 (got {self.retry_attempts})") + + self.num_worker_threads = int(config.get("num_connector_threads", 2)) + if self.num_worker_threads <= 0: + raise ValueError(f"num_worker_threads must be > 0 (got {self.num_worker_threads})") + self.max_dequeue_per_loop = int(config.get("max_dequeue_per_loop", 100)) + if self.max_dequeue_per_loop <= 0: + raise ValueError(f"max_dequeue_per_loop must be > 0 (got {self.max_dequeue_per_loop})") + self.poll_timeout_sec = float(config.get("poll_timeout_sec", 0.01)) + if self.poll_timeout_sec < 0: + raise ValueError(f"poll_timeout_sec must be >= 0 (got {self.poll_timeout_sec})") + self.poll_interval_sec = float(config.get("poll_interval_sec", 0.1)) + if self.poll_interval_sec < 0: + raise ValueError(f"poll_interval_sec must be >= 0 (got {self.poll_interval_sec})") + + self.max_poll_failures = int(config.get("max_poll_failures", 3)) + if self.max_poll_failures <= 0: + raise ValueError(f"max_poll_failures must be > 0 (got {self.max_poll_failures})") + + errors_maxlen = int(config.get("errors_maxlen", 1000)) + if errors_maxlen < 0: + raise ValueError(f"errors_maxlen must be >= 0 (got {errors_maxlen})") + self._stop_event = threading.Event() # signals worker shutdown + + self._threads: List[threading.Thread] = [] # worker thread handles + + self._tasks: Deque[QueuedTask] = deque() # pending task queue + self._results: Deque[Hashable] = deque() # completed DAG node IDs + self._tasks_condition = threading.Condition() # protects task queue + self._results_condition = threading.Condition() # protects results queue + self._queued_or_inflight: Set[Hashable] = set() # de-dup guard for node IDs + + self._active_sessions_count = 0 # global active session count + self._active_sessions_lock = threading.Lock() # protects active count + + self._errors_maxlen = errors_maxlen # cap for stored worker errors + self._errors: Deque[WorkerError] = deque(maxlen=self._errors_maxlen) # recent errors + self._errors_lock = threading.Lock() # protects error deque + self._fatal_error: Optional[WorkerError] = None # first fatal error + + self._start_workers() + + self._logger.info( + "Threaded adapter initialized " + f"(workers={self.num_worker_threads}, max_dequeue_per_loop={self.max_dequeue_per_loop})" + ) + + def _start_workers(self) -> None: + for thread_id in range(self.num_worker_threads): + t = threading.Thread( + target=self._worker_main_loop, + args=(thread_id,), + name=f"GridConnectorDagAdapter-W{thread_id}", + daemon=True, + ) + t.start() + self._threads.append(t) + + def shutdown(self, wait: bool = True, timeout: Optional[float] = None) -> None: + self._stop_event.set() + with self._tasks_condition: + self._tasks_condition.notify_all() + with self._results_condition: + self._results_condition.notify_all() + + if not wait: + return + + deadline = None if timeout is None else (time.time() + timeout) + for t in self._threads: + join_timeout = None if deadline is None else max(0.0, deadline - time.time()) + t.join(timeout=join_timeout) + + def __del__(self) -> None: + try: + self.shutdown(wait=False) + except Exception: + pass + + def submit_tasks(self, node_ids: List[str], dag_container: BaseDAG) -> Dict[Any, Dict[str, str]]: + """ + Enqueue tasks for worker threads. + + - Marks tasks as submitted on the scheduler thread to avoid re-queueing. + - Builds task definitions on the scheduler thread to avoid multi-threaded DAG access. + """ + self.raise_if_failed() + if not node_ids: + return {} + + items: List[QueuedTask] = [] + + for node_id in node_ids: + task_definition = dag_container.build_grid_task(node_id) + items.append(QueuedTask(node_id=node_id, task_definition=task_definition)) + + enqueued_ids: List[Hashable] = [] + with self._tasks_condition: + for item in items: + if item.node_id in self._queued_or_inflight: + continue + self._queued_or_inflight.add(item.node_id) + self._tasks.append(item) + enqueued_ids.append(item.node_id) + + # notify threads that new tasks are ready to be submitted to the grid + if enqueued_ids: + self._tasks_condition.notify_all() + + if enqueued_ids and hasattr(dag_container, "mark_task_submitted"): + for node_id in enqueued_ids: + dag_container.mark_task_submitted(node_id) # type: ignore[attr-defined] + + if enqueued_ids: + self._logger.info( + f"Enqueued {len(enqueued_ids)}/{len(node_ids)} tasks" + ) + return {} + + def poll_completed(self) -> List[Hashable]: + """Drain completed DAG task IDs from the results queue.""" + self.raise_if_failed() + completed: List[Hashable] = [] + with self._results_condition: + while self._results: + completed.append(self._results.popleft()) + if completed: + self._results_condition.notify_all() + if completed: + with self._tasks_condition: + for node_id in completed: + self._queued_or_inflight.discard(node_id) + return completed + + def active_count(self) -> int: + """Return number of active submissions across all workers.""" + with self._active_sessions_lock: + return self._active_sessions_count + + def get_errors(self) -> List[WorkerError]: + """Return a snapshot of worker errors (if any).""" + with self._errors_lock: + return list(self._errors) + + def raise_if_failed(self) -> None: + """Raise if a worker hit a fatal error or all workers stopped.""" + if self._fatal_error is not None: + err = self._fatal_error + raise RuntimeError(f"Threaded adapter worker failed (thread={err.thread_id}, stage={err.stage}): {err.error}") + + if self._threads and not any(t.is_alive() for t in self._threads): + raise RuntimeError("Threaded adapter has no live worker threads") + + def _record_error(self, thread_id: int, stage: str, exc: Exception) -> None: + with self._errors_lock: + err = WorkerError(thread_id=thread_id, stage=stage, error=exc, when_ts=time.time()) + self._errors.append(err) + if stage in {"worker_crash"} and self._fatal_error is None: + self._fatal_error = err + + def _requeue_session_tasks(self, thread_id: int, session_obj: SessionState) -> None: + remaining_node_ids = [ + session_obj.grid_to_dag[grid_tid] + for grid_tid in session_obj.remaining_grid_task_ids + if grid_tid in session_obj.grid_to_dag + ] + if not remaining_node_ids: + return + + with self._tasks_condition: + for node_id in remaining_node_ids: + task_definition = session_obj.task_definitions_by_node.get(node_id) + if task_definition is None: + self._record_error(thread_id, "missing_task_definition", KeyError(node_id)) + continue + self._tasks.append(QueuedTask(node_id=node_id, task_definition=task_definition)) + self._tasks_condition.notify_all() + + @staticmethod + def _extract_task_ids(submission_resp: Any) -> List[str]: + if submission_resp is None: + return [] + if isinstance(submission_resp, str): + try: + parsed = json.loads(submission_resp) + if isinstance(parsed, dict) and "task_ids" in parsed: + return parsed.get("task_ids") or [] + except Exception: + return [] + if isinstance(submission_resp, dict): + return submission_resp.get("task_ids") or [] + task_ids = getattr(submission_resp, "task_ids", None) + if task_ids: + return list(task_ids) + if hasattr(submission_resp, "get"): + try: + return submission_resp.get("task_ids", []) # type: ignore[no-any-return] + except Exception: + return [] + return [] + + @staticmethod + def _extract_session_id(submission_resp: Any) -> Optional[str]: + if submission_resp is None: + return None + if isinstance(submission_resp, str): + try: + parsed = json.loads(submission_resp) + if isinstance(parsed, dict) and "session_id" in parsed: + return parsed.get("session_id") + except Exception: + return submission_resp + if isinstance(submission_resp, dict): + return submission_resp.get("session_id") + sess = getattr(submission_resp, "session_id", None) + if sess: + return str(sess) + if hasattr(submission_resp, "get"): + try: + return submission_resp.get("session_id", None) # type: ignore[no-any-return] + except Exception: + return None + return None + + @staticmethod + def _extract_finished_task_ids(grid_response: Any) -> List[str]: + if not grid_response: + return [] + if isinstance(grid_response, dict): + finished = grid_response.get("finished") or [] + return [str(x) for x in finished] + if hasattr(grid_response, "finished"): + return [str(x) for x in getattr(grid_response, "finished", [])] + if hasattr(grid_response, "get"): + try: + finished = grid_response.get("finished") or [] # type: ignore[assignment] + return [str(x) for x in finished] + except Exception: + return [] + return [] + + def _worker_main_loop(self, thread_id: int) -> None: + """ + Worker thread main loop. + + Responsibilities: + <1.> Creates a thread-local grid connector + <2.> Repeatedly dequeue up to `max_dequeue_per_loop` queued tasks and submit them as one grid session. + <3.> Track the submitted session locally (`active_sessions`) and poll it for finished tasks. + <4.> For each finished grid task ID, map it back to the DAG `node_id` and push it into the shared results queue. + + Concurrency model: + - Uses `self._tasks_condition` to wait for new tasks and to dequeue atomically. + - Uses `self._results_condition` to push completed node IDs for `poll_completed()` to drain. + - Stops when `self._stop_event` is set; `shutdown()` wakes sleepers via `notify_all()`. + + Error handling: + - if submissions fail after all retries, tasks return back into the queue. + + """ + logger = logging.getLogger(f"GridConnectorDagAdapter-W{thread_id}") + + # <1.> Creates a thread-local grid connector + try: + connector = self._connector_factory.create(thread_id=thread_id, logger=logger) + logger.info(f"Thread {thread_id}: initialized") + except Exception as e: + self._record_error(thread_id, "init_connector", e) + logger.error(f"Thread {thread_id}: failed to initialize connector: {e}", exc_info=True) + return + + active_sessions: Dict[str, SessionState] = {} + local_session_counter = 0 + + try: + while not self._stop_event.is_set(): + new_items: List[QueuedTask] = [] + + # <2.A> Repeatedly dequeue up to `max_dequeue_per_loop` + with self._tasks_condition: + if not self._tasks and not active_sessions and not self._stop_event.is_set(): + self._tasks_condition.wait(timeout=self.poll_interval_sec) + + while self._tasks and len(new_items) < self.max_dequeue_per_loop: + new_items.append(self._tasks.popleft()) + + # <2.B> Submit new items + if new_items: + batch = new_items + task_definitions = [t.task_definition for t in batch] + node_ids = [t.node_id for t in batch] + + submission_handle = None + # Stores the last exception that occurred when trying to submit tasks, so it can be logged if all retry attempts fail + last_exc: Optional[Exception] = None + + for attempt in range(self.retry_attempts): + try: + submission_handle = connector.send(task_definitions) + logger.info(f"Thread {thread_id}: submited {len(task_definitions)}") + break + except Exception as e: + last_exc = e + self._record_error(thread_id, "send", e) + if attempt < self.retry_attempts - 1: + time.sleep(1.0) + + if submission_handle is None: + if last_exc: + logger.error(f"Thread {thread_id}: failed to submit batch: {last_exc}", exc_info=True) + with self._tasks_condition: + for t in batch: + self._tasks.appendleft(t) + self._tasks_condition.notify_all() + continue + + session_id = self._extract_session_id(submission_handle) + task_ids = self._extract_task_ids(submission_handle) + + if not session_id: + local_session_counter += 1 + session_id = f"t{thread_id}-{int(time.time())}-{local_session_counter}" + + if not task_ids or len(task_ids) != len(node_ids): + self._record_error( + thread_id, + "invalid_submission_response", + ValueError(f"Expected {len(node_ids)} task_ids, got {len(task_ids)}"), + ) + with self._tasks_condition: + for t in batch: + self._tasks.appendleft(t) + self._tasks_condition.notify_all() + continue + + grid_to_dag: Dict[str, Hashable] = {} + task_definitions_by_node = {t.node_id: t.task_definition for t in batch} + for node_id, grid_tid in zip(node_ids, task_ids): + grid_to_dag[str(grid_tid)] = node_id + + remaining = set(grid_to_dag.keys()) + active_sessions[session_id] = SessionState( + session_id=session_id, + submission_handle=submission_handle, + grid_to_dag=grid_to_dag, + remaining_grid_task_ids=remaining, + task_definitions_by_node=task_definitions_by_node, + ) + + with self._active_sessions_lock: + self._active_sessions_count += 1 + + # iterate over all in‑flight grid submissions owned by the worker thread + # if a session has new completed tasks, add them to the completed list + # if all tasks of a session completed, remove session from the active_sessions. + for session_id, session_obj in list(active_sessions.items()): + try: + results = connector.get_results(session_obj.submission_handle, timeout_sec=self.poll_timeout_sec) + except Exception as e: + self._record_error(thread_id, "get_results", e) + session_obj.poll_failures += 1 + if session_obj.poll_failures >= self.max_poll_failures: + logger.error( + "Thread %s: session %s exceeded poll failure limit; requeueing remaining tasks", + thread_id, + session_id, + ) + self._requeue_session_tasks(thread_id, session_obj) + del active_sessions[session_id] + with self._active_sessions_lock: + self._active_sessions_count = max(0, self._active_sessions_count - 1) + continue + else: + session_obj.poll_failures = 0 + + finished_grid_task_ids = self._extract_finished_task_ids(results) + if not finished_grid_task_ids: + continue + + for grid_task_id in finished_grid_task_ids: + # This task has been returned in previous call and sent to the completed queue. + if grid_task_id not in session_obj.remaining_grid_task_ids: + continue + + node_id = session_obj.grid_to_dag.get(grid_task_id) + if node_id is None: + self._record_error(thread_id, "map_grid_to_dag", KeyError(grid_task_id)) + continue + + session_obj.remaining_grid_task_ids.remove(grid_task_id) + + with self._results_condition: + if self._stop_event.is_set(): + return + self._results.append(node_id) + self._results_condition.notify_all() + + if not session_obj.remaining_grid_task_ids: + del active_sessions[session_id] + with self._active_sessions_lock: + self._active_sessions_count = max(0, self._active_sessions_count - 1) + + if not new_items and active_sessions and self.poll_interval_sec: + time.sleep(self.poll_interval_sec) + except Exception as e: + self._record_error(thread_id, "worker_crash", e) + logger.error(f"Thread {thread_id}: worker crashed: {e}", exc_info=True) + return diff --git a/source/client/python/utils/utils/dag/base/__init__.py b/source/client/python/utils/utils/dag/base/__init__.py new file mode 100644 index 00000000..16593d74 --- /dev/null +++ b/source/client/python/utils/utils/dag/base/__init__.py @@ -0,0 +1,6 @@ +""" +Package for DAG-related classes and adapters. +""" + +from .base_dag import BaseDAG # noqa: F401 +from .nx_htc_dag_container import NxHtcDagContainer # noqa: F401 diff --git a/source/client/python/utils/utils/dag/base/base_dag.py b/source/client/python/utils/utils/dag/base/base_dag.py new file mode 100644 index 00000000..21f0ce7e --- /dev/null +++ b/source/client/python/utils/utils/dag/base/base_dag.py @@ -0,0 +1,61 @@ +# Copyright 2024 Amazon.com, Inc. or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ + +""" +Base DAG abstraction used by the DAG processor. + +This skeleton aligns with the DDD design in DDD.md: it defines the stable +contract that domain-specific DAG adapters must implement so the processor +can reason about ready nodes, mark progress, and detect completion without +depending on any concrete DAG library. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Hashable, Iterable + + +class BaseDAG(ABC): + """Minimal DAG contract for the processor to interact with domain DAGs.""" + + @abstractmethod + def get_nodes_with_resolved_dependencies(self) -> Iterable[Hashable]: + """ + Return an iterable of node identifiers whose dependencies are fully satisfied + and are ready for execution/submission. + """ + + @abstractmethod + def mark_node_completed(self, node_id: Hashable) -> None: + """ + Mark the given node as completed so downstream dependency checks can be updated. + """ + + @abstractmethod + def is_dag_completed(self) -> bool: + """ + Return True when all nodes have been processed/completed, otherwise False. + """ + + @abstractmethod + def get_node_by_id(self, node_id: Hashable): + """ + Return the node object/data for the given node identifier. + + Args: + node_id: The identifier of the node to retrieve + + Returns: + The node object or data associated with the given ID, or None if not found + + """ + + @abstractmethod + def build_grid_task(self, node_id: Hashable) -> Dict[str, Any]: + """ + Build and return a grid connector task definition for the given DAG node. + + The returned object must be suitable to pass to the grid connector `send()` call. + """ diff --git a/source/client/python/utils/utils/dag/base/dag_generator.py b/source/client/python/utils/utils/dag/base/dag_generator.py new file mode 100644 index 00000000..21386feb --- /dev/null +++ b/source/client/python/utils/utils/dag/base/dag_generator.py @@ -0,0 +1,391 @@ +# Copyright 2024 Amazon.com, Inc. or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ + + +""" +DAG Generator: Generate test DAGs with configurable structure +""" + +import logging +import networkx as nx +import random +from typing import Dict, Any + +logger = logging.getLogger("DAGGenerator") + + +class DAGGenerator: + """Generates test DAGs with variable depth and breadth""" + + def __init__(self, dag_config: Dict[str, Any], seed: int = None): + """Initialize DAG generator + + Args: + dag_config: DAG-related configuration + seed: Random seed for reproducible generation + """ + self.config = dag_config + if seed is not None: + random.seed(seed) + + logger.debug("DAG Generator initialized") + + def generate_dag(self, depth: int, breadth: int) -> nx.DiGraph: + """Generate a DAG with specified depth and breadth + + Args: + depth: Number of levels in the DAG + breadth: Maximum number of children per node + + Returns: + Generated NetworkX DiGraph + """ + if depth < 1: + raise ValueError("Depth must be at least 1") + if breadth < 1: + raise ValueError("Breadth must be at least 1") + + logger.info(f"Generating DAG with depth={depth}, breadth={breadth}") + + dag = nx.DiGraph() + node_counter = 0 + + # Generate nodes level by level + current_level = [] + + for level in range(depth): + next_level = [] + + if level == 0: + # Root level - create only 1 root node + num_nodes = 1 + for i in range(num_nodes): + node_id = f"level_{level}_node_{i}" + task_type = "compute" if level == depth - 1 else "aggregation" + + dag.add_node( + node_id, + task_type=task_type, + worker_arguments=self._generate_task_arguments(), + status="pending" + ) + current_level.append(node_id) + node_counter += 1 + + logger.debug(f"Level {level}: Created {len(current_level)} root nodes") + # Don't set current_level = next_level for root level, keep the root nodes for next iteration + continue + else: + # Create children for each node in current level + if not current_level: + logger.debug(f"Level {level}: No parent nodes, stopping") + break + + for parent_node in current_level: + # Each parent gets 1 to breadth children + num_children = breadth + logger.debug(f"Level {level}: Creating {num_children} children for {parent_node}") + + for i in range(num_children): + child_id = f"level_{level}_node_{node_counter}" + task_type = "compute" if level == depth - 1 else "aggregation" + + dag.add_node( + child_id, + task_type=task_type, + worker_arguments=self._generate_task_arguments(), + status="pending" + ) + + # Add edge from child to parent (child must complete before parent) + dag.add_edge(child_id, parent_node) + + next_level.append(child_id) + node_counter += 1 + + logger.debug(f"Level {level}: Created {len(next_level)} child nodes") + + current_level = next_level + + # Validate generated DAG + if not nx.is_directed_acyclic_graph(dag): + logger.error("Generated graph contains cycles!") + raise RuntimeError("Generated invalid DAG with cycles") + + logger.info(f"Generated DAG with {len(dag.nodes())} nodes and {len(dag.edges())} edges") + + # Log DAG statistics + self._log_dag_statistics(dag) + + # Print DAG structure to stdout + self._print_dag_structure(dag) + + return dag + + def _generate_task_arguments(self) -> list: + """Generate random task arguments within specified ranges + + Returns: + List of task arguments [param1, param2, param3] + """ + sleep_ms = self.config.get("mock_tasks_sleep_time_ms", 100) + param1 = str(sleep_ms) + return [param1, "1", "1"] + + def _log_dag_statistics(self, dag: nx.DiGraph): + """Log statistics about the generated DAG + + Args: + dag: DAG to analyze + """ + try: + # Count node types + compute_nodes = sum(1 for _, data in dag.nodes(data=True) + if data.get('task_type') == 'compute') + aggregation_nodes = sum(1 for _, data in dag.nodes(data=True) + if data.get('task_type') == 'aggregation') + + # Calculate depth (longest path) + if dag.nodes(): + # Find nodes with no predecessors (roots) + roots = [n for n in dag.nodes() if dag.in_degree(n) == 0] + max_depth = 0 + + for root in roots: + for node in dag.nodes(): + if nx.has_path(dag, root, node): + try: + path_length = nx.shortest_path_length(dag, root, node) + max_depth = max(max_depth, path_length + 1) + except nx.NetworkXNoPath: + pass + else: + max_depth = 0 + + logger.info("DAG Statistics:") + logger.info(f" Total nodes: {len(dag.nodes())}") + logger.info(f" Total edges: {len(dag.edges())}") + logger.info(f" Compute nodes: {compute_nodes}") + logger.info(f" Aggregation nodes: {aggregation_nodes}") + logger.info(f" Actual depth: {max_depth}") + + except Exception as e: + logger.warning(f"Failed to calculate DAG statistics: {str(e)}") + + def _print_dag_structure(self, dag: nx.DiGraph): + """Print DAG structure in a tree-like format showing execution flow + + Args: + dag: DAG to visualize + """ + print("\n" + "🌳 DAG STRUCTURE") + print("=" * 50) + + if not dag.nodes(): + print("📭 Empty DAG") + return + + # Find leaf nodes (no outgoing edges) - these execute first + leaf_nodes = [n for n in dag.nodes() if dag.out_degree(n) == 0] + # Find root node (should be level_0_node_0) + root_nodes = [n for n in dag.nodes() if n.startswith('level_0_')] + + # Print summary + total_nodes = len(dag.nodes()) + total_edges = len(dag.edges()) + print(f"📊 Summary: {total_nodes} nodes, {total_edges} edges") + print(f"🎯 Execution: {len(leaf_nodes)} leaf nodes → 1 root node") + print() + + # Show the tree structure from root down (logical structure) + if root_nodes: + root_node = root_nodes[0] # Should be level_0_node_0 + print("🌱 DAG Tree (Root → Leaves):") + visited = set() + self._print_dependency_tree(dag, root_node, "", True, visited) + print() + + print("⚡ Execution Order (leaf nodes execute first):") + # Group nodes by level for execution order display + levels = self._get_execution_levels(dag) + for level_num, level_nodes in enumerate(levels): + level_nodes_sorted = sorted(level_nodes) + if level_num == 0: + print(f" Level {level_num + 1} (Execute First): {level_nodes_sorted}") + elif level_num == len(levels) - 1: + print(f" Level {level_num + 1} (Execute Last): {level_nodes_sorted}") + else: + print(f" Level {level_num + 1}: {level_nodes_sorted}") + + print("=" * 50) + print() + + def _print_dependency_tree(self, dag: nx.DiGraph, node: str, prefix: str, is_last: bool, visited: set): + """Print dependency tree from root to leaves + + Args: + dag: The DAG + node: Current node to print + prefix: Current line prefix for indentation + is_last: Whether this is the last child at this level + visited: Set of already visited nodes + """ + if node in visited: + return + visited.add(node) + + # Get node info + node_data = dag.nodes[node] + task_type = node_data.get('task_type', 'unknown') + worker_args = node_data.get('worker_arguments', []) + exec_time = worker_args[0] if worker_args else "?" + + # Choose appropriate symbols + connector = "└── " if is_last else "├── " + type_icon = "🔄" if task_type == "aggregation" else "⚡" + + # Print current node + print(f"{prefix}{connector}{type_icon} {node} ({exec_time}ms)") + + # Get dependencies (nodes that must complete before this one) + dependencies = sorted(list(dag.predecessors(node))) + + # Print dependencies + for i, dep in enumerate(dependencies): + is_dep_last = (i == len(dependencies) - 1) + dep_prefix = prefix + (" " if is_last else "│ ") + self._print_dependency_tree(dag, dep, dep_prefix, is_dep_last, visited) + + def _get_execution_levels(self, dag: nx.DiGraph): + """Get nodes grouped by execution level (topological sort levels) + + Args: + dag: The DAG to analyze + + Returns: + List of lists, where each inner list contains nodes at that execution level + """ + # Create a copy to avoid modifying original + dag_copy = dag.copy() + levels = [] + + while dag_copy.nodes(): + # Find nodes with no incoming edges (ready to execute) + ready_nodes = [n for n in dag_copy.nodes() if dag_copy.in_degree(n) == 0] + + if not ready_nodes: + # Should not happen in a valid DAG + break + + levels.append(ready_nodes) + + # Remove these nodes and their edges + dag_copy.remove_nodes_from(ready_nodes) + + return levels + + def _print_tree_recursive(self, dag: nx.DiGraph, node: str, prefix: str, is_last: bool, visited: set): + """Recursively print tree structure + + Args: + dag: The DAG + node: Current node to print + prefix: Current line prefix for indentation + is_last: Whether this is the last child at this level + visited: Set of already visited nodes + """ + if node in visited: + return + visited.add(node) + + # Get node info + node_data = dag.nodes[node] + task_type = node_data.get('task_type', 'unknown') + worker_args = node_data.get('worker_arguments', []) + exec_time = worker_args[0] if worker_args else "?" + + # Choose appropriate symbols + connector = "└── " if is_last else "├── " + type_icon = "🔄" if task_type == "aggregation" else "⚡" + + # Print current node + print(f"{prefix}{connector}{type_icon} {node} ({exec_time}ms)") + + # Get children + children = sorted(list(dag.successors(node))) + + # Print children + for i, child in enumerate(children): + is_child_last = (i == len(children) - 1) + child_prefix = prefix + (" " if is_last else "│ ") + self._print_tree_recursive(dag, child, child_prefix, is_child_last, visited) + + def generate_linear_dag(self, num_tasks: int) -> nx.DiGraph: + """Generate a linear DAG (chain of tasks) + + Args: + num_tasks: Number of tasks in the chain + + Returns: + Linear DAG + """ + logger.info(f"Generating linear DAG with {num_tasks} tasks") + + dag = nx.DiGraph() + + for i in range(num_tasks): + task_id = f"task_{i}" + task_type = "compute" # All tasks are compute tasks in linear chain + + dag.add_node( + task_id, + task_type=task_type, + worker_arguments=self._generate_task_arguments(), + status="pending" + ) + + # Add edge from previous task + if i > 0: + dag.add_edge(f"task_{i - 1}", task_id) + + logger.info(f"Generated linear DAG with {len(dag.nodes())} nodes and {len(dag.edges())} edges") + return dag + + def generate_fan_out_dag(self, num_leaf_tasks: int) -> nx.DiGraph: + """Generate a fan-out DAG (one root with multiple children) + + Args: + num_leaf_tasks: Number of leaf tasks + + Returns: + Fan-out DAG + """ + logger.info(f"Generating fan-out DAG with {num_leaf_tasks} leaf tasks") + + dag = nx.DiGraph() + + # Add root task + root_id = "root_task" + dag.add_node( + root_id, + task_type="aggregation", + worker_arguments=self._generate_task_arguments(), + status="pending" + ) + + # Add leaf tasks + for i in range(num_leaf_tasks): + leaf_id = f"leaf_task_{i}" + dag.add_node( + leaf_id, + task_type="compute", + worker_arguments=self._generate_task_arguments(), + status="pending" + ) + + # Add edge from root to leaf + dag.add_edge(root_id, leaf_id) + + logger.info(f"Generated fan-out DAG with {len(dag.nodes())} nodes and {len(dag.edges())} edges") + return dag diff --git a/source/client/python/utils/utils/dag/base/nx_htc_dag_container.py b/source/client/python/utils/utils/dag/base/nx_htc_dag_container.py new file mode 100644 index 00000000..9a342d3c --- /dev/null +++ b/source/client/python/utils/utils/dag/base/nx_htc_dag_container.py @@ -0,0 +1,477 @@ +# Copyright 2024 Amazon.com, Inc. or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ + +""" +DAG Manager: NetworkX-based DAG management and dependency tracking +""" + +import logging +import threading +from typing import Any, Dict, Hashable, List, Optional, Set + +import networkx as nx + +from .base_dag import BaseDAG + +logger = logging.getLogger("NxHtcDagContainer") + + +class NxHtcDagContainer(BaseDAG): + """Manages DAG structure and task state using NetworkX""" + + def __init__(self, nx_dag: Optional[nx.DiGraph] = None): + """Initialize DAG manager.""" + self.dag: Optional[nx.DiGraph] = None + self._lock = threading.Lock() # Thread safety for state updates + logger.debug("DAG Manager initialized") + + if nx_dag: + self.set_dag(nx_dag) + + def get_nodes_with_resolved_dependencies(self) -> List[Hashable]: + """ + BaseDAG implementation: return nodes whose dependencies are satisfied. + """ + return self.get_ready_tasks() + + def mark_node_completed(self, node_id: Hashable) -> bool: + """ + BaseDAG implementation: mark a node as completed. + """ + return self.mark_task_complete(node_id) + + def is_dag_completed(self) -> bool: + """ + BaseDAG implementation: return True when all nodes are completed. + """ + return self.is_dag_complete() + + def get_node_by_id(self, node_id: Hashable): + """ + BaseDAG implementation: return the node object/data for the given node identifier. + + Args: + node_id: The identifier of the node to retrieve + + Returns: + The node data dictionary associated with the given ID, or None if not found + """ + if not self.dag: + return None + + if node_id not in self.dag.nodes(): + return None + + # with self._lock: + return dict(self.dag.nodes[node_id]) + + def build_grid_task(self, node_id: Hashable) -> Dict[str, Any]: + """ + BaseDAG implementation: build a grid task definition for a DAG node. + + Currently this adapter maps DAG node metadata into the connector payload shape: + {"worker_arguments": [...]} + """ + node_data = self.get_node_by_id(node_id) + if node_data is None: + raise KeyError(f"Task {node_id} not found in DAG") + return {"worker_arguments": node_data.get("worker_arguments", ["1000", "1", "1"])} + + ########################################################################################################### + ########################################################################################################### + ########################################################################################################### + + def set_dag(self, dag: nx.DiGraph) -> bool: + """Set the DAG to be processed + + Args: + dag: NetworkX DiGraph representing the task DAG + + Returns: + True if DAG was set successfully, False otherwise + """ + try: + print("Setting DAG----") + with self._lock: + # Validate DAG structure + if not self._validate_dag(dag): + logger.error("Invalid DAG structure") + return False + + self.dag = dag.copy() + + # Initialize task states if not already set + for node_id in self.dag.nodes(): + if 'status' not in self.dag.nodes[node_id]: + self.dag.nodes[node_id]['status'] = 'pending' + + logger.info(f"DAG set with {len(self.dag.nodes())} nodes and {len(self.dag.edges())} edges") + return True + + except Exception as e: + logger.error(f"Failed to set DAG: {str(e)}") + return False + + def _validate_dag(self, dag: nx.DiGraph) -> bool: + """Validate DAG structure + + Args: + dag: DAG to validate + + Returns: + True if DAG is valid, False otherwise + """ + try: + # Check if it's a DAG (no cycles) + if not nx.is_directed_acyclic_graph(dag): + logger.error("Graph contains cycles - not a valid DAG") + return False + + # Check if all nodes have required attributes + for node_id, node_data in dag.nodes(data=True): + if 'task_type' not in node_data: + logger.error(f"Node {node_id} missing task_type attribute") + print(node_data) + return False + + if 'worker_arguments' not in node_data: + logger.error(f"Node {node_id} missing worker_arguments attribute") + return False + + # Validate worker_arguments format + args = node_data['worker_arguments'] + if not isinstance(args, list) or len(args) != 3: + logger.error(f"Node {node_id} has invalid worker_arguments format") + return False + + return True + + except Exception as e: + logger.error(f"DAG validation failed: {str(e)}") + return False + + def get_ready_tasks(self) -> List[str]: + """Get tasks that are ready for execution (all dependencies satisfied) + + Returns: + List of task IDs ready for execution + """ + if not self.dag: + return [] + + ready_tasks = [] + + with self._lock: + for node_id in self.dag.nodes(): + if self._is_task_ready(node_id): + ready_tasks.append(node_id) + + logger.debug(f"Found {len(ready_tasks)} ready tasks: {ready_tasks}") + return ready_tasks + + def _is_task_ready(self, task_id: str) -> bool: + """Check if a task is ready for execution + + Args: + task_id: Task to check + + Returns: + True if task is ready, False otherwise + """ + if not self.dag or task_id not in self.dag.nodes(): + return False + + # Task must be in pending state + if self.dag.nodes[task_id].get('status') != 'pending': + return False + + # All predecessor tasks must be completed + predecessors = list(self.dag.predecessors(task_id)) + for pred_id in predecessors: + if self.dag.nodes[pred_id].get('status') != 'completed': + return False + + return True + + def mark_task_submitted(self, task_id: str) -> bool: + """Mark a task as submitted + + Args: + task_id: Task to mark as submitted + + Returns: + True if successful, False otherwise + """ + return self._update_task_status(task_id, 'submitted') + + def mark_task_complete(self, task_id: str) -> bool: + """Mark a task as completed + + Args: + task_id: Task to mark as completed + + Returns: + True if successful, False otherwise + """ + return self._update_task_status(task_id, 'completed') + + def _update_task_status(self, task_id: str, status: str) -> bool: + """Update task status + + Args: + task_id: Task to update + status: New status + + Returns: + True if successful, False otherwise + """ + if not self.dag or task_id not in self.dag.nodes(): + logger.error(f"Task {task_id} not found in DAG") + return False + + try: + # with self._lock: + old_status = self.dag.nodes[task_id].get('status', 'unknown') + self.dag.nodes[task_id]['status'] = status + logger.debug(f"Task {task_id} status: {old_status} -> {status}") + return True + + except Exception as e: + logger.error(f"Failed to update task {task_id} status: {str(e)}") + return False + + def is_dag_complete(self) -> bool: + """Check if all tasks in the DAG are completed + + Returns: + True if all tasks are completed, False otherwise + """ + if not self.dag: + return False + + with self._lock: + for node_id in self.dag.nodes(): + if self.dag.nodes[node_id].get('status') != 'completed': + return False + + return True + + def get_task_status(self, task_id: str) -> Optional[str]: + """Get status of a specific task + + Args: + task_id: Task to check + + Returns: + Task status or None if task not found + """ + if not self.dag or task_id not in self.dag.nodes(): + return None + + with self._lock: + return self.dag.nodes[task_id].get('status') + + def get_completed_task_count(self) -> int: + """Get number of completed tasks + + Returns: + Number of completed tasks + """ + if not self.dag: + return 0 + + count = 0 + with self._lock: + for node_id in self.dag.nodes(): + if self.dag.nodes[node_id].get('status') == 'completed': + count += 1 + + return count + + def get_total_task_count(self) -> int: + """Get total number of tasks + + Returns: + Total number of tasks + """ + return len(self.dag.nodes()) if self.dag else 0 + + def get_dag_summary(self) -> Dict[str, Any]: + """Get summary of DAG status + + Returns: + Dictionary with DAG summary information + """ + if not self.dag: + return {"status": "no_dag"} + + status_counts = {} + with self._lock: + for node_id in self.dag.nodes(): + status = self.dag.nodes[node_id].get('status', 'unknown') + status_counts[status] = status_counts.get(status, 0) + 1 + + return { + "total_nodes": len(self.dag.nodes()), + "total_edges": len(self.dag.edges()), + "status_counts": status_counts, + "is_complete": self.is_dag_complete() + } + + def visualize_dag_status(self) -> str: + """Generate a visual representation of the DAG with current status + + Returns: + String representation of the DAG with status indicators + """ + if not self.dag: + return "No DAG loaded" + + # Status symbols + status_symbols = { + 'pending': '⏳', # Not ready for execution + 'ready': '🟡', # Ready for execution + 'submitted': '🔄', # Being executed + 'completed': '✅', # Completed + 'failed': '❌' # Failed + } + + # Get ready tasks to determine which pending tasks are actually ready + ready_tasks = set(self.get_ready_tasks()) + + with self._lock: + # Build status summary + status_counts = {} + for node_id in self.dag.nodes(): + status = self.dag.nodes[node_id].get('status', 'pending') + # Override status for ready tasks + if status == 'pending' and node_id in ready_tasks: + status = 'ready' + status_counts[status] = status_counts.get(status, 0) + 1 + + # Create the visualization + lines = [] + lines.append("") + lines.append("🔍 DAG STATUS VISUALIZATION") + lines.append("=" * 50) + + # Status summary + lines.append("📊 Status Summary:") + for status, count in sorted(status_counts.items()): + symbol = status_symbols.get(status, '❓') + lines.append(f" {symbol} {status.capitalize()}: {count}") + + lines.append("") + lines.append("🌳 DAG Structure with Status (Execution Flow: Leaves → Root):") + + # Find leaf nodes (nodes with no successors) - these execute first + leaf_nodes = [n for n in self.dag.nodes() if self.dag.out_degree(n) == 0] + + if not leaf_nodes: + lines.append(" No leaf nodes found") + return "\n".join(lines) + + # Recursively build tree structure from leaves to root + visited = set() + for leaf in sorted(leaf_nodes): + self._build_reverse_tree_visualization(leaf, lines, visited, ready_tasks, status_symbols, "") + + lines.append("") + lines.append("Legend:") + lines.append(" ⏳ Pending (dependencies not met)") + lines.append(" 🟡 Ready (can be executed)") + lines.append(" 🔄 Submitted (being executed)") + lines.append(" ✅ Completed") + lines.append(" ❌ Failed") + lines.append("=" * 50) + + return "\n".join(lines) + + def _build_reverse_tree_visualization(self, node_id: str, lines: List[str], visited: Set[str], + ready_tasks: Set[str], status_symbols: Dict[str, str], prefix: str): + """Recursively build tree visualization showing execution flow from leaves to root + + Args: + node_id: Current node to visualize + lines: List to append visualization lines to + visited: Set of already visited nodes + ready_tasks: Set of ready task IDs + status_symbols: Mapping of status to symbols + prefix: Current indentation prefix + """ + if node_id in visited: + return + + visited.add(node_id) + + # Get node status + status = self.dag.nodes[node_id].get('status', 'pending') + if status == 'pending' and node_id in ready_tasks: + status = 'ready' + + symbol = status_symbols.get(status, '❓') + + # Format node line (removed misleading timing information) + node_line = f"{prefix}└── {symbol} {node_id}" + lines.append(node_line) + + # Get parents (predecessors) - these execute after this node + parents = list(self.dag.predecessors(node_id)) + parents.sort() # Sort for consistent output + + # Recursively add parents + for i, parent in enumerate(parents): + is_last = (i == len(parents) - 1) + parent_prefix = prefix + (" " if is_last else "│ ") + self._build_reverse_tree_visualization(parent, lines, visited, ready_tasks, status_symbols, parent_prefix) + + def _build_tree_visualization(self, node_id: str, lines: List[str], visited: Set[str], + ready_tasks: Set[str], status_symbols: Dict[str, str], prefix: str): + """Recursively build tree visualization for a node and its children + + Args: + node_id: Current node to visualize + lines: List to append visualization lines to + visited: Set of already visited nodes + ready_tasks: Set of ready task IDs + status_symbols: Mapping of status to symbols + prefix: Current indentation prefix + """ + if node_id in visited: + return + + visited.add(node_id) + + # Get node status + status = self.dag.nodes[node_id].get('status', 'pending') + if status == 'pending' and node_id in ready_tasks: + status = 'ready' + + symbol = status_symbols.get(status, '❓') + + # Get node info + node_data = self.dag.nodes[node_id] + + # Add execution time if available + exec_time = "" + if 'worker_arguments' in node_data and len(node_data['worker_arguments']) >= 3: + try: + duration = int(node_data['worker_arguments'][2]) + exec_time = f" ({duration}ms)" + except (ValueError, IndexError): + pass + + # Format node line + node_line = f"{prefix}└── {symbol} {node_id}{exec_time}" + lines.append(node_line) + + # Get children (successors) + children = list(self.dag.successors(node_id)) + children.sort() # Sort for consistent output + + # Recursively add children + for i, child in enumerate(children): + is_last = (i == len(children) - 1) + child_prefix = prefix + (" " if is_last else "│ ") + self._build_tree_visualization(child, lines, visited, ready_tasks, status_symbols, child_prefix) diff --git a/source/client/python/utils/utils/dag/grid_connector_factory.py b/source/client/python/utils/utils/dag/grid_connector_factory.py new file mode 100644 index 00000000..7e411704 --- /dev/null +++ b/source/client/python/utils/utils/dag/grid_connector_factory.py @@ -0,0 +1,58 @@ +""" +Grid connector factories. + +This module isolates connector construction/authentication logic from adapters. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, Protocol + + +class BaseGridConnectorFactory(Protocol): + def create(self, thread_id: int, logger: logging.Logger) -> Any: + ... + + +class GridConnectorFactory: + """Default connector factory based on adapter config.""" + + def __init__(self, config: Dict[str, Any], prototype_connector: Any = None) -> None: + self._config = config + self._prototype_connector = prototype_connector + + def create(self, thread_id: int, logger: logging.Logger) -> Any: + import os + + use_mock = bool(self._config.get("use_mock_grid", False)) + + if use_mock: + from api.mock_connector import MockGridConnector + + connector = MockGridConnector() + connector.init(self._config.get("htc_grid", [])) + connector.authenticate() + return connector + + try: + client_config_file = os.environ["AGENT_CONFIG_FILE"] + except KeyError: + client_config_file = "/etc/agent/Agent_config.tfvars.json" + + if self._prototype_connector is not None: + connector_cls = self._prototype_connector.__class__ + connector = connector_cls() + else: + from api.connector import AWSConnector + + connector = AWSConnector() + + with open(client_config_file, "r") as file: + client_config_json = json.loads(file.read()) + connector.init(client_config_json) + + connector.authenticate() + logger.debug(f"Thread {thread_id}: connector authenticated") + return connector diff --git a/source/client/python/utils/utils/dag/schedulers/__init__.py b/source/client/python/utils/utils/dag/schedulers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/source/client/python/utils/utils/dag/schedulers/htc_dag_scheduler.py b/source/client/python/utils/utils/dag/schedulers/htc_dag_scheduler.py new file mode 100644 index 00000000..7428d13a --- /dev/null +++ b/source/client/python/utils/utils/dag/schedulers/htc_dag_scheduler.py @@ -0,0 +1,194 @@ +""" +HTC DAG Scheduler: encapsulates the DAG processing loop previously in HTCClientDAG.run. +""" + +import logging +import time +import traceback +from typing import Any, Dict + +from utils.dag.base.base_dag import BaseDAG + + +class HTCDagScheduler: + """Runs the DAG processing loop using provided collaborators.""" + + def __init__( + self, + config: Dict[str, Any], + dag_container: BaseDAG, + grid_connector_adapter: Any, + grid_connector: Any, + ) -> None: + self.dcon = dag_container + self.grid_connector_adapter = grid_connector_adapter + self.grid_connector = grid_connector + self.config = config + self.logger = logging.getLogger("HTCDagScheduler") + + def run(self) -> bool: + """Execute the main DAG processing loop.""" + if not self.dcon.dag: + self.logger.error("No DAG loaded for processing") + return False + + self.logger.info("Starting DAG processing") + start_time = time.time() + + try: + iteration_count = 0 + previous_completed_count = 0 # Track completed tasks from previous iteration + + # <1.> While DAG is not complete... + while not self.dcon.is_dag_complete(): + + iteration_start = time.time() + iteration_count += 1 + self.logger.info(f"Processing iteration {iteration_count}") + + # <2.> Find nodes that have no dependencies or all dependencies have been computed + t21 = time.perf_counter() + ready_nodes = self.dcon.get_ready_tasks() + t22 = time.perf_counter() + + t31 = time.perf_counter() + if ready_nodes: + self.logger.debug(f"# nodes/tasks about to be submitted for processing: {len(ready_nodes)}") + self.grid_connector_adapter.submit_tasks(ready_nodes, self.dcon) + t32 = time.perf_counter() + + # <4.> Check if any of the previously submitted tasks have completed + t41 = time.perf_counter() + completed_dag_ids = list(self.grid_connector_adapter.poll_completed()) + t42 = time.perf_counter() + + # <5.> Update DAG mark newly completed tasks as completed! + t51 = time.perf_counter() + for dag_tid in completed_dag_ids: + self.dcon.mark_task_complete(dag_tid) + t52 = time.perf_counter() + # Calculate completed tasks this iteration + completed_count = self.dcon.get_completed_task_count() + completed_tasks_this_iteration = completed_count - previous_completed_count + + # Logging and Bookkeping ################################################################# + # Log timing breakdown for this iteration + self._log_iteration_timings( + t21, t22, t31, t32, t41, t42, t51, t52, + ready_tasks_count=len(ready_nodes), + active_submissions_count=self.grid_connector_adapter.active_count(), + completed_tasks_this_iteration=completed_tasks_this_iteration + ) + + total_count = self.dcon.get_total_task_count() + self.logger.info(f"Progress: {completed_count}/{total_count} tasks completed") + + # Optional DAG visualization + if self.config.get("show_dag_visualization", False): + dag_viz = self.dcon.visualize_dag_status() + print(dag_viz) + + self._print_iteration_summary( + ready_tasks_count=len(ready_nodes), + active_submissions_count=self.grid_connector_adapter.active_count(), + completed_tasks_this_iteration=completed_tasks_this_iteration + ) + + # Update previous completed count for next iteration + previous_completed_count = completed_count + + # Enforce minimum loop interval if configured + min_interval = self.config.get("polling_interval_seconds", 0) + elapsed = time.time() - iteration_start + if min_interval and elapsed < min_interval: + sleep_duration = min_interval - elapsed + self.logger.debug(f"Sleeping {sleep_duration:.2f}s to respect polling interval") + time.sleep(sleep_duration) + + total_time = time.time() - start_time + self.logger.info(f"DAG processing completed successfully in {total_time:.2f} seconds") + self.logger.info(f"Total iterations: {iteration_count}") + return True + + except Exception as e: + self.logger.error( + f"DAG processing failed: {str(e)}\n" + f"Stack trace:\n{''.join(traceback.format_tb(e.__traceback__))}", + exc_info=True, + ) + + return False + + def _print_iteration_summary( + self, + ready_tasks_count: int, + active_submissions_count: int, + completed_tasks_this_iteration: int + ) -> None: + """Print iteration status in a single line with brief abbreviations.""" + # Get DAG progress information + completed_count = self.dcon.get_completed_task_count() + total_count = self.dcon.get_total_task_count() + progress_percentage = (completed_count / total_count * 100) if total_count > 0 else 0 + + # Build status string with core metrics + status_parts = [ + f"newRdy={ready_tasks_count}", + f"A.S.={active_submissions_count}", + f"doneNow={completed_tasks_this_iteration}", + f"Prgs={completed_count}/{total_count}({progress_percentage:.1f}%)" + ] + + # Add DAG status counts if available + if self.dcon.dag: + dag_summary = self.dcon.get_dag_summary() + if dag_summary.get("status_counts"): + status_counts = dag_summary["status_counts"] + + if "pending" in status_counts: + status_parts.append(f"Pend={status_counts['pending']}") + if "submitted" in status_counts: + status_parts.append(f"Sub={status_counts['submitted']}") + if "completed" in status_counts: + status_parts.append(f"Done={status_counts['completed']}") + if "failed" in status_counts: + status_parts.append(f"Fail={status_counts['failed']}") + + self.logger.info(f"Status: {', '.join(status_parts)}") + + def _log_iteration_timings( + self, + t21: float, + t22: float, + t31: float, + t32: float, + t41: float, + t42: float, + t51: float, + t52: float, + ready_tasks_count: int, + active_submissions_count: int, + completed_tasks_this_iteration: int, + ) -> None: + """ + Log a one-line timing summary for key iteration segments with task counts. + + Example output: + TIMINGS s:ready=0.0003 submit=0.0010 poll=0.0025 mark=0.0004 total=0.0042 newRdy=5 A.S.=26 doneNow=3 + """ + ready = t22 - t21 + submit = t32 - t31 + poll = t42 - t41 + mark = t52 - t51 + total = t52 - t21 + + # Get total progress for logging + completed_count = self.dcon.get_completed_task_count() + total_count = self.dcon.get_total_task_count() + + self.logger.info( + f"TIMINGS ready={ready:>7.4f}s submit={submit:>7.4f}s " + f"poll={poll:>7.4f}s mark={mark:>7.4f}s total={total:>7.4f}s | " + f"newRdy={ready_tasks_count:>6} ActSub={active_submissions_count:>6} doneNow={completed_tasks_this_iteration:>6} " + f"totalDone={completed_count:>8}/{total_count:<8}" + ) diff --git a/source/client/python/utils/utils/ttl_experation_generator.py b/source/client/python/utils/utils/ttl_experation_generator.py index 4a9fe68f..ad364486 100644 --- a/source/client/python/utils/utils/ttl_experation_generator.py +++ b/source/client/python/utils/utils/ttl_experation_generator.py @@ -1,4 +1,4 @@ -# Copyright 2024 Amazon.com, Inc. or its affiliates. +# Copyright 2024 Amazon.com, Inc. or its affiliates. # SPDX-License-Identifier: Apache-2.0 # Licensed under the Apache License, Version 2.0 https://aws.amazon.com/apache-2-0/ diff --git a/source/control_plane/python/lambda/ttl_checker/ttl_checker.py b/source/control_plane/python/lambda/ttl_checker/ttl_checker.py index f7ce7281..89ec9a61 100644 --- a/source/control_plane/python/lambda/ttl_checker/ttl_checker.py +++ b/source/control_plane/python/lambda/ttl_checker/ttl_checker.py @@ -14,9 +14,6 @@ from utils import grid_error_logger as errlog from utils.state_table_common import ( - TASK_STATE_RETRYING, - TASK_STATE_INCONSISTENT, - TASK_STATE_FAILED, StateTableException, ) from api.queue_manager import queue_manager @@ -55,9 +52,7 @@ cw_client = boto3.client("cloudwatch") TTL_LAMBDA_ID = "TTL_LAMBDA" -TTL_LAMBDA_TMP_STATE = TASK_STATE_RETRYING -TTL_LAMBDA_FAILED_STATE = TASK_STATE_FAILED -TTL_LAMBDA_INCONSISTENT_STATE = TASK_STATE_INCONSISTENT + MAX_RETRIES = 5 RETRIEVE_EXPIRED_TASKS_LIMIT = 200