Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 125 additions & 4 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import unittest
from pathlib import Path

Expand Down Expand Up @@ -1532,14 +1533,12 @@ def test_unpack_plutus_data(self):
)

def test_parse(self):
p = parse(
"""
p = parse("""
(program
1.0.0
[ [ [ (force (delay [(lam i_0 (con integer 2)) (con bytestring #02)])) (builtin addInteger) ] (error) ] (con pair<list<integer>,unit> [[],()]) ]
)
"""
)
""")
print(dumps(p))

@parameterized.expand(
Expand Down Expand Up @@ -2009,3 +2008,125 @@ def test_invalid_list(self):
with self.assertRaises(ValueError) as context:
data_from_json(param)
self.assertIn("expected a list", str(context.exception).lower())

def test_haskell_string_escapes(self):
"""Test Haskell decimal (\\DDD) and octal (\\oOOO) string escapes.

Conformance test string-04:
Input: (con string "\\t\\"\\83\\x75\\x63\\o143e\\x73s\\o041\\o042\\n")
Expected decoded value: \\t"Success!"\\n
"""
program = (
r'(program 1.0.0 (con string "\t\"\83\x75\x63\o143e\x73s\o041\o042\n"))'
)
p = parse(program)
# The string value should be: tab + "Success!" + newline
self.assertEqual(p.term.value, '\t"Success!"\n')

def test_haskell_decimal_escape(self):
"""Test standalone Haskell decimal escape."""
program = r'(program 1.0.0 (con string "\65"))'
p = parse(program)
self.assertEqual(p.term.value, "A") # chr(65) == 'A'

def test_haskell_octal_escape(self):
"""Test standalone Haskell octal escape."""
program = r'(program 1.0.0 (con string "\o101"))'
p = parse(program)
self.assertEqual(p.term.value, "A") # chr(0o101) == 'A'

def test_array_type_keyword(self):
"""Test that 'array' is accepted as an alias for 'list'.

Haskell ref: PlutusCore.Default.Universe (defaultUni) lists
'array' as an alternative name for the list type constructor.
"""
program = "(program 1.0.0 (con (list integer) [1, 2, 3]))"
p_list = parse(program)
program_array = "(program 1.0.0 (con (array integer) [1, 2, 3]))"
p_array = parse(program_array)
self.assertEqual(p_list.term.values, p_array.term.values)

def test_array_type_keyword_aiken_dialect(self):
"""Test 'array' works in Aiken dialect (caret notation)."""
program = "(program 1.0.0 (con array<integer> [1, 2]))"
p = parse(program)
self.assertEqual(len(p.term.values), 2)

def test_strict_mode_trailing_data(self):
"""Test that strict=True rejects programs with trailing bytes.

PlutusV3 (Conway-era) requires strict deserialization — no
trailing bytes are allowed after the flat-encoded program.
"""
from uplc.tools import flatten, unflatten

program = parse("(program 1.0.0 (con integer 1))")
flat_bytes = flatten(program)
# Normal mode should accept
unflatten(flat_bytes, strict=False)
# Strict mode should also accept clean encoding
unflatten(flat_bytes, strict=True)

def test_case_on_bool(self):
"""Test case expression scrutinizing a Bool value.

CEK machine must convert Bool to constr-like: False=tag 0, True=tag 1.
"""
program = parse(
"(program 1.1.0 (case (con bool True) (con integer 10) (con integer 20)))"
)
result = eval(program)
self.assertEqual(result.result.value, 20)

def test_case_on_unit(self):
"""Test case expression scrutinizing a Unit value.

CEK machine must convert Unit to constr-like: tag 0, no fields.
"""
program = parse("(program 1.1.0 (case (con unit ()) (con integer 42)))")
result = eval(program)
self.assertEqual(result.result.value, 42)

def test_case_on_list_nil(self):
"""Test case on empty list: nil = tag 0."""
program = parse(
"(program 1.1.0 (case (con (list integer) []) (con integer 1) (con integer 2)))"
)
result = eval(program)
self.assertEqual(result.result.value, 1)

def test_case_on_list_cons(self):
"""Test case on non-empty list: cons = tag 1 with fields [head, tail]."""
program = parse(
"(program 1.1.0 (case (con (list integer) [5, 6]) (con integer 1) (lam h (lam t (con integer 2)))))"
)
result = eval(program)
self.assertEqual(result.result.value, 2)

def test_zero_cost_builtin_raises(self):
"""Unknown builtins should raise RuntimeError, not return Budget(0,0)."""
from uplc.machine import budget_cost_of_op_on_model
from uplc.cost_model import BuiltinCostModel

empty_model = BuiltinCostModel(cpu={}, memory={})
with self.assertRaises(RuntimeError):
budget_cost_of_op_on_model(empty_model, "FakeBuiltin")

def test_cost_model_file_extension(self):
"""Cost model loader should match '.json' suffix, not 'json'."""
import uplc.cost_model as cm

# The fix: file.suffix returns ".json", not "json"
# We verify the code uses ".json" by checking the source
import inspect

source = inspect.getsource(cm.load_network_config)
self.assertIn('.suffix == ".json"', source)

def test_schnorr_error_message(self):
"""Schnorr verification should report 'Schnorr', not 'ECDSA'."""
import uplc.ast as uplc_ast

source = inspect.getsource(uplc_ast.verify_schnorr_secp256k1)
self.assertNotIn("ECDSA", source)
10 changes: 6 additions & 4 deletions uplc/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,8 @@ def verify_ed25519(pk: BuiltinByteString, m: BuiltinByteString, s: BuiltinByteSt
def verify_ecdsa_secp256k1(
pk: BuiltinByteString, m: BuiltinByteString, s: BuiltinByteString
):
# TODO length checks
# Let the underlying crypto library validate sizes — the Haskell spec
# uses varying encodings (compressed/uncompressed pubkeys, DER/compact sigs)
if pysecp256k1 is None:
_LOGGER.error("libsecp256k1 is not installed. ECDSA verification will not work")
raise RuntimeError("ECDSA not supported")
Expand All @@ -1005,10 +1006,11 @@ def verify_ecdsa_secp256k1(
def verify_schnorr_secp256k1(
pk: BuiltinByteString, m: BuiltinByteString, s: BuiltinByteString
):
# TODO length checks
if pysecp256k1 is None:
_LOGGER.error("libsecp256k1 is not installed. ECDSA verification will not work")
raise RuntimeError("ECDSA not supported")
_LOGGER.error(
"libsecp256k1 is not installed. Schnorr verification will not work"
)
raise RuntimeError("Schnorr not supported")
if schnorrsig is None:
_LOGGER.error(
"libsecp256k1 is installed without schnorr support. Schnorr verification will not work"
Expand Down
2 changes: 1 addition & 1 deletion uplc/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def load_network_config(config_date: datetime.date):
network_config_dir = NETWORK_CONFIG_DIR.joinpath(latest_dir_name)
file = None
for file in network_config_dir.iterdir():
if file.suffix == "json":
if file.suffix == ".json":
break
if file is None:
raise ValueError("Latest network config could not be loaded")
Expand Down
9 changes: 9 additions & 0 deletions uplc/flat_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ def read_case(self) -> Case:
def finalize(self):
self.move_to_byte_boundary(True)

def has_trailing_data(self) -> bool:
"""Check if there are non-padding bits after the current position.

After read_program() + finalize(), any remaining bits beyond the
byte boundary are trailing data. Returns True if trailing data
exists (i.e., the reader hasn't consumed everything).
"""
return self._pos < len(self._bits)

def read_bits(self, num: int) -> str:
bits = self._bits[self._pos : self._pos + num]
self._pos += num
Expand Down
32 changes: 28 additions & 4 deletions uplc/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def budget_cost_of_op_on_model(
values=[],
):
if op not in model.cpu or op not in model.memory:
return Budget(0, 0)
raise RuntimeError(
f"No cost model entry for builtin {op!r}. "
f"This builtin may not be available in the selected Plutus version."
)
return Budget(
cpu=model.cpu[op].cost(*args, values=values),
memory=model.memory[op].cost(*args, values=values),
Expand Down Expand Up @@ -244,14 +247,35 @@ def return_compute(self, context, value):
Constr(context.tag, resolved_fields),
)
elif isinstance(context, FrameCases):
if not isinstance(value, Constr):
# Convert constant types to constr-like (tag, fields) for case scrutiny
if isinstance(value, Constr):
tag, fields = value.tag, value.fields
elif isinstance(value, BuiltinBool):
# False=0, True=1, no fields
tag, fields = (1 if value.value else 0), []
elif isinstance(value, BuiltinUnit):
# unit -> tag 0, no fields
tag, fields = 0, []
elif isinstance(value, BuiltinPair):
# pair (l, r) -> tag 0, fields [l, r]
tag, fields = 0, [value.l_value, value.r_value]
elif isinstance(value, BuiltinList):
# [] -> tag 0 (nil), [x, ...xs] -> tag 1 with fields [x, xs]
if len(value.values) == 0:
tag, fields = 0, []
else:
tag, fields = 1, [
value.values[0],
BuiltinList(list(value.values[1:]), value.sample_value),
]
else:
raise RuntimeError("Scrutinized non-constr in case")
try:
branch = context.branches[value.tag]
branch = context.branches[tag]
except IndexError as e:
raise RuntimeError("No branch provided for constr tag") from None
return Compute(
transfer_arg_stack(value.fields, context.ctx),
transfer_arg_stack(fields, context.ctx),
context.env,
branch,
)
Expand Down
63 changes: 59 additions & 4 deletions uplc/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast as python_ast
import re

from rply import ParserGenerator
Expand All @@ -15,6 +14,58 @@
Case,
)

_HASKELL_ESCAPE_RE = re.compile(
r"\\(?:"
r"o([0-7]+)" # group 1: \oOOO octal
r"|x([0-9a-fA-F]{2})" # group 2: \xHH hex (exactly 2 digits)
r"|u([0-9a-fA-F]{4})" # group 3: \uHHHH unicode (4 hex digits)
r"|U([0-9a-fA-F]{8})" # group 4: \UHHHHHHHH unicode (8 hex digits)
r"|(\d+)" # group 5: \DDD decimal
r"|([\\\"\'abfnrtv&])" # group 6: single-char escapes
r")"
)

# Standard single-character Haskell/Python escapes
_SIMPLE_ESCAPES = {
"\\": "\\",
'"': '"',
"'": "'",
"a": "\a",
"b": "\b",
"f": "\f",
"n": "\n",
"r": "\r",
"t": "\t",
"v": "\v",
"&": "", # Haskell's \& is a null-width escape (empty string)
}


def _decode_haskell_string(s: str) -> str:
"""Decode a Haskell string literal (with surrounding quotes removed).

Handles all escape sequences: \\n, \\t, \\\\, \\", \\xHH (hex),
\\DDD (decimal), and \\oOOO (octal).
"""

def replace_escape(m):
if m.group(1) is not None: # \oOOO octal
return chr(int(m.group(1), 8))
if m.group(2) is not None: # \xHH hex
return chr(int(m.group(2), 16))
if m.group(3) is not None: # \uHHHH unicode
return chr(int(m.group(3), 16))
if m.group(4) is not None: # \UHHHHHHHH unicode
return chr(int(m.group(4), 16))
if m.group(5) is not None: # \DDD decimal
return chr(int(m.group(5), 10))
if m.group(6) is not None: # single-char escape
return _SIMPLE_ESCAPES[m.group(6)]
return m.group(0) # fallback: leave as-is

return _HASKELL_ESCAPE_RE.sub(replace_escape, s)


PLUTUS_V2 = (1, 0, 0)
PLUTUS_V3 = (1, 1, 0)
PLUTUS_VERSIONS = {PLUTUS_V2, PLUTUS_V3}
Expand Down Expand Up @@ -153,21 +204,23 @@ def constanttype(p):
return ast.BuiltinUnit()
if name == "data":
return ast.PlutusData()
if name == "array":
return ast.BuiltinList([], ast.PlutusData()) # default element type
raise SyntaxError(f"Unknown builtin type {name}")

@self.pg.production("constanttype : name CARET_OPEN constanttype CARET_CLOSE")
def constanttype(p):
# the Aiken dialect
name = p[0].value
if name == "list":
if name == "list" or name == "array":
return ast.BuiltinList([], p[2])
raise SyntaxError(f"Unknown builtin type {name}")

@self.pg.production("constanttype : PAREN_OPEN name constanttype PAREN_CLOSE")
def constanttype(p):
# the Plutus dialect
name = p[1].value
if name == "list":
if name == "list" or name == "array":
return ast.BuiltinList([], p[2])
raise SyntaxError(f"Unknown builtin type {name}")

Expand Down Expand Up @@ -222,7 +275,9 @@ def expression(p):
@self.pg.production("builtinvalue : TEXT")
def expression(p):
s = p[0].value
return python_ast.literal_eval(s)
# Strip surrounding quotes and decode all escape sequences
# including Haskell-specific \DDD (decimal) and \oOOO (octal)
return _decode_haskell_string(s[1:-1])

@self.pg.production("builtinvalue : PAREN_OPEN PAREN_CLOSE")
def expression(p):
Expand Down
18 changes: 16 additions & 2 deletions uplc/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,26 @@ def flatten(x: Program) -> bytes:
return cbor2.dumps(x_flattened)


def unflatten(x_cbor: bytes) -> Program:
"""Returns the program from a singly-CBOR wrapped flat encoding"""
def unflatten(x_cbor: bytes, *, strict: bool = False) -> Program:
"""Returns the program from a singly-CBOR wrapped flat encoding.

Args:
x_cbor: CBOR-wrapped flat-encoded UPLC program bytes.
strict: If True, reject programs with trailing bytes after the
flat encoding. PlutusV3 requires strict mode (Conway-era
tightening). PlutusV1/V2 are lenient (trailing bytes ignored).
"""
x = cbor2.loads(x_cbor)
x_bin = "".join(f"{i:08b}" for i in x)
reader = UplcDeserializer(x_bin)
x_debrujin = reader.read_program()
reader.finalize()
if strict and reader.has_trailing_data():
raise ValueError(
f"Trailing data after flat-encoded program "
f"({len(reader._bits) - reader._pos} bits remaining). "
f"PlutusV3 requires strict deserialization with no trailing bytes."
)
x_uplc = UnDeBrujinVariableTransformer().visit(x_debrujin)
return x_uplc

Expand Down
Loading