diff --git a/src/braket/program_sets/program_set.py b/src/braket/program_sets/program_set.py index 0f4964af5..73e092a0a 100644 --- a/src/braket/program_sets/program_set.py +++ b/src/braket/program_sets/program_set.py @@ -13,7 +13,8 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +from collections.abc import Iterator, Mapping, Sequence +from dataclasses import dataclass from braket.ir.openqasm import ProgramSet as OpenQASMProgramSet @@ -97,6 +98,167 @@ def total_shots(self) -> int: raise ValueError("No per-executable shots defined") return self._shots_per_executable * self.total_executables + def enumerate_executables(self) -> Iterator[tuple[int, int, int]]: + """Yield ``(binding_index, parameter_set_index, observable_index)`` tuples in order, + one per executable. + + The iteration order is: iterate over ``self.entries``; within each entry, + iterate over parameter set indices; within each parameter set index, + iterate over observable indices. The total number of yields is ``self.total_executables``. + + For ``Circuit``s and ``CircuitBinding``s with no input sets, ``parameter_set_index`` is 0. + For entries with no observables, ``observable_index`` is 0. For ``CircuitBinding``s with a + ``Sum`` Hamiltonian, ``observable_index`` ranges over the summands. + + This ordering is used by ``split`` to build its index map and by + ``ProgramSetQuantumTaskResult.merge`` to merge results back into the original shape. + + Yields: + tuple[int, int, int]: ``(binding_index, parameter_set_index, observable_index)``. + """ + for binding_idx, prog in enumerate(self._programs): + if isinstance(prog, Circuit): + yield binding_idx, 0, 0 + continue + num_obs = len(prog.observables) if prog.observables is not None else 1 + for ps_idx in range(len(prog.input_sets) if prog.input_sets is not None else 1): + for obs_idx in range(num_obs): + yield binding_idx, ps_idx, obs_idx + + def split(self, max_executables: int) -> tuple[list[ProgramSet], list[list[int]]]: + """Split this program set into program sets of at most ``max_executables`` executables, + alongside a map that records the position in the original program set of each executable + in each of the generated program sets. + + When a single parameter set index of a ``CircuitBinding`` would by itself exceed + ``max_executables`` due to its observable list or ``Sum`` Hamiltonian being larger than + the budget, the observable list is split into chunks of at most ``max_executables`` entries + (``Sum`` summands are sliced with coefficients preserved). Observable splitting is only + performed when necessary; otherwise the full observable list or ``Sum`` is kept intact. + + Args: + max_executables (int): The maximum number of executables per program + set. Must be positive. + + Returns: + tuple[list[ProgramSet], list[list[int]]]: ``(program_sets, index_map)``. + ``index_map[k][j]`` is the index of the executable that the j-th executable of + ``program_sets[k]`` represents. + If this program set already fits within ``max_executables``, the returned + program-set list is ``[self]`` and the index_map is ``[[0, 1, ..., + total_executables - 1]]``. + + Raises: + ValueError: If ``max_executables`` is not positive. + + Examples: + >>> ps = ProgramSet([ + ... CircuitBinding(c1, inputs1, obs1), # 100 param sets, 4 observables + ... CircuitBinding(c2, inputs2, obs2), # 50 param sets, 2 observables + ... ]) + >>> subs, index_map = ps.split(120) + >>> [s.total_executables for s in subs] + [120, 120, 120, 120, 20] + >>> sum(len(m) for m in index_map) == ps.total_executables + True + """ + if max_executables <= 0: + raise ValueError(f"max_executables must be positive, got {max_executables}") + + if self.total_executables <= max_executables: + return [self], [list(range(self.total_executables))] + + program_sets = [] + index_map = [] + current = [] + current_size = 0 + for block in self._executable_blocks(max_executables): + if current and current_size + block.size > max_executables: + sub, sub_map = self._build_program_set(current) + program_sets.append(sub) + index_map.append(sub_map) + current = [] + current_size = 0 + current.append(block) + current_size += block.size + sub, sub_map = self._build_program_set(current) + program_sets.append(sub) + index_map.append(sub_map) + + return program_sets, index_map + + def _executable_blocks(self, max_executables: int) -> list[_ExecutableBlock]: + blocks = [] + orig_idx = 0 + for prog_idx, prog in enumerate(self._programs): + if isinstance(prog, Circuit): + blocks.append( + _ExecutableBlock( + prog_idx=prog_idx, + param_set_index=None, + obs_slice=None, + size=1, + original_indices=[orig_idx], + ) + ) + orig_idx += 1 + continue + + num_ps = len(prog.input_sets) if prog.input_sets is not None else 1 + obs_windows = _observable_windows( + len(prog.observables) if prog.observables is not None else 1, max_executables + ) + split_observables = len(obs_windows) > 1 + for ps_idx in range(num_ps) if prog.input_sets is not None else [None]: + for start, stop in obs_windows: + size = stop - start + blocks.append( + _ExecutableBlock( + prog_idx=prog_idx, + param_set_index=ps_idx, + obs_slice=slice(start, stop) if split_observables else None, + size=size, + original_indices=list(range(orig_idx, orig_idx + size)), + ) + ) + orig_idx += size + return blocks + + def _build_program_set(self, blocks: list[_ExecutableBlock]) -> tuple[ProgramSet, list[int]]: + entries = [] + sub_map = [] + i = 0 + while i < len(blocks): + head = blocks[i] + prog = self._programs[head.prog_idx] + if head.param_set_index is None: + entries.append(_apply_obs_slice(prog, head.obs_slice)) + sub_map.extend(head.original_indices) + i += 1 + continue + + j = i + while ( + j + 1 < len(blocks) + and blocks[j + 1].prog_idx == head.prog_idx + and blocks[j + 1].obs_slice == blocks[j].obs_slice + and blocks[j + 1].param_set_index == blocks[j].param_set_index + 1 + ): + j += 1 + start = head.param_set_index + stop = blocks[j].param_set_index + 1 + entries.append( + CircuitBinding( + prog.circuit, + input_sets=prog.input_sets.as_list()[start:stop], + observables=_slice_observables(prog.observables, head.obs_slice), + ) + ) + for k in range(i, j + 1): + sub_map.extend(blocks[k].original_indices) + i = j + 1 + return ProgramSet(entries, self._shots_per_executable), sub_map + @staticmethod def zip( circuits: Sequence[Circuit] | CircuitBinding, @@ -206,6 +368,64 @@ def __repr__(self): ) +@dataclass +class _ExecutableBlock: + """Multi-index range for an equivalence class of executables sharing the same combination of + ``(circuit, observable list/Sum Hamiltonian, single parameter assignment)``. + + Attributes: + prog_idx: Index of the originating program in ``ProgramSet.entries``. + param_set_index: Index into the originating ``CircuitBinding``'s ``input_sets``, or ``None`` + for ``Circuit`` entries and ``CircuitBinding``s with no input sets. + obs_slice: Slice into the originating observable list or ``Sum`` summands when observables + were split to fit the budget; ``None`` means the full original observable list + (or no observables). + size: Number of executables this block represents (== ``len(original_indices)``). + original_indices: The indices of this block's executables + in the order of the original program set. + """ + + prog_idx: int + param_set_index: int | None + obs_slice: slice | None + size: int + original_indices: list[int] + + +def _observable_windows(num_observables: int, max_executables: int) -> list[tuple[int, int]]: + if num_observables <= max_executables: + return [(0, num_observables)] + windows = [] + start = 0 + while start < num_observables: + stop = min(start + max_executables, num_observables) + windows.append((start, stop)) + start = stop + return windows + + +def _slice_observables( + observables: Sum | Sequence[Observable] | None, obs_slice: slice | None +) -> Sum | Sequence[Observable] | None: + if obs_slice is None or observables is None: + return observables + if isinstance(observables, Sum): + return Sum(list(observables.summands)[obs_slice]) + return list(observables)[obs_slice] + + +def _apply_obs_slice( + prog: CircuitBinding | Circuit, obs_slice: slice | None +) -> CircuitBinding | Circuit: + if obs_slice is None or isinstance(prog, Circuit) or prog.observables is None: + return prog + return CircuitBinding( + prog.circuit, + input_sets=prog.input_sets, + observables=_slice_observables(prog.observables, obs_slice), + ) + + def _zip_circuit_bindings( circuit_binding: CircuitBinding, input_sets: Sequence[Mapping[str, float]] | None, diff --git a/test/unit_tests/braket/program_sets/test_program_set.py b/test/unit_tests/braket/program_sets/test_program_set.py index 2706007d4..4a5b1792c 100644 --- a/test/unit_tests/braket/program_sets/test_program_set.py +++ b/test/unit_tests/braket/program_sets/test_program_set.py @@ -534,3 +534,263 @@ def test_inequality(circuit_rx_parametrized): program_set = ProgramSet([binding, binding]) assert program_set != ProgramSet([binding, circuit_rx_parametrized]) assert program_set != circuit_rx_parametrized + + +def test_split_already_fits(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}, {"theta": 3.21}]) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(10) + assert subs == [program_set] + assert subs[0] is program_set + assert mapping == [[0, 1]] + + +def test_split_exact_fit(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}, {"theta": 3.21}]) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(2) + assert subs == [program_set] + assert subs[0] is program_set + assert mapping == [[0, 1]] + + +def test_split_plain_circuits(): + circs = [ghz(1), ghz(2), ghz(3), ghz(1), ghz(2)] + program_set = ProgramSet(circs, shots_per_executable=10) + subs, mapping = program_set.split(2) + assert [s.total_executables for s in subs] == [2, 2, 1] + assert subs[0].entries == circs[0:2] + assert subs[1].entries == circs[2:4] + assert subs[2].entries == circs[4:5] + assert mapping == [[0, 1], [2, 3], [4]] + + +def test_split_single_binding_packed(circuit_rx_parametrized): + inputs = {"theta": [float(i) for i in range(10)]} + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [3, 3, 3, 1] + # Each sub-program-set is a single coalesced binding over a contiguous slice. + for s in subs: + assert len(s) == 1 + assert s.entries[0].circuit == circuit_rx_parametrized + assert s.entries[0].observables is None + thetas = [] + for s in subs: + thetas.extend(s.entries[0].input_sets.as_dict()["theta"]) + assert thetas == inputs["theta"] + assert mapping == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + + +def test_split_with_observables(circuit_rx_parametrized): + # 5 parameter-set indices, 4 observables => 5 classes of size 4. + inputs = {"theta": [float(i) for i in range(5)]} + observables = [X(0), Y(0), Z(0), X(0) @ Y(1)] + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(8) + assert [s.total_executables for s in subs] == [8, 8, 4] + # Observables propagate unchanged (never split across sub-program-sets). + for s in subs: + assert s.entries[0].observables == observables + assert sum(mapping, []) == list(range(20)) + + +def test_split_with_sum_hamiltonian(circuit_rx_parametrized): + # Sum with 3 summands => class size = 3 per parameter-set index. + inputs = {"theta": [float(i) for i in range(4)]} + hamiltonian = 1.0 * X(0) + 2.0 * Y(0) + 3.0 * Z(0) + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=hamiltonian) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(6) + assert [s.total_executables for s in subs] == [6, 6] + # Sum preserved intact (no observable-splitting needed at max=6). + for s in subs: + assert s.entries[0].observables is hamiltonian + assert sum(mapping, []) == list(range(12)) + + +def test_split_worked_example(circuit_rx_parametrized): + # Two bindings: c1 with 100 param sets × 4 obs, c2 with 50 param sets × 2 obs. + c1 = circuit_rx_parametrized + c2 = Circuit().rx(0, FreeParameter("phi")) + obs1 = [X(0), Y(0), Z(0), X(0) @ Y(1)] + obs2 = [X(0), Z(0)] + binding1 = CircuitBinding(c1, {"theta": [float(i) for i in range(100)]}, obs1) + binding2 = CircuitBinding(c2, {"phi": [float(i) for i in range(50)]}, obs2) + program_set = ProgramSet([binding1, binding2]) + + subs, mapping = program_set.split(120) + # Greedy packing fills each bucket up to the budget before flushing. + assert [s.total_executables for s in subs] == [120, 120, 120, 120, 20] + assert sum(s.total_executables for s in subs) == program_set.total_executables + # First three buckets are pure c1 (30 × 4 each). + for i in range(3): + assert len(subs[i]) == 1 + assert subs[i].entries[0].circuit == c1 + assert len(subs[i].entries[0].input_sets) == 30 + # Bucket 3 straddles both bindings (10 × 4 + 40 × 2 = 120); coalesced per binding. + assert len(subs[3]) == 2 + assert subs[3].entries[0].circuit == c1 + assert len(subs[3].entries[0].input_sets) == 10 + assert subs[3].entries[1].circuit == c2 + assert len(subs[3].entries[1].input_sets) == 40 + # Last bucket is pure c2 remainder (10 × 2 = 20). + assert len(subs[4]) == 1 + assert subs[4].entries[0].circuit == c2 + assert len(subs[4].entries[0].input_sets) == 10 + # Mapping covers every original executable exactly once, in order. + assert sum(mapping, []) == list(range(500)) + + +def test_split_preserves_shots(circuit_rx_parametrized): + inputs = {"theta": [float(i) for i in range(5)]} + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs) + program_set = ProgramSet(binding, shots_per_executable=100) + subs, _ = program_set.split(2) + assert all(s.shots_per_executable == 100 for s in subs) + assert sum(s.total_shots for s in subs) == program_set.total_shots + + +def test_split_coalesces_adjacent_same_binding(circuit_rx_parametrized): + # 6 parameter-set indices, class size 1, max_executables=4 => buckets of 4, 2. + # Each bucket should contain one coalesced multi-parameter-set binding, + # not four (resp. two) separate single-parameter-set bindings. + inputs = {"theta": [float(i) for i in range(6)]} + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs) + program_set = ProgramSet(binding) + subs, _ = program_set.split(4) + assert [len(s) for s in subs] == [1, 1] + assert len(subs[0].entries[0].input_sets) == 4 + assert len(subs[1].entries[0].input_sets) == 2 + + +def test_split_binding_without_input_sets(circuit_rx_parametrized): + # A binding with only observables is a single class of size len(observables). + c1 = circuit_rx_parametrized + c2 = Circuit().rx(0, FreeParameter("phi")) + binding_a = CircuitBinding(c1, observables=[X(0), Y(0)]) # size 2 + binding_b = CircuitBinding(c2, observables=[X(0), Y(0), Z(0)]) # size 3 + program_set = ProgramSet([binding_a, binding_b]) + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [2, 3] + assert subs[0].entries == [binding_a] + assert subs[1].entries == [binding_b] + assert mapping == [[0, 1], [2, 3, 4]] + + +def test_split_non_positive_raises(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, input_sets=[{"theta": 1.23}]) + program_set = ProgramSet(binding) + with pytest.raises(ValueError, match="must be positive"): + program_set.split(0) + with pytest.raises(ValueError, match="must be positive"): + program_set.split(-3) + + +def test_split_oversize_list_observables_are_chunked(circuit_rx_parametrized): + # A single class of 10 observables with max_executables=3 becomes 4 sub-program-sets + # of sizes 3, 3, 3, 1, each with a sliced observable list. + observables = [X(0), Y(0), Z(0), X(0), Y(0), Z(0), X(0), Y(0), Z(0), X(0)] + binding = CircuitBinding(circuit_rx_parametrized, observables=observables) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [3, 3, 3, 1] + assert mapping == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + slices = [list(s.entries[0].observables) for s in subs] + assert slices == [observables[0:3], observables[3:6], observables[6:9], observables[9:10]] + + +def test_split_oversize_sum_hamiltonian_is_chunked(circuit_rx_parametrized): + # Sum with 7 summands, max_executables=3 → sub-Sums of sizes 3, 3, 1 with + # coefficients preserved on each summand. + ham = 1.0 * X(0) + 2.0 * Y(0) + 3.0 * Z(0) + 4.0 * X(0) + 5.0 * Y(0) + 6.0 * Z(0) + 7.0 * X(0) + binding = CircuitBinding(circuit_rx_parametrized, observables=ham) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [3, 3, 1] + assert mapping == [[0, 1, 2], [3, 4, 5], [6]] + # Each sub-observable is a Sum whose summands come from the original in order. + expected_summands = list(ham.summands) + got_summands: list = [] + for s in subs: + sub_obs = s.entries[0].observables + assert isinstance(sub_obs, type(ham)) + got_summands.extend(sub_obs.summands) + assert got_summands == expected_summands + + +def test_split_oversize_observables_with_multiple_param_sets(circuit_rx_parametrized): + # 2 parameter sets x 5 observables, max_executables=3 ⇒ each parameter-set index + # splits into two observable windows ((0,3) size 3 and (3,5) size 2). The packer + # can't coalesce across parameter sets because they're interleaved by window, so we + # end up with 4 sub-program-sets. + inputs = {"theta": [1.0, 2.0]} + observables = [X(0), Y(0), Z(0), X(0), Y(0)] + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables) + program_set = ProgramSet(binding) + subs, mapping = program_set.split(3) + assert [s.total_executables for s in subs] == [3, 2, 3, 2] + # Mapping follows canonical ordering: ps=0,obs=0..4 = indices 0..4; ps=1,obs=0..4 = 5..9. + assert mapping == [[0, 1, 2], [3, 4], [5, 6, 7], [8, 9]] + assert sum(mapping, []) == list(range(program_set.total_executables)) + + +def test_split_sub_program_sets_are_serializable(circuit_rx_parametrized): + inputs = {"theta": [float(i) for i in range(10)]} + observables = [X(0), Y(0)] + binding = CircuitBinding(circuit_rx_parametrized, input_sets=inputs, observables=observables) + program_set = ProgramSet(binding) + subs, _ = program_set.split(6) + # Each sub-program set is a fully formed ProgramSet: to_ir() works and returns a + # single-program IR (one coalesced CircuitBinding per sub-program set here). + for s in subs: + ir = s.to_ir() + assert len(ir.programs) == len(s) + + +def test_enumerate_executables_plain_circuits(): + ps = ProgramSet([ghz(1), ghz(2), ghz(3)]) + assert list(ps.enumerate_executables()) == [(0, 0, 0), (1, 0, 0), (2, 0, 0)] + + +def test_enumerate_executables_binding_with_input_sets_only(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, input_sets={"theta": [0.1, 0.2, 0.3]}) + ps = ProgramSet(binding) + assert list(ps.enumerate_executables()) == [(0, 0, 0), (0, 1, 0), (0, 2, 0)] + + +def test_enumerate_executables_binding_with_observables_only(circuit_rx_parametrized): + binding = CircuitBinding(circuit_rx_parametrized, observables=[X(0), Y(0), Z(0)]) + ps = ProgramSet(binding) + assert list(ps.enumerate_executables()) == [(0, 0, 0), (0, 0, 1), (0, 0, 2)] + + +def test_enumerate_executables_mixed(): + # circuit, binding with 2 ps x 3 obs, binding with 2 ps no obs, binding with 4 obs no ps. + c0 = ghz(1) + c1 = Circuit().rx(0, FreeParameter("t")).cnot(0, 1) + c2 = Circuit().rx(0, FreeParameter("p")) + c3 = Circuit().h(0) + b1 = CircuitBinding(c1, {"t": [0.1, 0.2]}, [X(0), Y(0), Z(0)]) + b2 = CircuitBinding(c2, {"p": [0.3, 0.4]}) + b3 = CircuitBinding(c3, observables=[X(0), Y(0), Z(0), X(0) @ Y(1)]) + ps = ProgramSet([c0, b1, b2, b3]) + expected = [ + (0, 0, 0), + (1, 0, 0), + (1, 0, 1), + (1, 0, 2), + (1, 1, 0), + (1, 1, 1), + (1, 1, 2), + (2, 0, 0), + (2, 1, 0), + (3, 0, 0), + (3, 0, 1), + (3, 0, 2), + (3, 0, 3), + ] + assert list(ps.enumerate_executables()) == expected + assert len(expected) == ps.total_executables