Skip to content

Commit

Permalink
add functional tests (#17)
Browse files Browse the repository at this point in the history
* Add basic functional test

Also update pre-commit

* Add basic quantum workflow

0: -H- <Z>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update changelog

* Add braket cost estimate to quantum workflow

* Rename tests

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cjao and pre-commit-ci[bot] authored Sep 2, 2022
1 parent 52727fd commit 3b95813
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 7 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fail_fast: true

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -34,18 +34,18 @@ repos:
- id: requirements-txt-fixer

- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.8.0
rev: v5.10.1
hooks:
- id: isort
args: ["--profile", "black"]

- repo: https://github.com/ambv/black
rev: 21.5b1
rev: 22.8.0
hooks:
- id: black
language_version: python3.8

- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
rev: 5.0.4
hooks:
- id: flake8
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Enabled Codecov
- Added tests
- Update pre-commit hooks

## [0.4.1] - 2022-08-23

Expand Down
8 changes: 5 additions & 3 deletions tests/functional_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import List
from covalent._results_manager import Result

import covalent as ct
from covalent._results_manager import Result


def test_executor_functional():
@ct.electron(executor="braket")
def hybrid_task(size: int, shots: int, angles: List):
import pennylane as qml
import random
import os
import random

import pennylane as qml

device_arn = os.environ["AMZN_BRAKET_DEVICE_ARN"]
s3_bucket = os.environ["AMZN_BRAKET_OUT_S3_BUCKET"]
Expand Down
91 changes: 91 additions & 0 deletions tests/functional_tests/basic_quantum_workflow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import json
import os
import subprocess

import covalent as ct
from braket.tracking import Tracker

from covalent_braket_plugin.braket import BraketExecutor

terraform_dir = os.getenv("TF_DIR")

proc = subprocess.run(
[
"terraform",
f"-chdir={terraform_dir}",
"output",
"-json",
],
check=True,
capture_output=True,
)

s3_bucket_name = json.loads(proc.stdout.decode())["s3_bucket_name"]["value"]
ecr_repo_name = json.loads(proc.stdout.decode())["ecr_repo_name"]["value"]
iam_role_name = json.loads(proc.stdout.decode())["iam_role_name"]["value"]

credentials_file = os.getenv("AWS_SHARED_CREDENTIALS_FILE")
profile = os.getenv("AWS_PROFILE")

ex = BraketExecutor(
credentials=credentials_file,
profile=profile,
s3_bucket_name=s3_bucket_name,
ecr_repo_name=ecr_repo_name,
braket_job_execution_role_name=iam_role_name,
cache_dir="/tmp/covalent",
poll_freq=30,
quantum_device="arn:aws:braket:::device/quantum-simulator/amazon/sv1",
classical_device="ml.m5.large",
storage=30,
time_limit=300,
)


@ct.electron(executor=ex)
def my_hybrid_task(num_qubits: int):
import pennylane as qml

# These are passed to the Hybrid Jobs container at runtime
device_arn = os.environ["AMZN_BRAKET_DEVICE_ARN"]
s3_bucket = os.environ["AMZN_BRAKET_OUT_S3_BUCKET"]
s3_task_dir = os.environ["AMZN_BRAKET_TASK_RESULTS_S3_URI"].split(s3_bucket)[1]

device = qml.device(
"braket.aws.qubit",
device_arn=device_arn,
s3_destination_folder=(s3_bucket, s3_task_dir),
wires=num_qubits,
)

@qml.qnode(device=device)
def simple_circuit():
qml.Hadamard(wires=[0])
return qml.expval(qml.PauliZ(wires=[0]))

with Tracker() as tracker:
res = simple_circuit().numpy()
return res, tracker


@ct.electron
def get_cost(tracker: Tracker):
return tracker.simulator_tasks_cost()


@ct.lattice
def simple_quantum_workflow(num_qubits: int):
res, tracker = my_hybrid_task(num_qubits=num_qubits)
cost = get_cost(tracker)
return res, cost


dispatch_id = ct.dispatch(simple_quantum_workflow)(1)
print("Dispatch id:", dispatch_id)
result_object = ct.get_result(dispatch_id, wait=True)

res, cost = result_object.result
print("Result:", res)
print("Cost:", cost)

assert res == 0.0
67 changes: 67 additions & 0 deletions tests/functional_tests/basic_workflow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json
import os
import subprocess

import covalent as ct

from covalent_braket_plugin.braket import BraketExecutor

terraform_dir = os.getenv("TF_DIR")

proc = subprocess.run(
[
"terraform",
f"-chdir={terraform_dir}",
"output",
"-json",
],
check=True,
capture_output=True,
)

s3_bucket_name = json.loads(proc.stdout.decode())["s3_bucket_name"]["value"]
ecr_repo_name = json.loads(proc.stdout.decode())["ecr_repo_name"]["value"]
iam_role_name = json.loads(proc.stdout.decode())["iam_role_name"]["value"]

credentials_file = os.getenv("AWS_SHARED_CREDENTIALS_FILE")
profile = os.getenv("AWS_PROFILE")

ex = BraketExecutor(
credentials=credentials_file,
profile=profile,
s3_bucket_name=s3_bucket_name,
ecr_repo_name=ecr_repo_name,
braket_job_execution_role_name=iam_role_name,
cache_dir="/tmp/covalent",
poll_freq=30,
quantum_device="arn:aws:braket:::device/quantum-simulator/amazon/sv1",
classical_device="ml.m5.large",
storage=30,
time_limit=300,
)


@ct.electron(executor=ex)
def join_words(a, b):
return ", ".join([a, b])


@ct.electron(executor="local")
def excitement(a):
return f"{a}!"


# Construct a workflow of tasks
@ct.lattice(executor="local")
def simple_workflow(a, b):
phrase = join_words(a, b)
return excitement(phrase)


dispatch_id = ct.dispatch(simple_workflow)("Hello", "World")

result_object = ct.get_result(dispatch_id, wait=True)
print("Actual result:", result_object.result)
print("Expected result:", "Hello, World!")

assert result_object.result == "Hello, World!"

0 comments on commit 3b95813

Please sign in to comment.