Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions src/braket/default_simulator/openqasm/program_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from abc import ABC, abstractmethod
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from dataclasses import fields
from functools import singledispatchmethod
Expand Down Expand Up @@ -1996,30 +1997,42 @@ def _measure_and_branch(self, target: tuple[int]) -> None:
"""Sample outcomes per active path and branch with proportional shot
allocation.

For each qubit in target, for each active path:
1. Ask the subclass-supplied ``_get_qubit_samples`` for
``path.shots`` sampled bit outcomes of the qubit on this path.
2. Split the path: one child gets shots that measured 0, the other
gets shots that measured 1.
For each qubit in target:
1. Ask the subclass-supplied ``_get_qubit_samples`` for ``path.shots``
sampled bit outcomes of the qubit on each active path. When the
simulator opts in via ``parallelize_paths`` these per-path
samplings are fanned out to a thread pool.
2. For each path, split it: one child gets shots that measured 0, the
other gets shots that measured 1.
3. If one outcome has 0 shots, don't create that branch (deterministic
case).
4. Remove paths with 0 shots from the active set.
"""
for qubit_idx in target:
saved_active = list(self._active_path_indices)
per_path_samples = self._collect_qubit_samples(saved_active, qubit_idx)
new_active_indices = []
for path_idx in list(self._active_path_indices):
self._branch_single_qubit(path_idx, qubit_idx, new_active_indices)
for path_idx, qubit_samples in zip(saved_active, per_path_samples):
self._branch_single_qubit(path_idx, qubit_idx, qubit_samples, new_active_indices)
self._active_path_indices = new_active_indices

def _collect_qubit_samples(self, path_indices: list[int], qubit_idx: int) -> list[np.ndarray]:
paths = [self._paths[idx] for idx in path_indices]
if self._simulator is not None and self._simulator.parallelize_paths and len(paths) > 1:
with ThreadPoolExecutor() as pool:
return list(pool.map(lambda path: self._get_qubit_samples(path, qubit_idx), paths))
return [self._get_qubit_samples(path, qubit_idx) for path in paths]

def _branch_single_qubit(
self, path_idx: int, qubit_idx: int, new_active_indices: list[int]
self,
path_idx: int,
qubit_idx: int,
qubit_samples: np.ndarray,
new_active_indices: list[int],
) -> None:
"""Branch a single path on a single qubit measurement."""
path = self._paths[path_idx]

# Defer to the concrete simulator to sample the target qubit's bit for
# each of ``path.shots`` shots; then the shot-split is just a tally.
qubit_samples = self._get_qubit_samples(path, qubit_idx)
path_shots = path.shots
shots_for_1 = int(np.sum(qubit_samples))
shots_for_0 = path_shots - shots_for_1
Expand Down
49 changes: 42 additions & 7 deletions src/braket/default_simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from typing import Any

import numpy as np
Expand Down Expand Up @@ -163,6 +164,11 @@ def run(
return self.run_program_set(circuit_ir, *args, **kwargs)
return self.run_jaqcd(circuit_ir, *args, **kwargs)

@property
def parallelize_paths(self) -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an int to allow for configurable thread count?

"""bool: Whether to run path simulations in parallel."""
return False

def create_program_context(self) -> AbstractProgramContext:
return ProgramContext(simulator=self)

Expand Down Expand Up @@ -878,14 +884,29 @@ def _run_branched(
if circuit.qubit_set:
sim_qubit_count = max(sim_qubit_count, max(circuit.qubit_set) + 1)

# Aggregate samples across all active paths
paths = list(context.active_paths)
if self.parallelize_paths and len(paths) > 1:
with ThreadPoolExecutor() as pool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the performance impact of this verse letting the simulator parallelize each serial simulation?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question. In the beginning there is only one path, and the paths are only added during a path simulation. Would it cause a conflict? Is there really a performance gain?

per_path_samples = list(
pool.map(
_evolve_path_and_sample,
[self] * len(paths),
[path.instructions for path in paths],
[sim_qubit_count] * len(paths),
[path.shots for path in paths],
[batch_size] * len(paths),
)
)
else:
per_path_samples = [
_evolve_path_and_sample(
self, path.instructions, sim_qubit_count, path.shots, batch_size
)
for path in paths
]
all_samples = []
for path in context.active_paths:
sim = self.initialize_simulation(
qubit_count=sim_qubit_count, shots=path.shots, batch_size=batch_size
)
sim.evolve(path.instructions)
all_samples.extend(sim.retrieve_samples())
for samples in per_path_samples:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use per_path_samples?

all_samples.extend(samples)

# Build measurements in the same format as _formatted_measurements
measurements = [
Expand Down Expand Up @@ -995,3 +1016,17 @@ def run_jaqcd(
)

return self._create_results_obj(results, circuit_ir, simulation)


def _evolve_path_and_sample(
simulator: BaseLocalSimulator,
instructions,
qubit_count: int,
shots: int,
batch_size: int,
):
sim = simulator.initialize_simulation(
qubit_count=qubit_count, shots=shots, batch_size=batch_size
)
sim.evolve(instructions)
return sim.retrieve_samples()
41 changes: 41 additions & 0 deletions test/unit_tests/braket/default_simulator/test_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5030,3 +5030,44 @@ def test_flat_context_preserves_mcm_while_loop(self):
"}"
)
assert Interpreter(context=FlatProgramContext()).run(qasm).circuit == qasm


class TestParallelizePaths:
"""Cover the ``parallelize_paths`` thread-pool branches in
``BaseLocalSimulator._run_branched`` and
``ProgramContext._collect_qubit_samples``."""

class _ParallelStateVectorSimulator(StateVectorSimulator):
@property
def parallelize_paths(self) -> bool:
return True

def test_parallel_paths_produce_expected_distribution(self):
"""A parallelized simulator produces a correct distribution for a
program whose second measurement fires on multiple already-branched
paths — which exercises the thread-pool branch of
``_collect_qubit_samples``."""
qasm = """
OPENQASM 3.0;
qubit[3] q;
bit b0;
bit b1;
h q[0];
b0 = measure q[0];
if (b0 == 1) {
h q[1];
} else {
h q[1];
}
b1 = measure q[1];
if (b1 == 1) {
x q[2];
}
"""
simulator = self._ParallelStateVectorSimulator()
result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=1000)
counts = Counter("".join(m) for m in result.measurements)
# Regardless of b0, q[1] is Hadamarded then measured → 50/50.
# Regardless of b1, q[2] ends in state b1.
assert set(counts.keys()).issubset({"000", "001", "011", "100", "101", "111"})
assert sum(counts.values()) == 1000
Loading