diff --git a/src/braket/default_simulator/openqasm/program_context.py b/src/braket/default_simulator/openqasm/program_context.py index ed5f44d0..dc53e4c5 100644 --- a/src/braket/default_simulator/openqasm/program_context.py +++ b/src/braket/default_simulator/openqasm/program_context.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor from copy import copy from dataclasses import fields from functools import singledispatchmethod @@ -1996,30 +1997,42 @@ def _measure_and_branch(self, target: tuple[int]) -> None: """Sample outcomes per active path and branch with proportional shot allocation. - For each qubit in target, for each active path: - 1. Ask the subclass-supplied ``_get_qubit_samples`` for - ``path.shots`` sampled bit outcomes of the qubit on this path. - 2. Split the path: one child gets shots that measured 0, the other - gets shots that measured 1. + For each qubit in target: + 1. Ask the subclass-supplied ``_get_qubit_samples`` for ``path.shots`` + sampled bit outcomes of the qubit on each active path. When the + simulator opts in via ``parallelize_paths`` these per-path + samplings are fanned out to a thread pool. + 2. For each path, split it: one child gets shots that measured 0, the + other gets shots that measured 1. 3. If one outcome has 0 shots, don't create that branch (deterministic case). 4. Remove paths with 0 shots from the active set. """ for qubit_idx in target: + saved_active = list(self._active_path_indices) + per_path_samples = self._collect_qubit_samples(saved_active, qubit_idx) new_active_indices = [] - for path_idx in list(self._active_path_indices): - self._branch_single_qubit(path_idx, qubit_idx, new_active_indices) + for path_idx, qubit_samples in zip(saved_active, per_path_samples): + self._branch_single_qubit(path_idx, qubit_idx, qubit_samples, new_active_indices) self._active_path_indices = new_active_indices + def _collect_qubit_samples(self, path_indices: list[int], qubit_idx: int) -> list[np.ndarray]: + paths = [self._paths[idx] for idx in path_indices] + if self._simulator is not None and self._simulator.parallelize_paths and len(paths) > 1: + with ThreadPoolExecutor() as pool: + return list(pool.map(lambda path: self._get_qubit_samples(path, qubit_idx), paths)) + return [self._get_qubit_samples(path, qubit_idx) for path in paths] + def _branch_single_qubit( - self, path_idx: int, qubit_idx: int, new_active_indices: list[int] + self, + path_idx: int, + qubit_idx: int, + qubit_samples: np.ndarray, + new_active_indices: list[int], ) -> None: """Branch a single path on a single qubit measurement.""" path = self._paths[path_idx] - # Defer to the concrete simulator to sample the target qubit's bit for - # each of ``path.shots`` shots; then the shot-split is just a tally. - qubit_samples = self._get_qubit_samples(path, qubit_idx) path_shots = path.shots shots_for_1 = int(np.sum(qubit_samples)) shots_for_0 = path_shots - shots_for_1 diff --git a/src/braket/default_simulator/simulator.py b/src/braket/default_simulator/simulator.py index 9f2675e7..c379177a 100644 --- a/src/braket/default_simulator/simulator.py +++ b/src/braket/default_simulator/simulator.py @@ -15,6 +15,7 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor from typing import Any import numpy as np @@ -163,6 +164,11 @@ def run( return self.run_program_set(circuit_ir, *args, **kwargs) return self.run_jaqcd(circuit_ir, *args, **kwargs) + @property + def parallelize_paths(self) -> bool: + """bool: Whether to run path simulations in parallel.""" + return False + def create_program_context(self) -> AbstractProgramContext: return ProgramContext(simulator=self) @@ -878,14 +884,29 @@ def _run_branched( if circuit.qubit_set: sim_qubit_count = max(sim_qubit_count, max(circuit.qubit_set) + 1) - # Aggregate samples across all active paths + paths = list(context.active_paths) + if self.parallelize_paths and len(paths) > 1: + with ThreadPoolExecutor() as pool: + per_path_samples = list( + pool.map( + _evolve_path_and_sample, + [self] * len(paths), + [path.instructions for path in paths], + [sim_qubit_count] * len(paths), + [path.shots for path in paths], + [batch_size] * len(paths), + ) + ) + else: + per_path_samples = [ + _evolve_path_and_sample( + self, path.instructions, sim_qubit_count, path.shots, batch_size + ) + for path in paths + ] all_samples = [] - for path in context.active_paths: - sim = self.initialize_simulation( - qubit_count=sim_qubit_count, shots=path.shots, batch_size=batch_size - ) - sim.evolve(path.instructions) - all_samples.extend(sim.retrieve_samples()) + for samples in per_path_samples: + all_samples.extend(samples) # Build measurements in the same format as _formatted_measurements measurements = [ @@ -995,3 +1016,17 @@ def run_jaqcd( ) return self._create_results_obj(results, circuit_ir, simulation) + + +def _evolve_path_and_sample( + simulator: BaseLocalSimulator, + instructions, + qubit_count: int, + shots: int, + batch_size: int, +): + sim = simulator.initialize_simulation( + qubit_count=qubit_count, shots=shots, batch_size=batch_size + ) + sim.evolve(instructions) + return sim.retrieve_samples() diff --git a/test/unit_tests/braket/default_simulator/test_mcm.py b/test/unit_tests/braket/default_simulator/test_mcm.py index 4435cac7..7563bd02 100644 --- a/test/unit_tests/braket/default_simulator/test_mcm.py +++ b/test/unit_tests/braket/default_simulator/test_mcm.py @@ -5030,3 +5030,44 @@ def test_flat_context_preserves_mcm_while_loop(self): "}" ) assert Interpreter(context=FlatProgramContext()).run(qasm).circuit == qasm + + +class TestParallelizePaths: + """Cover the ``parallelize_paths`` thread-pool branches in + ``BaseLocalSimulator._run_branched`` and + ``ProgramContext._collect_qubit_samples``.""" + + class _ParallelStateVectorSimulator(StateVectorSimulator): + @property + def parallelize_paths(self) -> bool: + return True + + def test_parallel_paths_produce_expected_distribution(self): + """A parallelized simulator produces a correct distribution for a + program whose second measurement fires on multiple already-branched + paths — which exercises the thread-pool branch of + ``_collect_qubit_samples``.""" + qasm = """ + OPENQASM 3.0; + qubit[3] q; + bit b0; + bit b1; + h q[0]; + b0 = measure q[0]; + if (b0 == 1) { + h q[1]; + } else { + h q[1]; + } + b1 = measure q[1]; + if (b1 == 1) { + x q[2]; + } + """ + simulator = self._ParallelStateVectorSimulator() + result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=1000) + counts = Counter("".join(m) for m in result.measurements) + # Regardless of b0, q[1] is Hadamarded then measured → 50/50. + # Regardless of b1, q[2] ends in state b1. + assert set(counts.keys()).issubset({"000", "001", "011", "100", "101", "111"}) + assert sum(counts.values()) == 1000