diff --git a/.gitignore b/.gitignore index 37d40f3e..e7978e9a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -.DS_Store .codex +.DS_Store binary_sizes_baseline.json baseline.json trigger.sh diff --git a/examples/smart_contracts/always_true.py b/examples/smart_contracts/always_true.py index af4061a2..faf1909e 100644 --- a/examples/smart_contracts/always_true.py +++ b/examples/smart_contracts/always_true.py @@ -2,5 +2,7 @@ from opshin.prelude import * -def validator(context: ScriptContext) -> None: - pass +@dataclass() +class AlwaysTrue(Contract): + def raw(self, _context: ScriptContext) -> None: + pass diff --git a/examples/smart_contracts/assert_sum.py b/examples/smart_contracts/assert_sum.py index 576b0351..98a1ffb4 100644 --- a/examples/smart_contracts/assert_sum.py +++ b/examples/smart_contracts/assert_sum.py @@ -2,9 +2,11 @@ from opshin.prelude import * -def validator(context: ScriptContext) -> None: - datum: int = own_datum_unsafe(context) - redeemer: int = context.redeemer - assert ( - datum + redeemer == 42 - ), f"Expected datum and redeemer to sum to 42, but they sum to {datum + redeemer}" +@dataclass() +class AssertSum(Contract): + def spend_with_datum( + self, datum: int, redeemer: int, _context: ScriptContext + ) -> None: + assert ( + datum + redeemer == 42 + ), f"Expected datum and redeemer to sum to 42, but they sum to {datum + redeemer}" diff --git a/examples/smart_contracts/bitwise.py b/examples/smart_contracts/bitwise.py index 26a20ee3..9255f736 100644 --- a/examples/smart_contracts/bitwise.py +++ b/examples/smart_contracts/bitwise.py @@ -1,22 +1,19 @@ #!opshin from opshin.prelude import * -from opshin.std.integrity import check_integrity from opshin.std.math import * -def validator(context: ScriptContext) -> None: - """ - A contract that checks whether the bitwise AND of the datum and redeemer is zero. - """ - maybe_datum = own_datum(context) - # it is possible that no datum is attached (new in PlutusV3) - if isinstance(maybe_datum, NoOutputDatum): - # in that case, no datum was attached and we accept the transaction +@dataclass() +class Bitwise(Contract): + def spend_no_datum(self, _redeemer: int, _context: ScriptContext) -> None: return - else: - datum: int = maybe_datum.datum - # In plutus v3, the redeemer is in the script context - redeemer: int = context.redeemer + + def spend_with_datum( + self, datum: int, redeemer: int, _context: ScriptContext + ) -> None: + """ + A contract that checks whether the bitwise AND of the datum and redeemer is zero. + """ datum_bytes = bytes_big_from_unsigned_int(datum) redeemer_bytes = bytes_big_from_unsigned_int(redeemer) # compute the bitwise XOR of the two byte arrays diff --git a/examples/smart_contracts/gift.py b/examples/smart_contracts/gift.py index 4a31b930..34550acd 100755 --- a/examples/smart_contracts/gift.py +++ b/examples/smart_contracts/gift.py @@ -9,12 +9,15 @@ class WithdrawDatum(PlutusData): pubkeyhash: PubKeyHash -def validator(context: ScriptContext) -> None: - datum: WithdrawDatum = own_datum_unsafe(context) - # check that the datum has correct structure (recommended for user inputs) - # can be omitted if the datum can not make its way into a permanent script state (i.e., not stored in an output) - check_integrity(datum) - sig_present = datum.pubkeyhash in context.transaction.signatories - assert ( - sig_present - ), f"Required signature missing, expected {datum.pubkeyhash.hex()} but got {[s.hex() for s in context.transaction.signatories]}" +@dataclass() +class Gift(Contract): + def spend_with_datum( + self, datum: WithdrawDatum, _redeemer: Anything, context: ScriptContext + ) -> None: + # check that the datum has correct structure (recommended for user inputs) + # can be omitted if the datum can not make its way into a permanent script state (i.e., not stored in an output) + check_integrity(datum) + sig_present = datum.pubkeyhash in context.transaction.signatories + assert ( + sig_present + ), f"Required signature missing, expected {datum.pubkeyhash.hex()} but got {[s.hex() for s in context.transaction.signatories]}" diff --git a/examples/smart_contracts/inspect_script_context.py b/examples/smart_contracts/inspect_script_context.py index bab7f365..f2d46c86 100644 --- a/examples/smart_contracts/inspect_script_context.py +++ b/examples/smart_contracts/inspect_script_context.py @@ -2,5 +2,11 @@ from opshin.prelude import * -def validator(context: ScriptContext): - print(context.to_cbor()) +# this validator will print the datum, redeemer and script context passed from the node in a readable format +@dataclass() +class InspectScriptContext(Contract): + def raw(self, context: ScriptContext) -> None: + print(f"script context (CBOR hex): {context.to_cbor().hex()}") + print(f"script context (native): {context}") + + assert False, "Failing in order to show script logs" diff --git a/examples/smart_contracts/liquidity_pool.py b/examples/smart_contracts/liquidity_pool.py index efa48c39..90c605c2 100644 --- a/examples/smart_contracts/liquidity_pool.py +++ b/examples/smart_contracts/liquidity_pool.py @@ -607,37 +607,39 @@ def get_spending_purpose(context: ScriptContext) -> Spending: return purpose -def validator(context: ScriptContext) -> None: - """ - Validates that the pool is spent correctly - DISCLAIMER: This is a simple example to demonstrate onchain based contract upgradeability and should not be used in production. - """ - purpose = get_spending_purpose(context) - redeemer: PoolAction = context.redeemer - check_integrity(redeemer) - datum: PoolState = own_datum_unsafe(context) - check_integrity(datum) - own_input_info = context.transaction.inputs[redeemer.pool_input_index] - assert ( - own_input_info.out_ref == purpose.tx_out_ref - ), "Index of own input does not match purpose" - - own_output = context.transaction.outputs[redeemer.pool_output_index] - if isinstance(redeemer, AddLiquidity) or isinstance(redeemer, RemoveLiquidity): - check_valid_license_present( - redeemer.license_input_index, - context.transaction, - datum.up_pool_params.license_policy_id, - ) - check_change_liquidity(datum, redeemer, context, own_input_info, own_output) - elif isinstance(redeemer, SwapAsset): - check_valid_license_present( - redeemer.license_input_index, - context.transaction, - datum.up_pool_params.license_policy_id, - ) - check_swap(datum, redeemer, context, own_input_info, own_output) - elif isinstance(redeemer, PoolUpgrade): - check_upgrade(datum, redeemer, context, own_input_info, own_output) - else: - assert False, "Unknown redeemer" +@dataclass() +class LiquidityPool(Contract): + def spend_with_datum( + self, datum: PoolState, redeemer: PoolAction, context: ScriptContext + ) -> None: + """ + Validates that the pool is spent correctly + DISCLAIMER: This is a simple example to demonstrate onchain based contract upgradeability and should not be used in production. + """ + purpose = get_spending_purpose(context) + check_integrity(redeemer) + check_integrity(datum) + own_input_info = context.transaction.inputs[redeemer.pool_input_index] + assert ( + own_input_info.out_ref == purpose.tx_out_ref + ), "Index of own input does not match purpose" + + own_output = context.transaction.outputs[redeemer.pool_output_index] + if isinstance(redeemer, AddLiquidity) or isinstance(redeemer, RemoveLiquidity): + check_valid_license_present( + redeemer.license_input_index, + context.transaction, + datum.up_pool_params.license_policy_id, + ) + check_change_liquidity(datum, redeemer, context, own_input_info, own_output) + elif isinstance(redeemer, SwapAsset): + check_valid_license_present( + redeemer.license_input_index, + context.transaction, + datum.up_pool_params.license_policy_id, + ) + check_swap(datum, redeemer, context, own_input_info, own_output) + elif isinstance(redeemer, PoolUpgrade): + check_upgrade(datum, redeemer, context, own_input_info, own_output) + else: + assert False, "Unknown redeemer" diff --git a/examples/smart_contracts/marketplace.py b/examples/smart_contracts/marketplace.py index 184e6910..d5cc181d 100755 --- a/examples/smart_contracts/marketplace.py +++ b/examples/smart_contracts/marketplace.py @@ -53,23 +53,25 @@ def check_owner_signed(signatories: List[PubKeyHash], owner: PubKeyHash) -> None ), f"Owner did not sign transaction, requires {owner.hex()} but got {[s.hex() for s in signatories]}" -def validator(context: ScriptContext) -> None: - purpose = context.purpose - datum: Listing = own_datum_unsafe(context) - check_integrity(datum) - redeemer: ListingAction = context.redeemer - check_integrity(redeemer) - - tx_info = context.transaction - assert isinstance(purpose, Spending), f"Wrong script purpose: {purpose}" - own_utxo = resolve_spent_utxo(tx_info.inputs, purpose) - own_addr = own_utxo.address - - check_single_utxo_spent(tx_info.inputs, own_addr) - # It is recommended to explicitly check all options with isinstance for user input - if isinstance(redeemer, Buy): - check_paid(tx_info.outputs, datum.vendor, datum.price) - elif isinstance(redeemer, Unlist): - check_owner_signed(tx_info.signatories, datum.owner) - else: - assert False, "Wrong redeemer" +@dataclass() +class Marketplace(Contract): + def spend_with_datum( + self, datum: Listing, redeemer: ListingAction, context: ScriptContext + ) -> None: + purpose = context.purpose + check_integrity(datum) + check_integrity(redeemer) + + tx_info = context.transaction + assert isinstance(purpose, Spending), f"Wrong script purpose: {purpose}" + own_utxo = resolve_spent_utxo(tx_info.inputs, purpose) + own_addr = own_utxo.address + + check_single_utxo_spent(tx_info.inputs, own_addr) + # It is recommended to explicitly check all options with isinstance for user input + if isinstance(redeemer, Buy): + check_paid(tx_info.outputs, datum.vendor, datum.price) + elif isinstance(redeemer, Unlist): + check_owner_signed(tx_info.signatories, datum.owner) + else: + assert False, "Wrong redeemer" diff --git a/examples/smart_contracts/micropayments.py b/examples/smart_contracts/micropayments.py index 6129882e..0e4cc5e7 100644 --- a/examples/smart_contracts/micropayments.py +++ b/examples/smart_contracts/micropayments.py @@ -65,112 +65,110 @@ class TearDown(PlutusData): ChannelAction = Union[Micropayments, TearDown] -def validator(context: ScriptContext) -> None: - redeemer: ChannelAction = context.redeemer - # Ensure that the redeemer is well formed - check_integrity(redeemer) - purpose = context.purpose - assert isinstance(purpose, Spending), "Can only spend from the contract" - own_utxo = own_spent_utxo(context.transaction.inputs, purpose) - datum: PaymentChannel = resolve_datum_unsafe(own_utxo, context.transaction) - check_integrity(datum) - ( - balance_alice_datum, - pubkeyhash_alice, - balance_bob_datum, - pubkeyhash_bob, - nonce_datum, - ) = astuple(datum) - - if isinstance(redeemer, TearDown): - # Ensure that either party signed this request - assert ( - pubkeyhash_alice in context.transaction.signatories - or pubkeyhash_bob in context.transaction.signatories - ), f"Neither Alice nor Bob signed the transaction, signatory list: {context.transaction.signatories}" - # Ensure that all participants receive their amounts - amount_alice = 0 - amount_bob = 0 - for o in context.transaction.outputs: - # Note: in a real world scenario, you will want to make sure the stake key hash matches too! - pkh = o.address.payment_credential.credential_hash - if pkh == pubkeyhash_alice: - amount_alice += o.value.get(b"", {b"": 0}).get(b"", 0) - elif pkh == pubkeyhash_bob: - amount_bob += o.value.get(b"", {b"": 0}).get(b"", 0) - assert ( - amount_alice >= balance_alice_datum - ), f"Alice does not receive enough, expecting {balance_alice_datum}, receiving {amount_alice}" - assert ( - amount_bob >= balance_bob_datum - ), f"Bob does not receive enough, expecting {balance_bob_datum}, receiving {amount_bob}" - # That's it! - elif isinstance(redeemer, Micropayments): - # Squash apply the micropayments - balance_alice = balance_alice_datum - balance_bob = balance_bob_datum - nonce = nonce_datum - # Ensure that the payments are all valid and accumulate state - for payment in redeemer.payments: +@dataclass() +class Micropayments(Contract): + def spend_with_datum( + self, datum: PaymentChannel, redeemer: ChannelAction, context: ScriptContext + ) -> None: + # Ensure that the redeemer is well formed + check_integrity(redeemer) + check_integrity(datum) + ( + balance_alice_datum, + pubkeyhash_alice, + balance_bob_datum, + pubkeyhash_bob, + nonce_datum, + ) = astuple(datum) + + if isinstance(redeemer, TearDown): + # Ensure that either party signed this request assert ( - payment.nonce > nonce - ), f"Invalid nonce, replay attack detected ({payment.nonce} <= {nonce})" + pubkeyhash_alice in context.transaction.signatories + or pubkeyhash_bob in context.transaction.signatories + ), f"Neither Alice nor Bob signed the transaction, signatory list: {context.transaction.signatories}" + # Ensure that all participants receive their amounts + amount_alice = 0 + amount_bob = 0 + for o in context.transaction.outputs: + # Note: in a real world scenario, you will want to make sure the stake key hash matches too! + pkh = o.address.payment_credential.credential_hash + if pkh == pubkeyhash_alice: + amount_alice += o.value.get(b"", {b"": 0}).get(b"", 0) + elif pkh == pubkeyhash_bob: + amount_bob += o.value.get(b"", {b"": 0}).get(b"", 0) assert ( - payment.amount > 0 - ), f"Invalid amount transfer {payment.amount}, must be positive" - nonce = payment.nonce - if isinstance(payment, MicropaymentAlice): - assert verify_ed25519_signature( - pubkeyhash_alice, - (str(payment.amount) + str(nonce)).encode(), - payment.sig, - ), "Invalid signature of Alice for micropayment" - balance_alice -= payment.amount - balance_bob += payment.amount - elif isinstance(payment, MicropaymentBob): - assert verify_ed25519_signature( - pubkeyhash_bob, - (str(payment.amount) + str(nonce)).encode(), - payment.sig, - ), "Invalid signature of Bob for micropayment" - balance_alice += payment.amount - balance_bob -= payment.amount - else: - assert False, "Invalid type of micropayment!" - - # we cast the purpose to spending, every other purpose does not make sense - purpose: Spending = context.purpose - own_tx_out_ref = purpose.tx_out_ref - # this stunt is just to find the output that goes to the same address as the input we are validating to be spent - own_tx_out = [ - i for i in context.transaction.inputs if i.out_ref == own_tx_out_ref - ][0].resolved - own_address = own_tx_out.address - cont_tx_out = [ - o for o in context.transaction.outputs if o.address == own_address - ][0] - # The value = locked tokens must not change - for pid, tn_dict in own_tx_out.value.items(): - for tokenname, amount in tn_dict.items(): + amount_alice >= balance_alice_datum + ), f"Alice does not receive enough, expecting {balance_alice_datum}, receiving {amount_alice}" + assert ( + amount_bob >= balance_bob_datum + ), f"Bob does not receive enough, expecting {balance_bob_datum}, receiving {amount_bob}" + # That's it! + elif isinstance(redeemer, Micropayments): + # Squash apply the micropayments + balance_alice = balance_alice_datum + balance_bob = balance_bob_datum + nonce = nonce_datum + # Ensure that the payments are all valid and accumulate state + for payment in redeemer.payments: assert ( - amount <= cont_tx_out.value[pid][tokenname] - ), f"Value of token in payment channel has decreased from {amount} to {cont_tx_out.value[pid][tokenname]}" - cont_datum = cont_tx_out.datum - assert isinstance( - cont_datum, SomeOutputDatum - ), f"Must inline attached datum, got {cont_datum}" - # We cast the datum to payment channel (it is stored without structure in the ledger) - cont_datum_content: PaymentChannel = cont_datum.datum - # Technically not needed because we compare for exact equality below, but good practice - check_integrity(cont_datum_content) - # Ensure that the state is correctly updated - assert cont_datum_content == PaymentChannel( - balance_alice, - pubkeyhash_alice, - balance_bob, - pubkeyhash_bob, - nonce, - ) - else: - # Other redeemers are not allowed! - assert False, "Wrong redeemer passed!" + payment.nonce > nonce + ), f"Invalid nonce, replay attack detected ({payment.nonce} <= {nonce})" + assert ( + payment.amount > 0 + ), f"Invalid amount transfer {payment.amount}, must be positive" + nonce = payment.nonce + if isinstance(payment, MicropaymentAlice): + assert verify_ed25519_signature( + pubkeyhash_alice, + (str(payment.amount) + str(nonce)).encode(), + payment.sig, + ), "Invalid signature of Alice for micropayment" + balance_alice -= payment.amount + balance_bob += payment.amount + elif isinstance(payment, MicropaymentBob): + assert verify_ed25519_signature( + pubkeyhash_bob, + (str(payment.amount) + str(nonce)).encode(), + payment.sig, + ), "Invalid signature of Bob for micropayment" + balance_alice += payment.amount + balance_bob -= payment.amount + else: + assert False, "Invalid type of micropayment!" + + purpose: Spending = context.purpose + own_tx_out_ref = purpose.tx_out_ref + # this stunt is just to find the output that goes to the same address as the input we are validating to be spent + own_tx_out = [ + i for i in context.transaction.inputs if i.out_ref == own_tx_out_ref + ][0].resolved + own_address = own_tx_out.address + cont_tx_out = [ + o for o in context.transaction.outputs if o.address == own_address + ][0] + # The value = locked tokens must not change + for pid, tn_dict in own_tx_out.value.items(): + for tokenname, amount in tn_dict.items(): + assert ( + amount <= cont_tx_out.value[pid][tokenname] + ), f"Value of token in payment channel has decreased from {amount} to {cont_tx_out.value[pid][tokenname]}" + cont_datum = cont_tx_out.datum + assert isinstance( + cont_datum, SomeOutputDatum + ), f"Must inline attached datum, got {cont_datum}" + # We cast the datum to payment channel (it is stored without structure in the ledger) + cont_datum_content: PaymentChannel = cont_datum.datum + # Technically not needed because we compare for exact equality below, but good practice + check_integrity(cont_datum_content) + # Ensure that the state is correctly updated + assert cont_datum_content == PaymentChannel( + balance_alice, + pubkeyhash_alice, + balance_bob, + pubkeyhash_bob, + nonce, + ) + else: + # Other redeemers are not allowed! + assert False, "Wrong redeemer passed!" diff --git a/examples/smart_contracts/parameterized.py b/examples/smart_contracts/parameterized.py index 1d2bfe03..e614001a 100644 --- a/examples/smart_contracts/parameterized.py +++ b/examples/smart_contracts/parameterized.py @@ -7,6 +7,9 @@ # this contract can be parameterized at compile time. Pass the parameter with the build command # # $ opshin build examples/smart_contracts/parameterized.py '{"int": 42}' -def validator(parameter: int, ctx: ScriptContext) -> None: - r: int = ctx.redeemer - assert r == parameter, "Wrong redeemer" +@dataclass() +class Parameterized(Contract): + parameter: int + + def spend_no_datum(self, redeemer: int, _context: ScriptContext) -> None: + assert redeemer == self.parameter, "Wrong redeemer" diff --git a/examples/smart_contracts/simple_script.py b/examples/smart_contracts/simple_script.py index 3c397221..2ba1cca7 100644 --- a/examples/smart_contracts/simple_script.py +++ b/examples/smart_contracts/simple_script.py @@ -108,9 +108,14 @@ def validate_script( return res -# to fully emulate simple script behaviour, compile with --force-three-params # the script is a contract parameter, pass it into the build command -def validator(script: Script, context: ScriptContext) -> None: - assert validate_script( - script, context.transaction.signatories, context.transaction.validity_range - ), "Simple Script validation failed!" +@dataclass() +class SimpleScript(Contract): + script: Script + + def raw(self, context: ScriptContext) -> None: + assert validate_script( + self.script, + context.transaction.signatories, + context.transaction.validity_range, + ), "Simple Script validation failed!" diff --git a/examples/smart_contracts/vesting.py b/examples/smart_contracts/vesting.py index 2f3f9bab..aabe37e6 100644 --- a/examples/smart_contracts/vesting.py +++ b/examples/smart_contracts/vesting.py @@ -28,11 +28,11 @@ def deadline_reached(params: VestingParams, context: ScriptContext) -> bool: return is_after(params.deadline, context.transaction.validity_range) -def validator(context: ScriptContext) -> None: - purpose = context.purpose - assert isinstance(purpose, Spending) - own_utxo = own_spent_utxo(context.transaction.inputs, purpose) - datum: VestingParams = resolve_datum_unsafe(own_utxo, context.transaction) - assert signed_by_beneficiary(datum, context), "beneficiary's signature missing" - assert deadline_reached(datum, context), "deadline not reached" - return None +@dataclass() +class Vesting(Contract): + def spend_with_datum( + self, datum: VestingParams, _redeemer: Anything, context: ScriptContext + ) -> None: + assert signed_by_beneficiary(datum, context), "beneficiary's signature missing" + assert deadline_reached(datum, context), "deadline not reached" + return None diff --git a/examples/smart_contracts/wrapped_token.py b/examples/smart_contracts/wrapped_token.py index 03982c22..f9ae0f84 100644 --- a/examples/smart_contracts/wrapped_token.py +++ b/examples/smart_contracts/wrapped_token.py @@ -61,38 +61,39 @@ def all_tokens_locked_at_contract_address( # parameters controlling which token is to be wrapped and how many decimal places to add # compile the contract as follows to obtain the parameterized contract (for preprod milk) # -# moreover this contract should always be called with three virtual parameters, so enable --force-three-params -# -# $ opshin build spending examples/smart_contracts/wrapped_token.py '{"bytes": "ae810731b5d21c0d182d89c60a1eff7095dffd1c0dce8707a8611099"}' '{"bytes": "4d494c4b"}' '{"int": 1000000}' --force-three-params -def validator( - token_policy_id: bytes, - token_name: bytes, - wrapping_factor: int, - ctx: ScriptContext, -) -> None: - purpose = ctx.purpose - if isinstance(purpose, Minting): - # whenever tokens should be burned/minted, the minting purpose will be triggered - own_addr = own_address(purpose.policy_id) - own_pid = purpose.policy_id - elif isinstance(purpose, Spending): - # whenever something is unlocked from the contract, the spending purpose will be triggered - own_utxo = own_spent_utxo(ctx.transaction.inputs, purpose) - own_pid = own_policy_id(own_utxo) - own_addr = own_utxo.address - else: - assert False, "Incorrect purpose given" - token = Token(token_policy_id, token_name) - all_locked = all_tokens_locked_at_contract_address( - ctx.transaction.outputs, own_addr, token - ) - all_unlocked = all_tokens_unlocked_from_contract_address( - ctx.transaction.inputs, own_addr, token - ) - all_minted = ctx.transaction.mint.get(own_pid, {b"": 0}).get(b"w" + token_name, 0) - print(f"unlocked from contract: {all_unlocked}") - print(f"locked at contract: {all_locked}") - print(f"minted: {all_minted}") - assert ( - (all_locked - all_unlocked) * wrapping_factor - ) == all_minted, f"Wrong amount of tokens minted, difference: {(all_locked - all_unlocked) * wrapping_factor - all_minted}" +# $ opshin build spending examples/smart_contracts/wrapped_token.py '{"bytes": "ae810731b5d21c0d182d89c60a1eff7095dffd1c0dce8707a8611099"}' '{"bytes": "4d494c4b"}' '{"int": 1000000}' +@dataclass() +class WrappedToken(Contract): + token_policy_id: bytes + token_name: bytes + wrapping_factor: int + + def raw(self, ctx: ScriptContext) -> None: + purpose = ctx.purpose + if isinstance(purpose, Minting): + # whenever tokens should be burned/minted, the minting purpose will be triggered + own_addr = own_address(purpose.policy_id) + own_pid = purpose.policy_id + elif isinstance(purpose, Spending): + # whenever something is unlocked from the contract, the spending purpose will be triggered + own_utxo = own_spent_utxo(ctx.transaction.inputs, purpose) + own_pid = own_policy_id(own_utxo) + own_addr = own_utxo.address + else: + assert False, "Incorrect purpose given" + token = Token(self.token_policy_id, self.token_name) + all_locked = all_tokens_locked_at_contract_address( + ctx.transaction.outputs, own_addr, token + ) + all_unlocked = all_tokens_unlocked_from_contract_address( + ctx.transaction.inputs, own_addr, token + ) + all_minted = ctx.transaction.mint.get(own_pid, {b"": 0}).get( + b"w" + self.token_name, 0 + ) + print(f"unlocked from contract: {all_unlocked}") + print(f"locked at contract: {all_locked}") + print(f"minted: {all_minted}") + assert ( + (all_locked - all_unlocked) * self.wrapping_factor + ) == all_minted, f"Wrong amount of tokens minted, difference: {(all_locked - all_unlocked) * self.wrapping_factor - all_minted}" diff --git a/opshin/__main__.py b/opshin/__main__.py index e855ab99..b174f5e1 100644 --- a/opshin/__main__.py +++ b/opshin/__main__.py @@ -38,6 +38,7 @@ from .util import CompilerError, data_from_json, OPSHIN_LOG_HANDLER from .prelude import ScriptContext from .compiler_config import * +from .contract_interface import discover_contract_module from uplc import cost_model @@ -277,7 +278,14 @@ def perform_command(args): sys.path.pop() # load the passed parameters if not a lib try: - argspec = inspect.signature(sc.validator if lib is None else getattr(sc, lib)) + contract_info = discover_contract_module(sc) if lib is None else None + if lib is None: + validator_callable = ( + sc.validator if contract_info is None else contract_info.validator + ) + else: + validator_callable = getattr(sc, lib) + argspec = inspect.signature(validator_callable) except AttributeError: raise AssertionError( f"Contract has no function called '{'validator' if lib is None else lib}'. Make sure the compiled contract contains one function called 'validator'." @@ -324,7 +332,7 @@ def perform_command(args): print("Python execution started") with redirect_stdout(open(os.devnull, "w")): try: - py_ret = sc.validator(*parsed_params) + py_ret = validator_callable(*parsed_params) except Exception as e: py_ret = e command = Command.eval_uplc @@ -419,11 +427,16 @@ def perform_command(args): built_code = builder._build(code) script_arts = PlutusContract( built_code, - # TODO this actually does not work anymore - datum_type=None, - redeemer_type=None, + datum_type=None if contract_info is None else contract_info.datum_type, + redeemer_type=( + None if contract_info is None else contract_info.redeemer_type + ), parameter_types=param_types, - purpose=(Purpose.any,), + purpose=( + (Purpose.any,) + if contract_info is None + else tuple(Purpose[purpose] for purpose in contract_info.purpose_names) + ), title=pathlib.Path(input_file).stem, ) script_arts.dump(target_dir) diff --git a/opshin/compiler.py b/opshin/compiler.py index d0324dc9..ed4d387f 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -59,6 +59,7 @@ from .rewrite.rewrite_annotate_fallthrough import RewriteAnnotateFallthrough from .rewrite.rewrite_augassign import RewriteAugAssign from .rewrite.rewrite_cast_condition import RewriteConditions +from .rewrite.rewrite_contract_methods import RewriteContractMethods from .rewrite.rewrite_empty_dicts import RewriteEmptyDicts from .rewrite.rewrite_empty_lists import RewriteEmptyLists from .rewrite.rewrite_destructuring_assign import RewriteDestructuringAssign @@ -1437,6 +1438,7 @@ def compile( compile_pipeline = [ # Important to call this one first - it imports all further files RewriteImport(filename=filename), + RewriteContractMethods(), # Rewrites that simplify the python code RewriteForbiddenReturn(), OptimizeUnionExpansion() if config.expand_union_types else NoOp(), diff --git a/opshin/contract_interface.py b/opshin/contract_interface.py new file mode 100644 index 00000000..0606a04c --- /dev/null +++ b/opshin/contract_interface.py @@ -0,0 +1,333 @@ +import dataclasses +import inspect +import typing + +from .ledger.api_v3 import Minting, Publishing, Proposing, Spending, Voting, Withdrawing +from .prelude import ( + Contract as PreludeContract, + NoOutputDatum, + ScriptContext, + SomeOutputDatum, + SomeOutputDatumHash, + own_datum, +) + + +@dataclasses.dataclass(frozen=True) +class ContractMethodSpec: + method_name: str + purpose_class: typing.Optional[type] + purpose_name: str + onchain_argument_count: int + + +CONTRACT_METHOD_SPECS = ( + ContractMethodSpec("raw", None, "any", 1), + ContractMethodSpec("spend_no_datum", Spending, "spending", 2), + ContractMethodSpec("spend_with_datum", Spending, "spending", 3), + ContractMethodSpec("mint", Minting, "minting", 2), + ContractMethodSpec("withdraw", Withdrawing, "rewarding", 2), + ContractMethodSpec("publish", Publishing, "certifying", 2), + ContractMethodSpec("vote", Voting, "voting", 2), + ContractMethodSpec("propose", Proposing, "proposing", 2), +) + +CONTRACT_METHOD_SPEC_MAP = { + contract_method.method_name: contract_method + for contract_method in CONTRACT_METHOD_SPECS +} + + +@dataclasses.dataclass(frozen=True) +class ContractMethodDetails: + spec: ContractMethodSpec + method: typing.Callable + argument_names: typing.Tuple[str, ...] + datum_type: typing.Optional[type] + redeemer_type: typing.Optional[type] + return_type: typing.Any + + +def _datum_loading_strategy(annotation: typing.Any) -> str: + origin = typing.get_origin(annotation) + if origin is typing.Union: + union_members = typing.get_args(annotation) + attachment_types = {NoOutputDatum, SomeOutputDatum, SomeOutputDatumHash} + if all(member in attachment_types for member in union_members): + return "attachment" + assert ( + NoOutputDatum not in union_members + ), "Contracts must use spend_no_datum instead of Union[..., NoOutputDatum]." + return "unsafe_raw" + + +@dataclasses.dataclass(frozen=True) +class ContractModuleInfo: + validator: typing.Callable + parameter_types: typing.List[typing.Tuple[str, typing.Any]] + method_details: typing.Tuple[ContractMethodDetails, ...] + has_raw_override: bool + + @property + def purpose_names(self) -> typing.Tuple[str, ...]: + if self.has_raw_override or not self.method_details: + return ("any",) + return tuple(detail.spec.purpose_name for detail in self.method_details) + + @property + def datum_type(self) -> typing.Optional[typing.Tuple[str, typing.Any]]: + if self.has_raw_override or not self.method_details: + return None + spending_methods = [ + detail + for detail in self.method_details + if detail.spec.method_name == "spend_with_datum" + ] + if not spending_methods: + return None + detail = spending_methods[0] + if detail.datum_type is inspect.Signature.empty: + return None + return ("datum", detail.datum_type) + + @property + def redeemer_type(self) -> typing.Optional[typing.Tuple[str, typing.Any]]: + if self.has_raw_override or not self.method_details: + return None + redeemer_types = [] + for detail in self.method_details: + if detail.redeemer_type is inspect.Signature.empty: + return None + redeemer_types.append(detail.redeemer_type) + if not redeemer_types: + return None + first = redeemer_types[0] + if any(redeemer_type != first for redeemer_type in redeemer_types[1:]): + return None + return ("redeemer", first) + + +def _contract_parameter_types( + contract_class: type, +) -> typing.List[typing.Tuple[str, typing.Any]]: + assert ( + "CONSTR_ID" not in contract_class.__dict__ + ), "Contract classes must not define CONSTR_ID." + annotations = inspect.get_annotations(contract_class) + unannotated_fields = [ + name + for name, value in contract_class.__dict__.items() + if not name.startswith("__") + and name not in annotations + and name not in CONTRACT_METHOD_SPEC_MAP + and not inspect.isfunction(value) + and not isinstance(value, (staticmethod, classmethod)) + ] + assert ( + not unannotated_fields + ), "Contract fields must be annotated; unannotated Contract fields are not supported." + return list(annotations.items()) + + +def _method_annotations( + method: typing.Callable, onchain_argument_count: int +) -> typing.Tuple[typing.Tuple[str, ...], typing.Any, typing.Any]: + signature = inspect.signature(method) + parameters = list(signature.parameters.values()) + assert ( + parameters + ), f"Contract method '{method.__name__}' must accept self as first parameter." + assert ( + parameters[0].name == "self" + ), f"Contract method '{method.__name__}' must accept self as first parameter." + expected_parameter_count = onchain_argument_count + 1 + assert len(parameters) == expected_parameter_count, ( + f"Contract method '{method.__name__}' must accept self plus " + f"{onchain_argument_count} on-chain parameters." + ) + method_parameters = parameters[1:] + context_parameter = method_parameters[-1] + if context_parameter.annotation is not inspect.Signature.empty: + assert ( + context_parameter.annotation == ScriptContext + ), f"Contract method '{method.__name__}' must annotate context as ScriptContext." + datum_type = None + redeemer_type = None + if onchain_argument_count == 3: + datum_type = method_parameters[0].annotation + if datum_type is not inspect.Signature.empty: + _datum_loading_strategy(datum_type) + redeemer_type = method_parameters[1].annotation + elif onchain_argument_count == 2: + redeemer_type = method_parameters[0].annotation + return ( + tuple(parameter.name for parameter in method_parameters), + datum_type, + redeemer_type, + ) + + +def _make_unique_name(preferred_name: str, used_names: typing.Set[str]) -> str: + if preferred_name not in used_names: + return preferred_name + suffix = 0 + while f"{preferred_name}_{suffix}" in used_names: + suffix += 1 + return f"{preferred_name}_{suffix}" + + +def _build_contract_validator( + contract_class: type, + parameter_types: typing.List[typing.Tuple[str, typing.Any]], + method_details: typing.Tuple[ContractMethodDetails, ...], + has_raw_override: bool, +): + preferred_context_name = ( + method_details[0].argument_names[-1] if method_details else "context" + ) + context_parameter_name = _make_unique_name( + preferred_context_name, + {field_name for field_name, _ in parameter_types}, + ) + + def validator(*args): + contract_parameter_count = len(parameter_types) + contract = contract_class(*args[:contract_parameter_count]) + context = args[contract_parameter_count] + if has_raw_override: + raw_method = next( + detail.method + for detail in method_details + if detail.spec.method_name == "raw" + ) + return raw_method(contract, context) + if not method_details: + return PreludeContract.raw(contract, context) + purpose = context.purpose + spending_methods = { + detail.spec.method_name: detail + for detail in method_details + if detail.spec.purpose_class is Spending + } + if isinstance(purpose, Spending) and spending_methods: + attached_datum = own_datum(context) + if isinstance(attached_datum, NoOutputDatum): + spend_no_datum = spending_methods.get("spend_no_datum") + assert ( + spend_no_datum is not None + ), "No datum was attached to the UTxO being spent by this Contract." + return spend_no_datum.method(contract, context.redeemer, context) + spend_with_datum = spending_methods.get("spend_with_datum") + if spend_with_datum is None: + assert False, "Contract has no spending entrypoint for attached datums." + datum_loading_strategy = _datum_loading_strategy( + spend_with_datum.datum_type + ) + datum = ( + attached_datum + if datum_loading_strategy == "attachment" + else attached_datum.datum + ) + return spend_with_datum.method(contract, datum, context.redeemer, context) + for detail in method_details: + if not isinstance(purpose, detail.spec.purpose_class): + continue + return detail.method(contract, context.redeemer, context) + return PreludeContract.raw(contract, context) + + validator.__name__ = "validator" + if has_raw_override: + return_annotations = [ + detail.return_type + for detail in method_details + if detail.spec.method_name == "raw" + ] + else: + return_annotations = [detail.return_type for detail in method_details] + if return_annotations and any( + return_type != return_annotations[0] for return_type in return_annotations[1:] + ): + raise AssertionError( + "All Contract entrypoint methods must have the same return annotation." + ) + validator.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter( + name=field_name, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=field_type, + ) + for field_name, field_type in parameter_types + ] + + [ + inspect.Parameter( + name=context_parameter_name, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=ScriptContext, + ) + ], + return_annotation=( + return_annotations[0] if return_annotations else inspect.Signature.empty + ), + ) + validator.__annotations__ = { + field_name: field_type for field_name, field_type in parameter_types + } + validator.__annotations__[context_parameter_name] = ScriptContext + if return_annotations and return_annotations[0] is not inspect.Signature.empty: + validator.__annotations__["return"] = return_annotations[0] + return validator + + +def discover_contract_module(module) -> typing.Optional[ContractModuleInfo]: + validator = getattr(module, "validator", None) + if callable(validator): + return None + contract_classes = [ + value + for value in module.__dict__.values() + if inspect.isclass(value) + and value is not PreludeContract + and value.__module__ == module.__name__ + and issubclass(value, PreludeContract) + ] + if not contract_classes: + return None + assert ( + len(contract_classes) == 1 + ), "A contract module may define only one Contract subclass." + contract_class = contract_classes[0] + parameter_types = _contract_parameter_types(contract_class) + method_details = [] + for spec in CONTRACT_METHOD_SPECS: + method = contract_class.__dict__.get(spec.method_name) + if method is None: + continue + argument_names, datum_type, redeemer_type = _method_annotations( + method, spec.onchain_argument_count + ) + method_details.append( + ContractMethodDetails( + spec=spec, + method=method, + argument_names=argument_names, + datum_type=datum_type, + redeemer_type=redeemer_type, + return_type=inspect.signature(method).return_annotation, + ) + ) + has_raw_override = any( + detail.spec.method_name == "raw" for detail in method_details + ) + generated_validator = _build_contract_validator( + contract_class, + parameter_types, + tuple(method_details), + has_raw_override, + ) + return ContractModuleInfo( + validator=generated_validator, + parameter_types=parameter_types, + method_details=tuple(method_details), + has_raw_override=has_raw_override, + ) diff --git a/opshin/prelude.py b/opshin/prelude.py index 9f8b9c5d..82e6650e 100644 --- a/opshin/prelude.py +++ b/opshin/prelude.py @@ -1,6 +1,88 @@ from opshin.ledger.api_v3 import * +class Contract: + """ + Base class for contract entrypoint classes. + + User-defined contracts should inherit from this class and expose the + supported entrypoint methods on the subclass. + + There is intentionally no generic `spend(...)` entrypoint. Spending + contracts should implement `spend_no_datum(...)`, + `spend_with_datum(...)`, or both. Contracts that need fully custom + spending dispatch logic should fall back to `raw(...)`. + """ + + def raw(self, context: ScriptContext) -> Anything: + """ + Generic entrypoint for contracts. + + By default this dispatches to the purpose-specific entrypoints on the + subclass. Override `raw(...)` to take full control over dispatch. + """ + purpose = context.purpose + if isinstance(purpose, Spending): + attached_datum = own_datum(context) + if isinstance(attached_datum, NoOutputDatum): + return self.spend_no_datum(context.redeemer, context) + return self.spend_with_datum(attached_datum, context.redeemer, context) + if isinstance(purpose, Minting): + return self.mint(context.redeemer, context) + if isinstance(purpose, Withdrawing): + return self.withdraw(context.redeemer, context) + if isinstance(purpose, Publishing): + return self.publish(context.redeemer, context) + if isinstance(purpose, Voting): + return self.vote(context.redeemer, context) + if isinstance(purpose, Proposing): + return self.propose(context.redeemer, context) + assert False, "Unsupported script purpose for Contract" + + def spend_no_datum(self, redeemer: Anything, context: ScriptContext) -> Anything: + """ + Spending entrypoint used when the spent output has no datum attached. + + If the contract also defines `spend_with_datum(...)`, the compiler + dispatches here only when the spent output datum is `NoOutputDatum`. + """ + assert False, "Contract.spend_no_datum must be overridden" + + def spend_with_datum( + self, datum: Anything, redeemer: Anything, context: ScriptContext + ) -> Anything: + """ + Spending entrypoint for contracts that need the spent output datum. + + The compiler loads the datum before calling this entrypoint. Use + `OutputDatum` to receive the wrapped attachment or a concrete datum type + to receive the unwrapped datum value. If the contract also defines + `spend_no_datum(...)`, the compiler dispatches here only when a datum + is attached. + """ + assert False, "Contract.spend_with_datum must be overridden" + + def mint(self, redeemer: Anything, context: ScriptContext) -> Anything: + """Minting entrypoint for contracts executed in a minting context.""" + assert False, "Contract.mint must be overridden" + + def withdraw(self, redeemer: Anything, context: ScriptContext) -> Anything: + """Withdrawal entrypoint for contracts executed in a rewarding context.""" + assert False, "Contract.withdraw must be overridden" + + def publish(self, redeemer: Anything, context: ScriptContext) -> Anything: + """Certificate publication entrypoint for certifying script contexts.""" + assert False, "Contract.publish must be overridden" + + def vote(self, redeemer: Anything, context: ScriptContext) -> Anything: + """Voting entrypoint for contracts executed in a voting context.""" + assert False, "Contract.vote must be overridden" + + def propose(self, redeemer: Anything, context: ScriptContext) -> Anything: + """Proposal entrypoint for contracts executed in a proposing context.""" + assert False, "Contract.propose must be overridden" + + @dataclass(unsafe_hash=True) class Nothing(PlutusData): """ diff --git a/opshin/rewrite/rewrite_contract_methods.py b/opshin/rewrite/rewrite_contract_methods.py new file mode 100644 index 00000000..03edf4bc --- /dev/null +++ b/opshin/rewrite/rewrite_contract_methods.py @@ -0,0 +1,834 @@ +import ast +from copy import deepcopy + +from ..contract_interface import CONTRACT_METHOD_SPECS +from ..util import CompilingNodeTransformer, custom_fix_missing_locations + + +class _RewriteContractSelfReferences(ast.NodeTransformer): + def __init__(self, field_name_map, helper_function_names, local_name_map): + self.field_parameter_names = tuple(field_name_map.values()) + self.field_name_map = dict(field_name_map) + self.helper_function_names = dict(helper_function_names) + self.local_name_map = dict(local_name_map) + + def visit_Attribute(self, node: ast.Attribute): + if isinstance(node.value, ast.Name) and node.value.id == "self": + assert isinstance( + node.ctx, ast.Load + ), "Contract fields are immutable and may not be assigned through self." + if node.attr in self.field_name_map: + return ast.copy_location( + ast.Name(id=self.field_name_map[node.attr], ctx=ast.Load()), + node, + ) + if node.attr in self.helper_function_names: + return ast.copy_location( + ast.Name(id=self.helper_function_names[node.attr], ctx=ast.Load()), + node, + ) + assert False, f"Contract has no field or method named '{node.attr}'." + return self.generic_visit(node) + + def visit_Name(self, node: ast.Name): + if node.id not in self.local_name_map: + return node + return ast.copy_location( + ast.Name(id=self.local_name_map[node.id], ctx=node.ctx), + node, + ) + + def visit_Call(self, node: ast.Call): + node = self.generic_visit(node) + if not isinstance(node.func, ast.Name): + return node + if node.func.id not in self.helper_function_names.values(): + return node + return ast.copy_location( + ast.Call( + func=node.func, + args=[ + ast.Name(id=field_name, ctx=ast.Load()) + for field_name in self.field_parameter_names + ] + + node.args, + keywords=node.keywords, + ), + node, + ) + + +class RewriteContractMethods(CompilingNodeTransformer): + step = "Rewriting Contract entrypoints" + _internal_name_prefix = "__contract+" + + def _is_contract_class(self, statement: ast.stmt) -> bool: + return isinstance(statement, ast.ClassDef) and any( + isinstance(base, ast.Name) and base.id == "Contract" + for base in statement.bases + ) + + def visit_Module(self, node: ast.Module) -> ast.Module: + node = self.generic_visit(node) + has_validator = any( + isinstance(statement, ast.FunctionDef) and statement.name == "validator" + for statement in node.body + ) + if has_validator: + return node + contract_classes = [ + statement for statement in node.body if self._is_contract_class(statement) + ] + if not contract_classes: + return node + assert ( + len(contract_classes) == 1 + ), "A contract module may define only one Contract subclass." + contract_class = contract_classes[0] + contract_methods = { + statement.name: statement + for statement in contract_class.body + if isinstance(statement, ast.FunctionDef) + } + supported_methods = [ + contract_method + for contract_method in CONTRACT_METHOD_SPECS + if contract_method.method_name in contract_methods + ] + self._check_contract_class(contract_class, supported_methods, contract_methods) + field_annotations = self._field_annotations(contract_class) + module_names = self._module_bound_names(node) + generated_names = set(module_names) + field_parameter_names = self._field_parameter_names( + field_annotations, generated_names + ) + helper_function_names = self._helper_function_names( + contract_methods, generated_names + ) + method_argument_name_maps = self._method_argument_name_maps( + contract_methods, generated_names + ) + supported_method_names = { + supported_method.method_name for supported_method in supported_methods + } + referenced_methods = self._self_method_references(contract_methods.values()) + lifted_methods = [ + self._lift_contract_method( + method, + field_annotations, + field_parameter_names, + helper_function_names, + method_argument_name_maps[method.name], + ) + for method in contract_methods.values() + if method.name not in supported_method_names + or method.name in referenced_methods + ] + rewritten_entrypoint_bodies = { + supported_method.method_name: self._rewrite_method_body( + contract_methods[supported_method.method_name], + field_parameter_names, + helper_function_names, + method_argument_name_maps[supported_method.method_name], + ) + for supported_method in supported_methods + } + context_argument_name = self._make_reserved_name("context", generated_names) + generated_names.add(context_argument_name) + if "raw" in contract_methods: + return_annotation = deepcopy(contract_methods["raw"].returns) + elif supported_methods: + return_annotation = deepcopy( + contract_methods[supported_methods[0].method_name].returns + ) + for supported_method in supported_methods[1:]: + candidate = contract_methods[supported_method.method_name].returns + assert ast.dump(candidate) == ast.dump( + return_annotation + ), "All Contract entrypoint methods must have the same return annotation." + else: + return_annotation = ast.Name(id="Anything", ctx=ast.Load()) + validator_function = ast.FunctionDef( + name="validator", + args=ast.arguments( + posonlyargs=[], + args=[ + ast.arg( + arg=field_parameter_names[field_name], + annotation=deepcopy(annotation), + ) + for field_name, annotation in field_annotations + ] + + [ + ast.arg( + arg=context_argument_name, + annotation=ast.Name(id="ScriptContext", ctx=ast.Load()), + ) + ], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=self._validator_body( + field_parameter_names, + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + ), + decorator_list=[], + returns=return_annotation, + type_comment=None, + ) + validator_function = custom_fix_missing_locations( + ast.copy_location(validator_function, contract_class), contract_class + ) + rewritten_body = [] + for statement in node.body: + if statement is contract_class: + rewritten_body.extend(lifted_methods) + rewritten_body.append(validator_function) + else: + rewritten_body.append(statement) + node.body = rewritten_body + return node + + def _module_bound_names(self, node: ast.Module): + bound_names = set() + for statement in node.body: + if isinstance( + statement, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) + ): + bound_names.add(statement.name) + elif isinstance(statement, ast.Import): + for alias in statement.names: + bound_names.add(alias.asname or alias.name.split(".")[0]) + elif isinstance(statement, ast.ImportFrom): + for alias in statement.names: + bound_names.add(alias.asname or alias.name) + elif isinstance(statement, ast.Assign): + for target in statement.targets: + if isinstance(target, ast.Name): + bound_names.add(target.id) + elif isinstance(statement, ast.AnnAssign) and isinstance( + statement.target, ast.Name + ): + bound_names.add(statement.target.id) + return bound_names + + def _check_contract_class( + self, contract_class, supported_methods, contract_methods + ): + for supported_method in supported_methods: + self._check_method_signature( + contract_methods[supported_method.method_name], supported_method + ) + assert not any( + ( + isinstance(statement, ast.AnnAssign) + and isinstance(statement.target, ast.Name) + and statement.target.id == "CONSTR_ID" + ) + or ( + isinstance(statement, ast.Assign) + and any( + isinstance(target, ast.Name) and target.id == "CONSTR_ID" + for target in statement.targets + ) + ) + for statement in contract_class.body + ), "Contract classes must not define CONSTR_ID." + assert not any( + isinstance(statement, ast.Assign) + and any(isinstance(target, ast.Name) for target in statement.targets) + for statement in contract_class.body + ), "Contract fields must be annotated; unannotated Contract fields are not supported." + + def _field_annotations(self, contract_class): + field_annotations = [] + for statement in contract_class.body: + if not isinstance(statement, ast.AnnAssign): + continue + assert isinstance( + statement.target, ast.Name + ), "Contract fields must be named attributes." + field_annotations.append( + (statement.target.id, deepcopy(statement.annotation)) + ) + return field_annotations + + def _field_parameter_names(self, field_annotations, used_names): + field_parameter_names = {} + for field_name, _ in field_annotations: + internal_name = self._make_reserved_name(f"field_{field_name}", used_names) + field_parameter_names[field_name] = internal_name + used_names.add(internal_name) + return field_parameter_names + + def _helper_function_names(self, contract_methods, used_names): + helper_function_names = {} + for method_name in contract_methods: + helper_function_name = self._make_reserved_name( + f"method_{method_name}", used_names + ) + helper_function_names[method_name] = helper_function_name + used_names.add(helper_function_name) + return helper_function_names + + def _method_argument_name_maps(self, contract_methods, used_names): + method_argument_name_maps = {} + for method_name, method in contract_methods.items(): + argument_name_map = {} + for argument in method.args.args[1:]: + internal_name = self._make_reserved_name( + f"arg_{method_name}_{argument.arg}", used_names + ) + argument_name_map[argument.arg] = internal_name + used_names.add(internal_name) + method_argument_name_maps[method_name] = argument_name_map + return method_argument_name_maps + + def _self_method_references(self, methods): + referenced_methods = set() + for method in methods: + for node in ast.walk(method): + if not isinstance(node, ast.Call): + continue + if not isinstance(node.func, ast.Attribute): + continue + if not ( + isinstance(node.func.value, ast.Name) + and node.func.value.id == "self" + ): + continue + referenced_methods.add(node.func.attr) + return referenced_methods + + def _rewrite_method_body( + self, + method, + field_parameter_names, + helper_function_names, + argument_name_map, + ): + body_rewriter = _RewriteContractSelfReferences( + field_parameter_names, helper_function_names, argument_name_map + ) + return [body_rewriter.visit(deepcopy(statement)) for statement in method.body] + + def _lift_contract_method( + self, + method: ast.FunctionDef, + field_annotations, + field_parameter_names, + helper_function_names, + argument_name_map, + ) -> ast.FunctionDef: + assert not method.decorator_list, "Contract methods must not have decorators." + assert ( + not method.args.posonlyargs + and not method.args.vararg + and not method.args.kwonlyargs + and not method.args.kwarg + and not method.args.defaults + and not method.args.kw_defaults + ), "Contract methods must use plain positional parameters without defaults." + lifted_method = ast.FunctionDef( + name=helper_function_names[method.name], + args=ast.arguments( + posonlyargs=[], + args=[ + ast.arg( + arg=field_parameter_names[field_name], + annotation=deepcopy(annotation), + ) + for field_name, annotation in field_annotations + ] + + [ + ast.arg( + arg=argument_name_map[argument.arg], + annotation=deepcopy(argument.annotation), + ) + for argument in method.args.args[1:] + ], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=self._rewrite_method_body( + method, + field_parameter_names, + helper_function_names, + argument_name_map, + ), + decorator_list=[], + returns=deepcopy(method.returns), + type_comment=method.type_comment, + ) + return custom_fix_missing_locations( + ast.copy_location(lifted_method, method), method + ) + + def _make_reserved_name(self, preferred_name: str, used_names): + reserved_name = f"{self._internal_name_prefix}{preferred_name}" + if reserved_name not in used_names: + return reserved_name + suffix = 0 + while f"{reserved_name}_{suffix}" in used_names: + suffix += 1 + return f"{reserved_name}_{suffix}" + + def _check_method_signature(self, method: ast.FunctionDef, supported_method): + assert ( + not method.args.posonlyargs + and not method.args.vararg + and not method.args.kwonlyargs + and not method.args.kwarg + and not method.args.defaults + and not method.args.kw_defaults + ), f"Contract method '{method.name}' must use plain positional parameters without defaults." + actual_arguments = tuple(argument.arg for argument in method.args.args) + assert ( + actual_arguments + ), f"Contract method '{method.name}' must accept self as first parameter." + assert ( + actual_arguments[0] == "self" + ), f"Contract method '{method.name}' must accept self as first parameter." + expected_argument_count = supported_method.onchain_argument_count + 1 + assert len(actual_arguments) == expected_argument_count, ( + f"Contract method '{method.name}' must accept self plus " + f"{supported_method.onchain_argument_count} on-chain parameters." + ) + context_annotation = method.args.args[-1].annotation + if context_annotation is not None: + assert ( + isinstance(context_annotation, ast.Name) + and context_annotation.id == "ScriptContext" + ), f"Contract method '{method.name}' must annotate context as ScriptContext." + + def _annotation_union_members(self, annotation): + if isinstance(annotation, ast.Name) and annotation.id == "OutputDatum": + return [annotation] + if ( + isinstance(annotation, ast.Subscript) + and isinstance(annotation.value, ast.Name) + and annotation.value.id == "Union" + ): + if isinstance(annotation.slice, ast.Tuple): + return list(annotation.slice.elts) + return [annotation.slice] + return [annotation] + + def _datum_loading_strategy(self, annotation): + if isinstance(annotation, ast.Name) and annotation.id == "OutputDatum": + return "attachment" + union_members = self._annotation_union_members(annotation) + member_names = { + member.id for member in union_members if isinstance(member, ast.Name) + } + attachment_names = {"NoOutputDatum", "SomeOutputDatum", "SomeOutputDatumHash"} + if len(member_names) == len(union_members) and member_names.issubset( + attachment_names + ): + return "attachment" + assert ( + "NoOutputDatum" not in member_names + ), "Contracts must use spend_no_datum instead of Union[..., NoOutputDatum]." + return "unsafe_raw" + + def _validator_body( + self, + field_parameter_names, + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + ): + used_names = set(field_parameter_names.values()) | {context_argument_name} + body = [] + if "raw" in contract_methods: + context_name = method_argument_name_maps["raw"][ + contract_methods["raw"].args.args[1].arg + ] + if context_name != context_argument_name: + body.append( + ast.Assign( + targets=[ast.Name(id=context_name, ctx=ast.Store())], + value=ast.Name(id=context_argument_name, ctx=ast.Load()), + ) + ) + body.extend(deepcopy(rewritten_entrypoint_bodies["raw"])) + return body + + branch_specs = [] + supported_method_names = { + method_name + for method_name in contract_methods + if method_name in {spec.method_name for spec in CONTRACT_METHOD_SPECS} + } + if ( + "spend_no_datum" in supported_method_names + or "spend_with_datum" in supported_method_names + ): + branch_specs.append( + ( + "Spending", + self._spending_body( + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + used_names, + ), + ) + ) + + for method_name, purpose_class_name in ( + ("mint", "Minting"), + ("withdraw", "Withdrawing"), + ("publish", "Publishing"), + ("vote", "Voting"), + ("propose", "Proposing"), + ): + if method_name not in contract_methods: + continue + branch_specs.append( + ( + purpose_class_name, + self._specialized_entrypoint_body( + method_name, + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + ), + ) + ) + + if len(branch_specs) == 1: + return branch_specs[0][1] + + if not branch_specs: + branch_specs = [ + ( + "Spending", + self._spending_body( + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + used_names, + ), + ) + ] + for method_name, purpose_class_name in ( + ("mint", "Minting"), + ("withdraw", "Withdrawing"), + ("publish", "Publishing"), + ("vote", "Voting"), + ("propose", "Proposing"), + ): + branch_specs.append( + ( + purpose_class_name, + self._specialized_entrypoint_body( + method_name, + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + ), + ) + ) + + purpose_name = self._make_reserved_name("purpose", used_names) + used_names.add(purpose_name) + body.append( + ast.Assign( + targets=[ast.Name(id=purpose_name, ctx=ast.Store())], + value=ast.Attribute( + value=ast.Name(id=context_argument_name, ctx=ast.Load()), + attr="purpose", + ctx=ast.Load(), + ), + ) + ) + current_branch = None + for index, (purpose_class_name, branch_body) in enumerate(branch_specs): + branch = ast.If( + test=ast.Call( + func=ast.Name(id="isinstance", ctx=ast.Load()), + args=[ + ast.Name(id=purpose_name, ctx=ast.Load()), + ast.Name(id=purpose_class_name, ctx=ast.Load()), + ], + keywords=[], + ), + body=branch_body, + orelse=[], + ) + if index == 0: + body.append(branch) + current_branch = branch + else: + current_branch.orelse = [branch] + current_branch = branch + current_branch.orelse = [ + ast.Assert( + test=ast.Constant(value=False), + msg=ast.Constant(value="Unsupported script purpose for Contract"), + ) + ] + return body + + def _spending_body( + self, + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + used_names, + ): + if ( + "spend_no_datum" not in contract_methods + and "spend_with_datum" not in contract_methods + ): + return self._missing_entrypoint_body("spend_no_datum") + if ( + "spend_no_datum" not in contract_methods + and "spend_with_datum" in contract_methods + ): + method = contract_methods["spend_with_datum"] + datum_annotation = method.args.args[1].annotation + if self._datum_loading_strategy(datum_annotation) == "unsafe_raw": + return self._spend_with_datum_unsafe_body( + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + ) + + spending_used_names = set(used_names) + attached_datum_name = self._make_reserved_name( + "attached_datum", spending_used_names + ) + spending_body = [ + ast.AnnAssign( + target=ast.Name(id=attached_datum_name, ctx=ast.Store()), + annotation=ast.Name(id="OutputDatum", ctx=ast.Load()), + value=ast.Call( + func=ast.Name(id="own_datum", ctx=ast.Load()), + args=[ast.Name(id=context_argument_name, ctx=ast.Load())], + keywords=[], + ), + simple=1, + ) + ] + + no_datum_body = self._specialized_entrypoint_body( + "spend_no_datum", + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + ) + with_datum_body = self._spend_with_datum_body( + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + attached_datum_name, + ) + spending_body.append( + ast.If( + test=ast.Call( + func=ast.Name(id="isinstance", ctx=ast.Load()), + args=[ + ast.Name(id=attached_datum_name, ctx=ast.Load()), + ast.Name(id="NoOutputDatum", ctx=ast.Load()), + ], + keywords=[], + ), + body=no_datum_body, + orelse=with_datum_body, + ) + ) + return spending_body + + def _missing_entrypoint_body(self, method_name): + return [ + ast.Assert( + test=ast.Constant(value=False), + msg=ast.Constant(value=f"Contract.{method_name} must be overridden"), + ) + ] + + def _specialized_entrypoint_body( + self, + method_name, + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + ): + method = contract_methods.get(method_name) + if method is None: + return self._missing_entrypoint_body(method_name) + method_argument_names = [argument.arg for argument in method.args.args[1:]] + branch_names = method_argument_name_maps[method_name] + branch_body = [] + context_name = branch_names[method_argument_names[-1]] + if context_name != context_argument_name: + branch_body.append( + ast.Assign( + targets=[ast.Name(id=context_name, ctx=ast.Store())], + value=ast.Name(id=context_argument_name, ctx=ast.Load()), + ) + ) + redeemer_argument_name = method_argument_names[0] + redeemer_name = branch_names[redeemer_argument_name] + redeemer_annotation = deepcopy(method.args.args[1].annotation) + branch_body.extend( + [ + ast.AnnAssign( + target=ast.Name(id=redeemer_name, ctx=ast.Store()), + annotation=redeemer_annotation, + value=ast.Attribute( + value=ast.Name(id=context_name, ctx=ast.Load()), + attr="redeemer", + ctx=ast.Load(), + ), + simple=1, + ), + ] + + deepcopy(rewritten_entrypoint_bodies[method_name]) + ) + return branch_body + + def _spend_with_datum_body( + self, + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + attached_datum_name, + ): + method = contract_methods.get("spend_with_datum") + if method is None: + return self._missing_entrypoint_body("spend_with_datum") + branch_names = method_argument_name_maps["spend_with_datum"] + datum_name = branch_names[method.args.args[1].arg] + redeemer_name = branch_names[method.args.args[2].arg] + context_name = branch_names[method.args.args[3].arg] + datum_annotation = deepcopy(method.args.args[1].annotation) + redeemer_annotation = deepcopy(method.args.args[2].annotation) + datum_loading_strategy = self._datum_loading_strategy(datum_annotation) + with_datum_body = [] + if context_name != context_argument_name: + with_datum_body.append( + ast.Assign( + targets=[ast.Name(id=context_name, ctx=ast.Store())], + value=ast.Name(id=context_argument_name, ctx=ast.Load()), + ) + ) + with_datum_body.append( + ast.AnnAssign( + target=ast.Name(id=redeemer_name, ctx=ast.Store()), + annotation=redeemer_annotation, + value=ast.Attribute( + value=ast.Name(id=context_name, ctx=ast.Load()), + attr="redeemer", + ctx=ast.Load(), + ), + simple=1, + ) + ) + with_datum_body.append( + ast.Assert( + test=ast.Call( + func=ast.Name(id="isinstance", ctx=ast.Load()), + args=[ + ast.Name(id=attached_datum_name, ctx=ast.Load()), + ast.Name(id="SomeOutputDatum", ctx=ast.Load()), + ], + keywords=[], + ), + msg=ast.Constant( + value="No datum was attached to the UTxO being spent by this Contract." + ), + ) + if datum_loading_strategy == "unsafe_raw" + else ast.Pass() + ) + with_datum_value = ( + ast.Name(id=attached_datum_name, ctx=ast.Load()) + if datum_loading_strategy == "attachment" + else ast.Attribute( + value=ast.Name(id=attached_datum_name, ctx=ast.Load()), + attr="datum", + ctx=ast.Load(), + ) + ) + with_datum_body.append( + ast.AnnAssign( + target=ast.Name(id=datum_name, ctx=ast.Store()), + annotation=datum_annotation, + value=with_datum_value, + simple=1, + ) + ) + with_datum_body.extend( + deepcopy(rewritten_entrypoint_bodies["spend_with_datum"]) + ) + return with_datum_body + + def _spend_with_datum_unsafe_body( + self, + contract_methods, + method_argument_name_maps, + rewritten_entrypoint_bodies, + context_argument_name, + ): + method = contract_methods.get("spend_with_datum") + if method is None: + return self._missing_entrypoint_body("spend_with_datum") + branch_names = method_argument_name_maps["spend_with_datum"] + datum_name = branch_names[method.args.args[1].arg] + redeemer_name = branch_names[method.args.args[2].arg] + context_name = branch_names[method.args.args[3].arg] + datum_annotation = deepcopy(method.args.args[1].annotation) + redeemer_annotation = deepcopy(method.args.args[2].annotation) + body = [] + if context_name != context_argument_name: + body.append( + ast.Assign( + targets=[ast.Name(id=context_name, ctx=ast.Store())], + value=ast.Name(id=context_argument_name, ctx=ast.Load()), + ) + ) + body.append( + ast.AnnAssign( + target=ast.Name(id=datum_name, ctx=ast.Store()), + annotation=datum_annotation, + value=ast.Call( + func=ast.Name(id="own_datum_unsafe", ctx=ast.Load()), + args=[ast.Name(id=context_name, ctx=ast.Load())], + keywords=[], + ), + simple=1, + ) + ) + body.append( + ast.AnnAssign( + target=ast.Name(id=redeemer_name, ctx=ast.Store()), + annotation=redeemer_annotation, + value=ast.Attribute( + value=ast.Name(id=context_name, ctx=ast.Load()), + attr="redeemer", + ctx=ast.Load(), + ), + simple=1, + ) + ) + body.extend(deepcopy(rewritten_entrypoint_bodies["spend_with_datum"])) + return body diff --git a/opshin/rewrite/rewrite_import.py b/opshin/rewrite/rewrite_import.py index 052d40b3..dce6c03c 100644 --- a/opshin/rewrite/rewrite_import.py +++ b/opshin/rewrite/rewrite_import.py @@ -120,5 +120,13 @@ def visit_ImportFrom( resolved_imports=self.resolved_imports, ) recursively_resolved: Module = recursive_resolver.visit(resolved) + if module.__name__ == "opshin.prelude": + recursively_resolved.body = [ + statement + for statement in recursively_resolved.body + if not ( + isinstance(statement, ast.ClassDef) and statement.name == "Contract" + ) + ] self.resolved_imports.update(recursive_resolver.resolved_imports) return recursively_resolved.body diff --git a/opshin/rewrite/rewrite_import_dataclasses.py b/opshin/rewrite/rewrite_import_dataclasses.py index b587ae54..8afd74a3 100644 --- a/opshin/rewrite/rewrite_import_dataclasses.py +++ b/opshin/rewrite/rewrite_import_dataclasses.py @@ -102,7 +102,14 @@ def visit_Call(self, node: Call) -> Call: ), "astuple must be imported via 'from dataclasses import astuple'" return node + def _is_contract_class(self, node: ClassDef) -> bool: + return any( + isinstance(base, Name) and base.id == "Contract" for base in node.bases + ) + def visit_ClassDef(self, node: ClassDef) -> ClassDef: + if self._is_contract_class(node): + return node assert ( self.imports_dataclasses ), "dataclasses must be imported in order to use datum classes" diff --git a/opshin/rewrite/rewrite_import_plutusdata.py b/opshin/rewrite/rewrite_import_plutusdata.py index 47298ab4..622af9fd 100644 --- a/opshin/rewrite/rewrite_import_plutusdata.py +++ b/opshin/rewrite/rewrite_import_plutusdata.py @@ -13,6 +13,11 @@ class RewriteImportPlutusData(CompilingNodeTransformer): imports_plutus_data = False + def _is_contract_class(self, node: ClassDef) -> bool: + return any( + isinstance(base, Name) and base.id == "Contract" for base in node.bases + ) + def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]: if node.module != "pycardano": return node @@ -35,6 +40,8 @@ def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]: return None def visit_ClassDef(self, node: ClassDef) -> ClassDef: + if self._is_contract_class(node): + return node assert ( len(node.decorator_list) == 1 ), f"Class definitions must have no decorators but @dataclass, {node.name} has {tuple(node.decorator_list)}" diff --git a/scripts/binary_size_tracker.py b/scripts/binary_size_tracker.py index 0d071f29..43175b26 100755 --- a/scripts/binary_size_tracker.py +++ b/scripts/binary_size_tracker.py @@ -17,6 +17,7 @@ import shutil import subprocess import sys +import tempfile import yaml from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -39,27 +40,23 @@ def load_config(config_file: Optional[str] = None) -> Dict: { "name": "assert_sum", "path": "examples/smart_contracts/assert_sum.py", - "purpose": "spending", "description": "Simple spending validator with assertion", }, { "name": "marketplace", "path": "examples/smart_contracts/marketplace.py", - "purpose": "spending", "description": "Marketplace contract with complex data structures", }, { "name": "gift", "path": "examples/smart_contracts/gift.py", - "purpose": "spending", "description": "Gift contract with simple logic", }, { - "name": "dual_use", - "path": "examples/smart_contracts/dual_use.py", - "purpose": "spending", - "extra_flags": ["-fforce-three-params"], - "description": "Dual-use contract with multiple entry points", + "name": "wrapped_token", + "path": "examples/smart_contracts/wrapped_token.py", + "extra_flags": ["--parameters", "3"], + "description": "Dual-use contract to generate a wrapped token", }, ], "optimization_levels": ["O1", "O2", "O3"], @@ -92,7 +89,7 @@ def compile_contract( if work_dir is None: work_dir = os.getcwd() - output_dir = f"size_test_{optimization}" + output_dir = tempfile.mkdtemp(prefix=f"size_test_{optimization}_", dir=work_dir) cmd = [ "uv", @@ -114,16 +111,18 @@ def compile_contract( if exit_code != 0: print(f"Failed to compile {contract_path} with {optimization}: {stderr}") - raise Exception(f"Compilation failed: {stderr}") + shutil.rmtree(output_dir, ignore_errors=True) + return None - cbor_file = Path(work_dir) / output_dir / "script.cbor" + cbor_file = Path(output_dir) / "script.cbor" if cbor_file.exists(): size = cbor_file.stat().st_size # Clean up - shutil.rmtree(Path(work_dir) / output_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) return size else: print(f"CBOR file not found for {contract_path}") + shutil.rmtree(output_dir, ignore_errors=True) return None diff --git a/scripts/check_binary_sizes.py b/scripts/check_binary_sizes.py index b1351bbb..825be5b0 100755 --- a/scripts/check_binary_sizes.py +++ b/scripts/check_binary_sizes.py @@ -9,6 +9,7 @@ import argparse import os import requests +import subprocess import sys import tempfile from pathlib import Path @@ -57,6 +58,20 @@ def download_latest_baseline(repo: str = "OpShin/opshin") -> str: return None +def run_tracker_comparison(tracker_script: Path, baseline_file: str) -> int: + result = subprocess.run( + [ + sys.executable, + str(tracker_script), + "compare", + "--baseline-file", + baseline_file, + ], + check=False, + ) + return result.returncode + + def main(): parser = argparse.ArgumentParser( description="Check binary sizes against latest release baseline" @@ -95,7 +110,9 @@ def main(): try: # Run the comparison print("\nRunning binary size comparison...") - os.system(f"python {tracker_script} compare --baseline-file {baseline_file}") + return_code = run_tracker_comparison(tracker_script, baseline_file) + if return_code != 0: + sys.exit(return_code) finally: # Clean up temp file if we downloaded it diff --git a/tests/test_binary_size_tools.py b/tests/test_binary_size_tools.py new file mode 100644 index 00000000..b8002579 --- /dev/null +++ b/tests/test_binary_size_tools.py @@ -0,0 +1,59 @@ +import importlib.util +import tempfile +import unittest +from pathlib import Path +from unittest.mock import Mock, patch + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +def load_module(module_name: str, path: str): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +CHECK_BINARY_SIZES = load_module( + "check_binary_sizes", str(REPO_ROOT / "scripts" / "check_binary_sizes.py") +) +BINARY_SIZE_TRACKER = load_module( + "binary_size_tracker", str(REPO_ROOT / "scripts" / "binary_size_tracker.py") +) + + +class BinarySizeToolTests(unittest.TestCase): + def test_compile_contract_returns_none_on_failure(self): + with patch.object( + BINARY_SIZE_TRACKER, + "run_command", + return_value=(1, "", "boom"), + ): + size = BINARY_SIZE_TRACKER.compile_contract( + "examples/smart_contracts/assert_sum.py", + "O1", + work_dir=str(REPO_ROOT), + ) + self.assertIsNone(size) + + def test_fallback_config_uses_current_contracts(self): + config = BINARY_SIZE_TRACKER.load_config("/definitely/missing/config.yaml") + contract_names = {contract["name"] for contract in config["contracts"]} + self.assertIn("wrapped_token", contract_names) + self.assertNotIn("dual_use", contract_names) + + def test_run_tracker_comparison_propagates_failure(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as baseline_file: + completed_process = Mock(returncode=7) + with patch.object( + CHECK_BINARY_SIZES.subprocess, + "run", + return_value=completed_process, + ) as run_mock: + return_code = CHECK_BINARY_SIZES.run_tracker_comparison( + Path("/tmp/binary_size_tracker.py"), + baseline_file.name, + ) + self.assertEqual(return_code, 7) + run_mock.assert_called_once() diff --git a/tests/test_contract_class.py b/tests/test_contract_class.py new file mode 100644 index 00000000..a49365ef --- /dev/null +++ b/tests/test_contract_class.py @@ -0,0 +1,513 @@ +import subprocess +import tempfile +import types +import unittest +import ast + +from opshin.contract_interface import discover_contract_module +from opshin.prelude import * +from opshin.rewrite.rewrite_contract_methods import RewriteContractMethods +from opshin.util import CompilerError +from .utils import eval_uplc_value + + +def make_tx_info(inputs, purpose, redeemer): + return TxInfo( + inputs=inputs, + reference_inputs=[], + outputs=[], + fee=0, + mint={}, + certificates=[], + withdrawals={}, + validity_range=POSIXTimeRange( + lower_bound=LowerBoundPOSIXTime(FinitePOSIXTime(0), TrueData()), + upper_bound=UpperBoundPOSIXTime(PosInfPOSIXTime(), FalseData()), + ), + signatories=[], + redeemers={purpose: redeemer}, + datums={}, + id=b"\x01" * 32, + votes={}, + proposal_procedures=[], + current_treasury_amount=NoValue(), + treasury_donation=NoValue(), + ) + + +def make_spending_context(datum, redeemer): + out_ref = TxOutRef(id=b"\x02" * 32, idx=0) + purpose = Spending(tx_out_ref=out_ref) + inputs = [ + TxInInfo( + out_ref=out_ref, + resolved=TxOut( + address=Address( + payment_credential=ScriptCredential(b"\x03" * 28), + staking_credential=NoStakingCredential(), + ), + value={b"": {b"": 0}}, + datum=SomeOutputDatum(datum), + reference_script=NoScriptHash(), + ), + ) + ] + return ScriptContext( + transaction=make_tx_info(inputs, purpose, redeemer), + redeemer=redeemer, + purpose=purpose, + ) + + +def make_spending_context_without_datum(redeemer): + out_ref = TxOutRef(id=b"\x05" * 32, idx=0) + purpose = Spending(tx_out_ref=out_ref) + inputs = [ + TxInInfo( + out_ref=out_ref, + resolved=TxOut( + address=Address( + payment_credential=ScriptCredential(b"\x06" * 28), + staking_credential=NoStakingCredential(), + ), + value={b"": {b"": 0}}, + datum=NoOutputDatum(), + reference_script=NoScriptHash(), + ), + ) + ] + return ScriptContext( + transaction=make_tx_info(inputs, purpose, redeemer), + redeemer=redeemer, + purpose=purpose, + ) + + +def make_minting_context(redeemer): + purpose = Minting(policy_id=b"\x04" * 28) + return ScriptContext( + transaction=make_tx_info([], purpose, redeemer), + redeemer=redeemer, + purpose=purpose, + ) + + +CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class ArithmeticContract(Contract): + offset: int + + def spend_with_datum( + self, datum: int, redeemer: int, context: ScriptContext + ) -> int: + return datum + redeemer + self.offset + + def mint(self, redeemer: int, context: ScriptContext) -> int: + return redeemer * self.offset +""" + +SPEND_WITHOUT_DATUM_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class NoDatumContract(Contract): + offset: int + + def spend_no_datum(self, redeemer: int, context: ScriptContext) -> int: + return redeemer + self.offset +""" + +RAW_CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class RawContract(Contract): + offset: int + + def raw(self, script_context: ScriptContext) -> int: + redeemer: int = script_context.redeemer + return self.offset + redeemer +""" + +RAW_WITH_HELPER_ENTRYPOINT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class RawWithHelperContract(Contract): + offset: int + + def raw(self, context: ScriptContext) -> int: + redeemer: int = context.redeemer + return self.spend_no_datum(redeemer, context) + + def spend_no_datum(self, redeemer: int, context: ScriptContext) -> int: + return self.offset + redeemer +""" + +COLLIDING_NAMES_CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class CollidingNamesContract(Contract): + context: int + redeemer: int + + def mint(self, policy_redeemer: int, script_context: ScriptContext) -> int: + return self.context + self.redeemer + policy_redeemer +""" + +INVALID_CONSTR_ID_CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class InvalidConstrIdContract(Contract): + CONSTR_ID = 0 + + def raw(self, context: ScriptContext) -> None: + pass +""" + +INVALID_UNANNOTATED_FIELD_CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class InvalidUnannotatedFieldContract(Contract): + offset = 0 + + def raw(self, context: ScriptContext) -> None: + pass +""" + +OUTPUT_DATUM_CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class OutputDatumContract(Contract): + def spend_with_datum( + self, datum: OutputDatum, redeemer: int, context: ScriptContext + ) -> int: + if isinstance(datum, NoOutputDatum): + return redeemer + assert isinstance(datum, SomeOutputDatum) + unwrapped_datum: int = datum.datum + return unwrapped_datum + redeemer +""" + +SPEND_WITH_DATUM_ONLY_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class SpendWithDatumOnlyContract(Contract): + def spend_with_datum( + self, datum: int, redeemer: int, context: ScriptContext + ) -> int: + return datum + redeemer +""" + +DOUBLE_SPENDING_ENTRYPOINTS_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class DualSpendingContract(Contract): + def spend_no_datum(self, redeemer: int, context: ScriptContext) -> int: + return redeemer + + def spend_with_datum(self, datum: int, redeemer: int, context: ScriptContext) -> int: + return datum + redeemer +""" + +INVALID_OPTIONAL_DATUM_CONTRACT_SOURCE = """ +from typing import Union + +from opshin.prelude import * + +@dataclass() +class InvalidOptionalDatumContract(Contract): + offset: int + + def spend_with_datum( + self, datum: Union[int, NoOutputDatum], redeemer: int, context: ScriptContext + ) -> int: + return self.offset + redeemer +""" + +HELPER_METHOD_CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class HelperMethodContract(Contract): + offset: int + + def add_offset(self, value: int) -> int: + return self.offset + value + + def mint(self, redeemer: int, context: ScriptContext) -> int: + return self.add_offset(redeemer) +""" + +LOCAL_PURPOSE_NAME_CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class LocalPurposeNameContract(Contract): + def mint(self, redeemer: int, context: ScriptContext) -> int: + purpose = context.purpose + assert isinstance(purpose, Minting) + return redeemer +""" + +RENAMED_CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class MyContract(Contract): + offset: int + + def spend_no_datum(self, redeemer: int, context: ScriptContext) -> int: + return self.offset + redeemer +""" + +EMPTY_CONTRACT_SOURCE = """ +from opshin.prelude import * + +@dataclass() +class EmptyContract(Contract): + offset: int +""" + + +class ContractClassTests(unittest.TestCase): + def test_contract_spend_with_datum_dispatches_through_validator(self): + ret = eval_uplc_value(CONTRACT_SOURCE, 7, make_spending_context(2, 3)) + self.assertEqual(ret, 12) + + def test_contract_spend_no_datum_dispatches_without_datum(self): + ret = eval_uplc_value( + SPEND_WITHOUT_DATUM_SOURCE, 7, make_spending_context_without_datum(3) + ) + self.assertEqual(ret, 10) + + def test_contract_spend_no_datum_rejects_attached_datum(self): + with self.assertRaises(RuntimeError): + eval_uplc_value(SPEND_WITHOUT_DATUM_SOURCE, 7, make_spending_context(2, 3)) + + def test_contract_mint_dispatches_through_validator(self): + ret = eval_uplc_value(CONTRACT_SOURCE, 5, make_minting_context(4)) + self.assertEqual(ret, 20) + + def test_contract_raw_dispatches_through_validator(self): + ret = eval_uplc_value(RAW_CONTRACT_SOURCE, 5, make_minting_context(4)) + self.assertEqual(ret, 9) + + def test_contract_raw_override_may_call_specialized_entrypoints(self): + ret = eval_uplc_value( + RAW_WITH_HELPER_ENTRYPOINT_SOURCE, + 5, + make_spending_context_without_datum(4), + ) + self.assertEqual(ret, 9) + + def test_contract_spend_supports_output_datum_annotation(self): + ret = eval_uplc_value(OUTPUT_DATUM_CONTRACT_SOURCE, make_spending_context(2, 3)) + self.assertEqual(ret, 5) + + def test_contract_helper_methods_are_lifted(self): + ret = eval_uplc_value(HELPER_METHOD_CONTRACT_SOURCE, 5, make_minting_context(4)) + self.assertEqual(ret, 9) + + def test_contract_rewrite_avoids_local_name_collisions(self): + ret = eval_uplc_value( + LOCAL_PURPOSE_NAME_CONTRACT_SOURCE, make_minting_context(4) + ) + self.assertEqual(ret, 4) + + def test_runtime_contract_discovery_builds_validator(self): + module = types.ModuleType("contract_module") + exec(CONTRACT_SOURCE, module.__dict__) + contract_info = discover_contract_module(module) + self.assertIsNotNone(contract_info) + self.assertEqual(contract_info.purpose_names, ("spending", "minting")) + self.assertEqual(contract_info.validator(3, make_minting_context(6)), 18) + + def test_runtime_contract_discovery_builds_raw_validator(self): + module = types.ModuleType("contract_module") + exec(RAW_CONTRACT_SOURCE, module.__dict__) + contract_info = discover_contract_module(module) + self.assertIsNotNone(contract_info) + self.assertEqual(contract_info.purpose_names, ("any",)) + self.assertEqual(contract_info.validator(3, make_minting_context(6)), 9) + + def test_runtime_contract_discovery_builds_inherited_raw_validator(self): + module = types.ModuleType("contract_module") + exec(EMPTY_CONTRACT_SOURCE, module.__dict__) + contract_info = discover_contract_module(module) + self.assertIsNotNone(contract_info) + with self.assertRaises(AssertionError): + contract_info.validator(3, make_minting_context(6)) + + def test_runtime_contract_discovery_builds_spend_no_datum_validator(self): + module = types.ModuleType("contract_module") + exec(SPEND_WITHOUT_DATUM_SOURCE, module.__dict__) + contract_info = discover_contract_module(module) + self.assertIsNotNone(contract_info) + self.assertEqual( + contract_info.validator(3, make_spending_context_without_datum(7)), 10 + ) + + def test_runtime_contract_discovery_rejects_attached_datum_for_spend_no_datum(self): + module = types.ModuleType("contract_module") + exec(SPEND_WITHOUT_DATUM_SOURCE, module.__dict__) + contract_info = discover_contract_module(module) + self.assertIsNotNone(contract_info) + with self.assertRaises(AssertionError): + contract_info.validator(3, make_spending_context(2, 7)) + + def test_runtime_contract_discovery_finds_contract_subclass(self): + module = types.ModuleType("contract_module") + exec(RENAMED_CONTRACT_SOURCE, module.__dict__) + contract_info = discover_contract_module(module) + self.assertIsNotNone(contract_info) + self.assertEqual( + contract_info.validator(3, make_spending_context_without_datum(7)), 10 + ) + + def test_contract_handles_parameter_name_collisions(self): + ret = eval_uplc_value( + COLLIDING_NAMES_CONTRACT_SOURCE, + 10, + 20, + make_minting_context(7), + ) + self.assertEqual(ret, 37) + + def test_contract_rejects_constr_id_definition(self): + with self.assertRaises(CompilerError) as exc: + eval_uplc_value(INVALID_CONSTR_ID_CONTRACT_SOURCE, make_minting_context(0)) + self.assertIsInstance(exc.exception.orig_err, AssertionError) + + def test_contract_rejects_unannotated_field_definition(self): + with self.assertRaises(CompilerError) as exc: + eval_uplc_value( + INVALID_UNANNOTATED_FIELD_CONTRACT_SOURCE, make_minting_context(0) + ) + self.assertIsInstance(exc.exception.orig_err, AssertionError) + + def test_main_compiles_contract_class_without_explicit_validator(self): + with tempfile.TemporaryDirectory() as tmpdir: + contract_path = f"{tmpdir}/contract.py" + with open(contract_path, "w") as fp: + fp.write(CONTRACT_SOURCE) + result = subprocess.run( + ["opshin", "compile", contract_path, '{"int": 5}'], + capture_output=True, + text=True, + cwd=tmpdir, + ) + self.assertEqual(result.returncode, 0, result.stderr) + + def test_main_compiles_contract_class_with_inherited_raw(self): + with tempfile.TemporaryDirectory() as tmpdir: + contract_path = f"{tmpdir}/contract.py" + with open(contract_path, "w") as fp: + fp.write(EMPTY_CONTRACT_SOURCE) + result = subprocess.run( + ["opshin", "compile", contract_path, '{"int": 5}'], + capture_output=True, + text=True, + cwd=tmpdir, + ) + self.assertEqual(result.returncode, 0, result.stderr) + + def test_runtime_contract_discovery_rejects_constr_id_definition(self): + module = types.ModuleType("contract_module") + exec(INVALID_CONSTR_ID_CONTRACT_SOURCE, module.__dict__) + with self.assertRaises(AssertionError): + discover_contract_module(module) + + def test_runtime_contract_discovery_rejects_unannotated_field_definition(self): + module = types.ModuleType("contract_module") + exec(INVALID_UNANNOTATED_FIELD_CONTRACT_SOURCE, module.__dict__) + with self.assertRaises(AssertionError): + discover_contract_module(module) + + def test_runtime_contract_discovery_dispatches_between_spending_entrypoints(self): + module = types.ModuleType("contract_module") + exec(DOUBLE_SPENDING_ENTRYPOINTS_SOURCE, module.__dict__) + contract_info = discover_contract_module(module) + self.assertIsNotNone(contract_info) + self.assertEqual( + contract_info.validator(make_spending_context_without_datum(7)), 7 + ) + self.assertEqual(contract_info.validator(make_spending_context(5, 7)), 12) + + def test_contract_rejects_optional_nooutputdatum_union(self): + with self.assertRaises(CompilerError) as exc: + eval_uplc_value( + INVALID_OPTIONAL_DATUM_CONTRACT_SOURCE, 5, make_spending_context(2, 3) + ) + self.assertIsInstance(exc.exception.orig_err, AssertionError) + + def test_runtime_contract_discovery_rejects_optional_nooutputdatum_union(self): + module = types.ModuleType("contract_module") + exec(INVALID_OPTIONAL_DATUM_CONTRACT_SOURCE, module.__dict__) + with self.assertRaises(AssertionError): + discover_contract_module(module) + + def test_contract_rewrite_removes_contract_class(self): + rewritten_module = RewriteContractMethods().visit(ast.parse(CONTRACT_SOURCE)) + self.assertFalse( + any( + isinstance(statement, ast.ClassDef) + and any( + isinstance(base, ast.Name) and base.id == "Contract" + for base in statement.bases + ) + for statement in rewritten_module.body + ) + ) + + def test_contract_rewrite_uses_reserved_internal_names(self): + rewritten_module = RewriteContractMethods().visit( + ast.parse(LOCAL_PURPOSE_NAME_CONTRACT_SOURCE) + ) + bound_names = { + node.id + for node in ast.walk(rewritten_module) + if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store) + } + internal_names = {name for name in bound_names if name.startswith("__contract")} + self.assertTrue(internal_names) + self.assertTrue(all("+" in name for name in internal_names)) + + def test_contract_rewrite_uses_prelude_datum_helpers(self): + rewritten_module = RewriteContractMethods().visit(ast.parse(CONTRACT_SOURCE)) + helper_calls = { + node.func.id + for node in ast.walk(rewritten_module) + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) + } + self.assertIn("own_datum_unsafe", helper_calls) + + def test_contract_rewrite_specializes_spend_with_datum_only_contract(self): + rewritten_module = RewriteContractMethods().visit( + ast.parse(SPEND_WITH_DATUM_ONLY_SOURCE) + ) + validator = next( + statement + for statement in rewritten_module.body + if isinstance(statement, ast.FunctionDef) and statement.name == "validator" + ) + helper_calls = { + node.func.id + for node in ast.walk(validator) + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) + } + referenced_names = { + node.id for node in ast.walk(validator) if isinstance(node, ast.Name) + } + self.assertIn("own_datum_unsafe", helper_calls) + self.assertNotIn("own_datum", helper_calls) + self.assertNotIn("Minting", referenced_names) + self.assertNotIn("Spending", referenced_names) diff --git a/tests/test_uplc_patch.py b/tests/test_uplc_patch.py index 2d8e0535..3d4ffeef 100644 --- a/tests/test_uplc_patch.py +++ b/tests/test_uplc_patch.py @@ -1,5 +1,5 @@ -import unittest import os +import unittest from frozendict import frozendict import uplc.ast as uplc_ast