diff --git a/gltest/contracts/contract.py b/gltest/contracts/contract.py index d105503..87babb7 100644 --- a/gltest/contracts/contract.py +++ b/gltest/contracts/contract.py @@ -71,6 +71,7 @@ def transact_method( wait_retries: Optional[int] = None, wait_triggered_transactions: bool = False, wait_triggered_transactions_status: TransactionStatus = TransactionStatus.ACCEPTED, + wait_triggered_transactions_depth: int = 3, transaction_context: Optional[TransactionContext] = None, ): """ @@ -117,17 +118,62 @@ def transact_method( interval=actual_wait_interval, retries=actual_wait_retries, ) - if wait_triggered_transactions: - triggered_transactions = receipt.get("triggered_transactions", []) - for triggered_transaction in triggered_transactions: - client.wait_for_transaction_receipt( - transaction_hash=triggered_transaction, - status=wait_triggered_transactions_status, - interval=actual_wait_interval, - retries=actual_wait_retries, - ) + if wait_triggered_transactions and wait_triggered_transactions_depth > 0: + pending_receipts = [receipt] + for _ in range(wait_triggered_transactions_depth): + next_receipts = [] + for current_receipt in pending_receipts: + triggered_transactions = current_receipt.get( + "triggered_transactions", [] + ) + for triggered_transaction in triggered_transactions: + triggered_receipt = client.wait_for_transaction_receipt( + transaction_hash=triggered_transaction, + status=wait_triggered_transactions_status, + interval=actual_wait_interval, + retries=actual_wait_retries, + ) + next_receipts.append(triggered_receipt) + if not next_receipts: + break + pending_receipts = next_receipts return receipt + def raw_transact_method( + value: int = 0, + consensus_max_rotations: Optional[int] = None, + transaction_context: Optional[TransactionContext] = None, + ): + """ + Send the transaction and return the transaction hash without waiting. + """ + general_config = get_general_config() + leader_only = ( + general_config.get_leader_only() + if general_config.check_studio_based_rpc() + else False + ) + client = get_gl_client() + sim_config = None + if transaction_context: + try: + sim_config = SimConfig(**transaction_context) + except TypeError as e: + raise ValueError( + f"Invalid transaction_context keys: {sorted(transaction_context.keys())}" + ) from e + tx_hash = client.write_contract( + address=self.address, + function_name=method_name, + account=self.account, + value=value, + consensus_max_rotations=consensus_max_rotations, + leader_only=leader_only, + args=args, + sim_config=sim_config, + ) + return tx_hash + def analyze_method( provider: str, model: str, @@ -161,6 +207,7 @@ def analyze_method( method_name=method_name, read_only=False, transact_method=transact_method, + raw_transact_method=raw_transact_method, analyze_method=analyze_method, ) diff --git a/gltest/contracts/contract_factory.py b/gltest/contracts/contract_factory.py index 75e8971..79df222 100644 --- a/gltest/contracts/contract_factory.py +++ b/gltest/contracts/contract_factory.py @@ -115,6 +115,7 @@ def deploy( wait_transaction_status: TransactionStatus = TransactionStatus.ACCEPTED, wait_triggered_transactions: bool = False, wait_triggered_transactions_status: TransactionStatus = TransactionStatus.ACCEPTED, + wait_triggered_transactions_depth: int = 3, transaction_context: Optional[TransactionContext] = None, ) -> Contract: """ @@ -132,6 +133,7 @@ def deploy( wait_transaction_status=wait_transaction_status, wait_triggered_transactions=wait_triggered_transactions, wait_triggered_transactions_status=wait_triggered_transactions_status, + wait_triggered_transactions_depth=wait_triggered_transactions_depth, transaction_context=transaction_context, ) @@ -151,6 +153,7 @@ def deploy_contract_tx( wait_transaction_status: TransactionStatus = TransactionStatus.ACCEPTED, wait_triggered_transactions: bool = False, wait_triggered_transactions_status: TransactionStatus = TransactionStatus.ACCEPTED, + wait_triggered_transactions_depth: int = 3, transaction_context: Optional[TransactionContext] = None, ) -> GenLayerTransaction: """ @@ -197,15 +200,25 @@ def deploy_contract_tx( interval=actual_wait_interval, retries=actual_wait_retries, ) - if wait_triggered_transactions: - triggered_transactions = tx_receipt.get("triggered_transactions", []) - for triggered_transaction in triggered_transactions: - client.wait_for_transaction_receipt( - transaction_hash=triggered_transaction, - status=wait_triggered_transactions_status, - interval=actual_wait_interval, - retries=actual_wait_retries, - ) + if wait_triggered_transactions and wait_triggered_transactions_depth > 0: + pending_receipts = [tx_receipt] + for _ in range(wait_triggered_transactions_depth): + next_receipts = [] + for current_receipt in pending_receipts: + triggered_transactions = current_receipt.get( + "triggered_transactions", [] + ) + for triggered_transaction in triggered_transactions: + triggered_receipt = client.wait_for_transaction_receipt( + transaction_hash=triggered_transaction, + status=wait_triggered_transactions_status, + interval=actual_wait_interval, + retries=actual_wait_retries, + ) + next_receipts.append(triggered_receipt) + if not next_receipts: + break + pending_receipts = next_receipts return tx_receipt except Exception as e: raise DeploymentError( diff --git a/gltest/contracts/contract_functions.py b/gltest/contracts/contract_functions.py index e773468..0215afc 100644 --- a/gltest/contracts/contract_functions.py +++ b/gltest/contracts/contract_functions.py @@ -10,6 +10,7 @@ class ContractFunction: call_method: Optional[Callable] = None analyze_method: Optional[Callable] = None transact_method: Optional[Callable] = None + raw_transact_method: Optional[Callable] = None def call( self, @@ -32,6 +33,7 @@ def transact( wait_retries: Optional[int] = None, wait_triggered_transactions: bool = False, wait_triggered_transactions_status: TransactionStatus = TransactionStatus.ACCEPTED, + wait_triggered_transactions_depth: int = 3, transaction_context: Optional[TransactionContext] = None, ): if self.read_only: @@ -44,6 +46,21 @@ def transact( wait_retries=wait_retries, wait_triggered_transactions=wait_triggered_transactions, wait_triggered_transactions_status=wait_triggered_transactions_status, + wait_triggered_transactions_depth=wait_triggered_transactions_depth, + transaction_context=transaction_context, + ) + + def raw_transact( + self, + value: int = 0, + consensus_max_rotations: Optional[int] = None, + transaction_context: Optional[TransactionContext] = None, + ): + if self.read_only: + raise ValueError("Cannot raw_transact read-only method") + return self.raw_transact_method( + value=value, + consensus_max_rotations=consensus_max_rotations, transaction_context=transaction_context, ) diff --git a/gltest/types.py b/gltest/types.py index 694ebc0..767c034 100644 --- a/gltest/types.py +++ b/gltest/types.py @@ -1,4 +1,5 @@ # Re-export genlayer-py types +from __future__ import annotations from genlayer_py.types import ( CalldataAddress, GenLayerTransaction, @@ -6,7 +7,8 @@ CalldataEncodable, TransactionHashVariant, ) -from typing import List, TypedDict, Dict, Any +from typing import List, TypedDict, Dict, Any, Optional, Literal +from dataclasses import dataclass, field class MockedLLMResponse(TypedDict): @@ -49,3 +51,38 @@ class TransactionContext(TypedDict, total=False): validators: List[ValidatorConfig] # List to create virtual validators genvm_datetime: str # ISO format datetime string + + +@dataclass +class TransactionTree: + """A tree structure representing a transaction and its triggered children.""" + + receipt: GenLayerTransaction + children: List[TransactionTree] = field(default_factory=list) + + def flatten(self) -> List[GenLayerTransaction]: + """Flatten the tree into a list of receipts (breadth-first order).""" + result = [self.receipt] + for child in self.children: + result.extend(child.flatten()) + return result + + def get_children_receipts( + self, triggered_on: Optional[Literal["accepted", "finalized"]] = None + ) -> List[GenLayerTransaction]: + """Get receipts of direct children, optionally filtered by triggered_on status. + + Args: + triggered_on: Optional status to filter by ("accepted" or "finalized"). + If None, returns all children receipts. + + Returns: + A list of receipts from direct children. + """ + if triggered_on is None: + return [child.receipt for child in self.children] + return [ + child.receipt + for child in self.children + if child.receipt.get("triggered_on") == triggered_on + ] diff --git a/gltest/utils.py b/gltest/utils.py index 99a88a2..aa3ae4a 100644 --- a/gltest/utils.py +++ b/gltest/utils.py @@ -1,4 +1,8 @@ +from typing import List, Optional from genlayer_py.types import GenLayerTransaction +from gltest.types import TransactionStatus, TransactionTree +from gltest.clients import get_gl_client +from gltest_cli.config.general import get_general_config def extract_contract_address(receipt: GenLayerTransaction) -> str: @@ -12,3 +16,75 @@ def extract_contract_address(receipt: GenLayerTransaction) -> str: return receipt["data"]["contract_address"] else: raise ValueError("Transaction receipt missing contract address") + + +def wait_for_transaction( + tx_hash: str, + wait_transaction_status: TransactionStatus = TransactionStatus.ACCEPTED, + wait_interval: Optional[int] = None, + wait_retries: Optional[int] = None, + wait_triggered_transactions: bool = False, + wait_triggered_transactions_status: TransactionStatus = TransactionStatus.ACCEPTED, + wait_triggered_transactions_depth: int = 3, +) -> TransactionTree: + """Wait for a transaction and optionally its triggered transactions. + + Args: + tx_hash: The transaction hash to wait for. + wait_transaction_status: The status to wait for on the main transaction. + wait_interval: Polling interval in seconds. Uses default if not specified. + wait_retries: Number of retries. Uses default if not specified. + wait_triggered_transactions: Whether to wait for triggered transactions. + wait_triggered_transactions_status: The status to wait for on triggered transactions. + wait_triggered_transactions_depth: Maximum depth to follow triggered transactions. + + Returns: + A TransactionTree with the root transaction and nested children for + triggered transactions. Use .flatten() to get a flat list of receipts, + or .children to access direct children, or .get_children_receipts() for all descendants. + """ + general_config = get_general_config() + actual_wait_interval = ( + wait_interval + if wait_interval is not None + else general_config.get_default_wait_interval() + ) + actual_wait_retries = ( + wait_retries + if wait_retries is not None + else general_config.get_default_wait_retries() + ) + + client = get_gl_client() + receipt = client.wait_for_transaction_receipt( + transaction_hash=tx_hash, + status=wait_transaction_status, + interval=actual_wait_interval, + retries=actual_wait_retries, + ) + + root = TransactionTree(receipt=receipt) + + if wait_triggered_transactions and wait_triggered_transactions_depth > 0: + pending_nodes = [root] + for _ in range(wait_triggered_transactions_depth): + next_nodes = [] + for current_node in pending_nodes: + triggered_transactions = current_node.receipt.get( + "triggered_transactions", [] + ) + for triggered_transaction in triggered_transactions: + triggered_receipt = client.wait_for_transaction_receipt( + transaction_hash=triggered_transaction, + status=wait_triggered_transactions_status, + interval=actual_wait_interval, + retries=actual_wait_retries, + ) + child_node = TransactionTree(receipt=triggered_receipt) + current_node.children.append(child_node) + next_nodes.append(child_node) + if not next_nodes: + break + pending_nodes = next_nodes + + return root