diff --git a/azure-quantum/README.md b/azure-quantum/README.md index dad2d8182..ea4db881d 100644 --- a/azure-quantum/README.md +++ b/azure-quantum/README.md @@ -29,15 +29,14 @@ To get started, visit the following Quickstart guides: ## General usage ## -To connect to your Azure Quantum Workspace, go to the [Azure Portal](https://portal.azure.com), navigate to your Workspace and copy-paste the resource ID and location into the code snippet below. +To connect to your Azure Quantum Workspace, go to the [Azure Portal](https://portal.azure.com), navigate to your Workspace and copy-paste the resource ID into the code snippet below. ```python from azure.quantum import Workspace -# Enter your Workspace details (resource ID and location) below +# Enter your Workspace resource ID below workspace = Workspace( - resource_id="", - location="" + resource_id="" ) ``` diff --git a/azure-quantum/azure/quantum/_constants.py b/azure-quantum/azure/quantum/_constants.py index 2c7ca17ae..4eeb32bca 100644 --- a/azure-quantum/azure/quantum/_constants.py +++ b/azure-quantum/azure/quantum/_constants.py @@ -55,6 +55,9 @@ class ConnectionConstants: DATA_PLANE_CREDENTIAL_SCOPE = "https://quantum.microsoft.com/.default" ARM_CREDENTIAL_SCOPE = "https://management.azure.com/.default" + DEFAULT_ARG_API_VERSION = "2021-03-01" + DEFAULT_WORKSPACE_API_VERSION = "2025-11-01-preview" + MSA_TENANT_ID = "9188040d-6c67-4c5b-b112-36a304b66dad" AUTHORITY = AzureIdentityInternals.get_default_authority() @@ -63,10 +66,14 @@ class ConnectionConstants: # pylint: disable=unnecessary-lambda-assignment GET_QUANTUM_PRODUCTION_ENDPOINT = \ lambda location: f"https://{location}.quantum.azure.com/" + GET_QUANTUM_PRODUCTION_ENDPOINT_v2 = \ + lambda location: f"https://{location}-v2.quantum.azure.com/" GET_QUANTUM_CANARY_ENDPOINT = \ lambda location: f"https://{location or 'eastus2euap'}.quantum.azure.com/" GET_QUANTUM_DOGFOOD_ENDPOINT = \ lambda location: f"https://{location}.quantum-test.azure.com/" + GET_QUANTUM_DOGFOOD_ENDPOINT_v2 = \ + lambda location: f"https://{location}-v2.quantum-test.azure.com/" ARM_PRODUCTION_ENDPOINT = "https://management.azure.com/" ARM_DOGFOOD_ENDPOINT = "https://api-dogfood.resources.windows-int.net/" @@ -93,3 +100,65 @@ class ConnectionConstants: GUID_REGEX_PATTERN = ( r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" ) + +VALID_WORKSPACE_NAME_PATTERN = r"^[a-zA-Z0-9]+(-*[a-zA-Z0-9])*$" + +VALID_AZURE_REGIONS = { + "australiacentral", + "australiacentral2", + "australiaeast", + "australiasoutheast", + "austriaeast", + "belgiumcentral", + "brazilsouth", + "brazilsoutheast", + "canadacentral", + "canadaeast", + "centralindia", + "centralus", + "centraluseuap", + "chilecentral", + "eastasia", + "eastus", + "eastus2", + "eastus2euap", + "francecentral", + "francesouth", + "germanynorth", + "germanywestcentral", + "indonesiacentral", + "israelcentral", + "italynorth", + "japaneast", + "japanwest", + "koreacentral", + "koreasouth", + "malaysiawest", + "mexicocentral", + "newzealandnorth", + "northcentralus", + "northeurope", + "norwayeast", + "norwaywest", + "polandcentral", + "qatarcentral", + "southafricanorth", + "southafricawest", + "southcentralus", + "southindia", + "southeastasia", + "spaincentral", + "swedencentral", + "switzerlandnorth", + "switzerlandwest", + "uaecentral", + "uaenorth", + "uksouth", + "ukwest", + "westcentralus", + "westeurope", + "westindia", + "westus", + "westus2", + "westus3", +} diff --git a/azure-quantum/azure/quantum/_mgmt_client.py b/azure-quantum/azure/quantum/_mgmt_client.py new file mode 100644 index 000000000..0c520eb14 --- /dev/null +++ b/azure-quantum/azure/quantum/_mgmt_client.py @@ -0,0 +1,239 @@ +## +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +## +""" +Module providing the WorkspaceMgmtClient class for managing workspace operations. +Created to do not add additional azure-mgmt-* dependencies that can conflict with existing ones. +""" + +import logging +from http import HTTPStatus +from typing import Any, Optional, cast +from azure.core import PipelineClient +from azure.core.credentials import TokenProvider +from azure.core.pipeline import policies +from azure.core.rest import HttpRequest +from azure.core.exceptions import HttpResponseError +from azure.quantum._workspace_connection_params import WorkspaceConnectionParams +from azure.quantum._constants import ConnectionConstants +from azure.quantum._client._configuration import VERSION + +logger = logging.getLogger(__name__) + +__all__ = ["WorkspaceMgmtClient"] + + +class WorkspaceMgmtClient(): + """ + Client for Azure Quantum Workspace related ARM/ARG operations. + Uses PipelineClient under the hood which is standard for all Azure SDK clients, + see https://learn.microsoft.com/en-us/azure/developer/python/sdk/fundamentals/http-pipeline-retries. + + :param credential: + The credential to use to connect to Azure services. + + :param base_url: + The base URL for the ARM endpoint. + + :param user_agent: + Add the specified value as a prefix to the HTTP User-Agent header. + """ + + # Constants + DEFAULT_RETRY_TOTAL = 3 + CONTENT_TYPE_JSON = "application/json" + CONNECT_DOC_LINK = "https://learn.microsoft.com/en-us/azure/quantum/how-to-connect-workspace" + CONNECT_DOC_MESSAGE = f"To find details on how to connect to your workspace, please see {CONNECT_DOC_LINK}." + + def __init__(self, credential: TokenProvider, base_url: str, user_agent: Optional[str] = None) -> None: + """ + Initialize the WorkspaceMgmtClient. + + :param credential: + The credential to use to connect to Azure services. + + :param base_url: + The base URL for the ARM endpoint. + """ + self._credential = credential + self._base_url = base_url + self._policies = [ + policies.RequestIdPolicy(), + policies.HeadersPolicy({ + "Content-Type": self.CONTENT_TYPE_JSON, + "Accept": self.CONTENT_TYPE_JSON, + }), + policies.UserAgentPolicy(user_agent=user_agent, sdk_moniker="quantum/{}".format(VERSION)), + policies.RetryPolicy(retry_total=self.DEFAULT_RETRY_TOTAL), + policies.BearerTokenCredentialPolicy(self._credential, ConnectionConstants.ARM_CREDENTIAL_SCOPE), + ] + self._client: PipelineClient = PipelineClient(base_url=cast(str, base_url), policies=self._policies) + + def close(self) -> None: + self._client.close() + + def __enter__(self) -> 'WorkspaceMgmtClient': + self._client.__enter__() + return self + + def __exit__(self, *exc_details: Any) -> None: + self._client.__exit__(*exc_details) + + def load_workspace_from_arg(self, connection_params: WorkspaceConnectionParams) -> None: + """ + Queries Azure Resource Graph to find a workspace by name and optionally location, resource group, subscription. + Provided workspace name, location, resource group, and subscription in connection params must be validated beforehand. + + :param connection_params: + The workspace connection parameters to use and update. + """ + if not connection_params.workspace_name: + raise ValueError("Workspace name must be specified to try to load workspace details from ARG.") + + query = f""" + Resources + | where type =~ 'microsoft.quantum/workspaces' + | where name =~ '{connection_params.workspace_name}' + """ + + if connection_params.resource_group: + query += f"\n | where resourceGroup =~ '{connection_params.resource_group}'" + + if connection_params.location: + query += f"\n | where location =~ '{connection_params.location}'" + + query += """ + | extend endpointUri = tostring(properties.endpointUri) + | project name, subscriptionId, resourceGroup, location, endpointUri + """ + + request_body = { + "query": query + } + + if connection_params.subscription_id: + request_body["subscriptions"] = [connection_params.subscription_id] + + # Create request to Azure Resource Graph API + request = HttpRequest( + method="POST", + url=self._client.format_url("/providers/Microsoft.ResourceGraph/resources"), + params={"api-version": ConnectionConstants.DEFAULT_ARG_API_VERSION}, + json=request_body + ) + + try: + response = self._client.send_request(request) + response.raise_for_status() + result = response.json() + except Exception as e: + raise RuntimeError( + f"Could not load workspace details from Azure Resource Graph: {str(e)}.\n{self.CONNECT_DOC_MESSAGE}" + ) from e + + data = result.get('data', []) + + if not data: + raise ValueError(f"No matching workspace found with name '{connection_params.workspace_name}'. {self.CONNECT_DOC_MESSAGE}") + + if len(data) > 1: + raise ValueError( + f"Multiple Azure Quantum workspaces found with name '{connection_params.workspace_name}'. " + f"Please specify additional connection parameters. {self.CONNECT_DOC_MESSAGE}" + ) + + workspace_data = data[0] + + connection_params.subscription_id = workspace_data.get('subscriptionId') + connection_params.resource_group = workspace_data.get('resourceGroup') + connection_params.location = workspace_data.get('location') + connection_params.quantum_endpoint = workspace_data.get('endpointUri') + + logger.debug( + "Found workspace '%s' in subscription '%s', resource group '%s', location '%s', endpoint '%s'", + connection_params.workspace_name, + connection_params.subscription_id, + connection_params.resource_group, + connection_params.location, + connection_params.quantum_endpoint + ) + + # If one of the required parameters is missing, probably workspace in failed provisioning state + if not connection_params.is_complete(): + raise ValueError( + f"Failed to retrieve complete workspace details for workspace '{connection_params.workspace_name}'. " + "Please check that workspace is in valid state." + ) + + def load_workspace_from_arm(self, connection_params: WorkspaceConnectionParams) -> None: + """ + Fetches the workspace resource from ARM and sets location and endpoint URI params. + Provided workspace name, resource group, and subscription in connection params must be validated beforehand. + + :param connection_params: + The workspace connection parameters to use and update. + """ + if not all([connection_params.subscription_id, connection_params.resource_group, connection_params.workspace_name]): + raise ValueError("Missing required connection parameters to load workspace details from ARM.") + + api_version = connection_params.api_version or ConnectionConstants.DEFAULT_WORKSPACE_API_VERSION + + url = ( + f"/subscriptions/{connection_params.subscription_id}" + f"/resourceGroups/{connection_params.resource_group}" + f"/providers/Microsoft.Quantum/workspaces/{connection_params.workspace_name}" + ) + + request = HttpRequest( + method="GET", + url=self._client.format_url(url), + params={"api-version": api_version}, + ) + + try: + response = self._client.send_request(request) + response.raise_for_status() + workspace_data = response.json() + except HttpResponseError as e: + if e.status_code == HTTPStatus.NOT_FOUND: + raise ValueError( + f"Azure Quantum workspace '{connection_params.workspace_name}' " + f"not found in resource group '{connection_params.resource_group}' " + f"and subscription '{connection_params.subscription_id}'. " + f"{self.CONNECT_DOC_MESSAGE}" + ) from e + # Re-raise for other HTTP errors + raise + except Exception as e: + raise RuntimeError( + f"Could not load workspace details from ARM: {str(e)}.\n{self.CONNECT_DOC_MESSAGE}" + ) from e + + # Extract and apply location + location = workspace_data.get("location") + if location: + connection_params.location = location + logger.debug( + "Updated workspace location from ARM: %s", + location + ) + else: + raise ValueError( + f"Failed to retrieve location for workspace '{connection_params.workspace_name}'. " + f"Please check that workspace is in valid state." + ) + + # Extract and apply endpoint URI from properties + properties = workspace_data.get("properties", {}) + endpoint_uri = properties.get("endpointUri") + if endpoint_uri: + connection_params.quantum_endpoint = endpoint_uri + logger.debug( + "Updated workspace endpoint from ARM: %s", connection_params.quantum_endpoint + ) + else: + raise ValueError( + f"Failed to retrieve endpoint uri for workspace '{connection_params.workspace_name}'. " + f"Please check that workspace is in valid state." + ) diff --git a/azure-quantum/azure/quantum/_workspace_connection_params.py b/azure-quantum/azure/quantum/_workspace_connection_params.py index 8411fde23..db38836c6 100644 --- a/azure-quantum/azure/quantum/_workspace_connection_params.py +++ b/azure-quantum/azure/quantum/_workspace_connection_params.py @@ -20,6 +20,8 @@ EnvironmentVariables, ConnectionConstants, GUID_REGEX_PATTERN, + VALID_WORKSPACE_NAME_PATTERN, + VALID_AZURE_REGIONS, ) class WorkspaceConnectionParams: @@ -46,9 +48,19 @@ class WorkspaceConnectionParams: ResourceGroupName=(?P[^\s;]+); WorkspaceName=(?P[^\s;]+); ApiKey=(?P[^\s;]+); - QuantumEndpoint=(?Phttps://(?P[^\s\.]+).quantum(?:-test)?.azure.com/); + QuantumEndpoint=(?Phttps://(?P[^\s\.]+?)(?:-v2)?.quantum(?:-test)?.azure.com/); """, re.VERBOSE | re.IGNORECASE) + + WORKSPACE_NOT_FULLY_SPECIFIED_MSG = """ + Azure Quantum workspace not fully specified. + Please specify one of the following: + 1) A valid resource ID. + 2) A valid combination of subscription ID, + resource group name, and workspace name. + 3) A valid connection string (via Workspace.from_connection_string()). + 4) A valid workspace name. + """ def __init__( self, @@ -85,6 +97,8 @@ def __init__( self.client_id = None self.tenant_id = None self.api_version = None + # Track if connection string was used + self._used_connection_string = False # callback to create a new client if needed # for example, when changing the user agent self.on_new_client_request = on_new_client_request @@ -108,6 +122,81 @@ def __init__( workspace_name=workspace_name, ) self.apply_resource_id(resource_id=resource_id) + # Validate connection parameters if they are set + self._validate_connection_params() + + def _validate_connection_params(self): + self._validate_subscription_id() + self._validate_resource_group() + self._validate_workspace_name() + self._validate_location() + + def _validate_subscription_id(self): + # Validate that subscription id is a valid GUID + if self.subscription_id is not None: + if not isinstance(self.subscription_id, str): + raise ValueError("Subscription ID must be a string.") + if not re.match(f"^{GUID_REGEX_PATTERN}$", self.subscription_id, re.IGNORECASE): + raise ValueError("Subscription ID must be a valid GUID.") + + def _validate_resource_group(self): + # Validate resource group, see https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules#microsoftresources + # Length 1-90, valid characters: alphanumeric, underscore, parentheses, hyphen, period (except at end), and Unicode characters: + # Uppercase Letter - Signified by the Unicode designation "Lu" (letter, uppercase); + # Lowercase Letter - Signified by the Unicode designation "Ll" (letter, lowercase); + # Titlecase Letter - Signified by the Unicode designation "Lt" (letter, titlecase); + # Modifier Letter - Signified by the Unicode designation "Lm" (letter, modifier); + # Other Letter - Signified by the Unicode designation "Lo" (letter, other); + # Decimal Digit Number - Signified by the Unicode designation "Nd" (number, decimal digit). + if self.resource_group is not None: + if not isinstance(self.resource_group, str): + raise ValueError("Resource group name must be a string.") + + if len(self.resource_group) < 1 or len(self.resource_group) > 90: + raise ValueError( + "Resource group name must be between 1 and 90 characters long." + ) + + err_msg = "Resource group name can only include alphanumeric, underscore, parentheses, hyphen, period (except at end), and Unicode characters that match the allowed characters." + if self.resource_group.endswith('.'): + raise ValueError(err_msg) + + import unicodedata + for i, char in enumerate(self.resource_group): + category = unicodedata.category(char) + if not ( + char in ('_', '(', ')', '-', '.') or + category in ('Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Nd') + ): + raise ValueError(err_msg) + + def _validate_workspace_name(self): + # Validate workspace name, see https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules#microsoftquantum + # Length 2-54, valid characters: alphanumerics (a-zA-Z0-9) and hyphens, can't start or end with hyphen + if self.workspace_name is not None: + if not isinstance(self.workspace_name, str): + raise ValueError("Workspace name must be a string.") + + if len(self.workspace_name) < 2 or len(self.workspace_name) > 54: + raise ValueError( + "Workspace name must be between 2 and 54 characters long." + ) + + err_msg = "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen." + + if self.workspace_name.startswith('-') or self.workspace_name.endswith('-'): + raise ValueError(err_msg) + + if not re.match(VALID_WORKSPACE_NAME_PATTERN, self.workspace_name): + raise ValueError(err_msg) + + def _validate_location(self): + # Validate that location is one of the Azure regions https://learn.microsoft.com/en-us/azure/reliability/regions-list + if self.location is not None: + if not isinstance(self.location, str): + raise ValueError("Location must be a string.") + if self.location not in VALID_AZURE_REGIONS: + raise ValueError(f"Location must be one of the Azure regions listed in https://learn.microsoft.com/en-us/azure/reliability/regions-list.") @property def location(self): @@ -142,19 +231,8 @@ def environment(self, value: Union[str, EnvironmentKind]): def quantum_endpoint(self): """ The Azure Quantum data plane endpoint. - Defaults to well-known endpoint based on the environment. - """ - if self._quantum_endpoint: - return self._quantum_endpoint - if not self.location: - raise ValueError("Location not specified") - if self.environment is EnvironmentKind.PRODUCTION: - return ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT(self.location) - if self.environment is EnvironmentKind.CANARY: - return ConnectionConstants.GET_QUANTUM_CANARY_ENDPOINT(self.location) - if self.environment is EnvironmentKind.DOGFOOD: - return ConnectionConstants.GET_QUANTUM_DOGFOOD_ENDPOINT(self.location) - raise ValueError(f"Unknown environment `{self.environment}`.") + """ + return self._quantum_endpoint @quantum_endpoint.setter def quantum_endpoint(self, value: str): @@ -235,6 +313,7 @@ def apply_connection_string(self, connection_string: str): if not match: raise ValueError("Invalid connection string") self._merge_re_match(match) + self._used_connection_string = True def merge( self, @@ -450,6 +529,32 @@ def get_full_user_agent(self): full_user_agent = (f"{app_id} {full_user_agent}" if full_user_agent else app_id) return full_user_agent + + def have_enough_for_discovery(self) -> bool: + """ + Returns true if we have enough parameters + to try to find the Azure Quantum Workspace. + """ + return (self.workspace_name + and self.get_credential_or_default()) + + def assert_have_enough_for_discovery(self): + """ + Raises ValueError if we don't have enough parameters + to try to find the Azure Quantum Workspace. + """ + if not self.have_enough_for_discovery(): + raise ValueError(self.WORKSPACE_NOT_FULLY_SPECIFIED_MSG) + + def can_build_resource_id(self) -> bool: + """ + Returns true if we have all necessary parameters + to identify the Azure Quantum Workspace resource. + """ + return (self.subscription_id + and self.resource_group + and self.workspace_name + and self.get_credential_or_default()) def is_complete(self) -> bool: """ @@ -460,6 +565,7 @@ def is_complete(self) -> bool: and self.subscription_id and self.resource_group and self.workspace_name + and self.quantum_endpoint and self.get_credential_or_default()) def assert_complete(self): @@ -468,15 +574,7 @@ def assert_complete(self): to connect to the Azure Quantum Workspace. """ if not self.is_complete(): - raise ValueError( - """ - Azure Quantum workspace not fully specified. - Please specify one of the following: - 1) A valid combination of location and resource ID. - 2) A valid combination of location, subscription ID, - resource group name, and workspace name. - 3) A valid connection string (via Workspace.from_connection_string()). - """) + raise ValueError(self.WORKSPACE_NOT_FULLY_SPECIFIED_MSG) def default_from_env_vars(self) -> WorkspaceConnectionParams: """ @@ -512,10 +610,13 @@ def default_from_env_vars(self) -> WorkspaceConnectionParams: or not self.workspace_name or not self.credential ): - self._merge_connection_params( - connection_params=WorkspaceConnectionParams( - connection_string=os.environ.get(EnvironmentVariables.CONNECTION_STRING)), - merge_default_mode=True) + env_connection_string = os.environ.get(EnvironmentVariables.CONNECTION_STRING) + if env_connection_string: + self._merge_connection_params( + connection_params=WorkspaceConnectionParams( + connection_string=env_connection_string), + merge_default_mode=True) + self._used_connection_string = True return self @classmethod diff --git a/azure-quantum/azure/quantum/workspace.py b/azure-quantum/azure/quantum/workspace.py index 284a9d5ef..ac7d54ca6 100644 --- a/azure-quantum/azure/quantum/workspace.py +++ b/azure-quantum/azure/quantum/workspace.py @@ -21,6 +21,7 @@ Tuple, Union, ) +from typing_extensions import Self from azure.core.paging import ItemPaged from azure.quantum._client import ServicesClient from azure.quantum._client.models import JobDetails, ItemDetails, SessionDetails @@ -49,6 +50,7 @@ get_container_uri, ContainerClient ) +from azure.quantum._mgmt_client import WorkspaceMgmtClient if TYPE_CHECKING: from azure.quantum.target import Target @@ -62,10 +64,11 @@ class Workspace: """ Represents an Azure Quantum workspace. - When creating a Workspace object, callers have two options for + When creating a Workspace object, callers have several options for identifying the Azure Quantum workspace (in order of precedence): - 1. specify a valid location and resource ID; or - 2. specify a valid location, subscription ID, resource group, and workspace name. + 1. specify a valid resource ID; or + 2. specify a valid subscription ID, resource group, and workspace name; or + 3. specify a valid workspace name. You can also use a connection string to specify the connection parameters to an Azure Quantum Workspace by calling @@ -110,6 +113,12 @@ class Workspace: Add the specified value as a prefix to the HTTP User-Agent header when communicating to the Azure Quantum service. """ + + # Internal parameter names + _FROM_CONNECTION_STRING_PARAM = '_from_connection_string' + _QUANTUM_ENDPOINT_PARAM = '_quantum_endpoint' + _MGMT_CLIENT_PARAM = '_mgmt_client' + def __init__( self, subscription_id: Optional[str] = None, @@ -122,6 +131,14 @@ def __init__( user_agent: Optional[str] = None, **kwargs: Any, ) -> None: + # Extract internal params before passing kwargs to WorkspaceConnectionParams + # Param to track whether the workspace was created from a connection string + from_connection_string = kwargs.pop(Workspace._FROM_CONNECTION_STRING_PARAM, False) + # In case from connection string, quantum_endpoint must be passed + quantum_endpoint = kwargs.pop(Workspace._QUANTUM_ENDPOINT_PARAM, None) + # Params to pass a mock in tests + self._mgmt_client = kwargs.pop(Workspace._MGMT_CLIENT_PARAM, None) + connection_params = WorkspaceConnectionParams( location=location, subscription_id=subscription_id, @@ -129,13 +146,14 @@ def __init__( workspace_name=name, credential=credential, resource_id=resource_id, + quantum_endpoint=quantum_endpoint, user_agent=user_agent, **kwargs ).default_from_env_vars() logger.info("Using %s environment.", connection_params.environment) - connection_params.assert_complete() + connection_params.assert_have_enough_for_discovery() connection_params.on_new_client_request = self._on_new_client_request @@ -145,6 +163,32 @@ def __init__( self._resource_group = connection_params.resource_group self._workspace_name = connection_params.workspace_name + if not self._mgmt_client: + credential = connection_params.get_credential_or_default() + self._mgmt_client = WorkspaceMgmtClient( + credential=credential, + base_url=connection_params.arm_endpoint, + user_agent=connection_params.get_full_user_agent(), + ) + + # pylint: disable=protected-access + using_connection_string = ( + from_connection_string + or connection_params._used_connection_string + ) + + # Populate workspace details from ARG if not using connection string and + # name is provided but missing subscription and/or resource group + if not using_connection_string \ + and not connection_params.can_build_resource_id(): + self._mgmt_client.load_workspace_from_arg(connection_params) + + # Populate workspace details from ARM if not using connection string and not loaded from ARG + if not using_connection_string and not connection_params.is_complete(): + self._mgmt_client.load_workspace_from_arm(connection_params) + + connection_params.assert_complete() + # Create QuantumClient self._client = self._create_client() @@ -277,6 +321,8 @@ def from_connection_string(cls, connection_string: str, **kwargs) -> Workspace: :rtype: Workspace """ connection_params = WorkspaceConnectionParams(connection_string=connection_string) + kwargs[cls._FROM_CONNECTION_STRING_PARAM] = True + kwargs[cls._QUANTUM_ENDPOINT_PARAM] = connection_params.quantum_endpoint return cls( subscription_id=connection_params.subscription_id, resource_group=connection_params.resource_group, @@ -1023,4 +1069,16 @@ def _create_orderby(self, orderby_property: str, is_asc: bool) -> str: return orderby else: return None - \ No newline at end of file + + def close(self) -> None: + self._mgmt_client.close() + self._client.close() + + def __enter__(self) -> Self: + self._client.__enter__() + self._mgmt_client.__enter__() + return self + + def __exit__(self, *exc_details: Any) -> None: + self._mgmt_client.__exit__(*exc_details) + self._client.__exit__(*exc_details) diff --git a/azure-quantum/tests/unit/common.py b/azure-quantum/tests/unit/common.py index 49b642427..7838a46ce 100644 --- a/azure-quantum/tests/unit/common.py +++ b/azure-quantum/tests/unit/common.py @@ -7,7 +7,8 @@ import os import json import time -from unittest.mock import patch +from typing import Optional, Any +from unittest.mock import patch, MagicMock from vcr.request import Request as VcrRequest from azure_devtools.scenario_tests.base import ReplayableTest @@ -40,6 +41,7 @@ WORKSPACE = "myworkspace" LOCATION = "eastus" STORAGE = "mystorage" +ENDPOINT_URI = f"https://{LOCATION}.quantum.azure.com/" # TODO change to f"https://{WORKSPACE}.{LOCATION}.quantum.azure.com/" when recordings are removed to follow format returned by ARM. Currently set to format which is used in connection string to avoid mass updates of recordings. API_KEY = "myapikey" APP_ID = "testapp" DEFAULT_TIMEOUT_SECS = 300 @@ -285,6 +287,27 @@ def clear_env_vars(self, os_environ): if env_var in os_environ: del os_environ[env_var] + def create_mock_mgmt_client(self) -> MagicMock: + """ + Create a mock WorkspaceMgmtClient to avoid ARM/ARG calls during tests. + """ + mock_mgmt_client = MagicMock() + + def mock_load_workspace_from_arm(connection_params): + connection_params.location = LOCATION + connection_params.quantum_endpoint = ENDPOINT_URI + + def mock_load_workspace_from_arg(connection_params): + connection_params.subscription_id = SUBSCRIPTION_ID + connection_params.resource_group = RESOURCE_GROUP + connection_params.location = LOCATION + connection_params.quantum_endpoint = ENDPOINT_URI + + mock_mgmt_client.load_workspace_from_arm = mock_load_workspace_from_arm + mock_mgmt_client.load_workspace_from_arg = mock_load_workspace_from_arg + + return mock_mgmt_client + def create_workspace( self, credential = None, @@ -303,6 +326,11 @@ def create_workspace( client_id=ZERO_UID, client_secret=PLACEHOLDER) + mock_mgmt_client = None + # When not in live mode use object mock instead of recording as we are going to get rid of the recordings anyway + if not self.is_live: + mock_mgmt_client = self.create_mock_mgmt_client() + workspace = Workspace( credential=credential, subscription_id=connection_params.subscription_id, @@ -310,11 +338,47 @@ def create_workspace( name=connection_params.workspace_name, location=connection_params.location, user_agent=connection_params.user_agent_app_id, + _mgmt_client=mock_mgmt_client, **kwargs ) return workspace + def create_workspace_with_params( + self, + subscription_id: Optional[str] = None, + resource_group: Optional[str] = None, + name: Optional[str] = None, + storage: Optional[str] = None, + resource_id: Optional[str] = None, + location: Optional[str] = None, + credential: Optional[object] = None, + user_agent: Optional[str] = None, + **kwargs: Any) -> Workspace: + """ + Create workspace with explicit parameters, using a mock management client + when not in live mode. + """ + mock_mgmt_client = None + # When not in live mode use object mock instead of recording as we are going to get rid of the recordings anyway + if not self.is_live: + mock_mgmt_client = self.create_mock_mgmt_client() + + workspace = Workspace( + subscription_id=subscription_id, + resource_group=resource_group, + name=name, + storage=storage, + resource_id=resource_id, + location=location, + credential=credential, + user_agent=user_agent, + _mgmt_client=mock_mgmt_client, + **kwargs + ) + + return workspace + def create_echo_target( self, credential = None, diff --git a/azure-quantum/tests/unit/test_cirq.py b/azure-quantum/tests/unit/test_cirq.py index c4b3803e1..eac59c253 100644 --- a/azure-quantum/tests/unit/test_cirq.py +++ b/azure-quantum/tests/unit/test_cirq.py @@ -14,7 +14,7 @@ from azure.quantum.cirq import AzureQuantumService from azure.quantum.cirq.targets.target import Target -from common import QuantumTestBase, ONE_UID, LOCATION, DEFAULT_TIMEOUT_SECS +from common import QuantumTestBase, ONE_UID, DEFAULT_TIMEOUT_SECS from test_workspace import SIMPLE_RESOURCE_ID class TestCirq(QuantumTestBase): @@ -64,9 +64,7 @@ def test_cirq_service_init_with_workspace_not_raises_deprecation(self): # Cause all warnings to always be triggered. warnings.simplefilter("always") # Try to trigger a warning. - workspace = Workspace( - resource_id=SIMPLE_RESOURCE_ID, - location=LOCATION) + workspace = self.create_workspace_with_params(resource_id=SIMPLE_RESOURCE_ID) AzureQuantumService(workspace) # Verify @@ -75,36 +73,38 @@ def test_cirq_service_init_with_workspace_not_raises_deprecation(self): def test_cirq_service_init_without_workspace_raises_deprecation(self): # testing warning according to https://docs.python.org/3/library/warnings.html#testing-warnings import warnings - - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. - warnings.simplefilter("always") - # Try to trigger a warning. - AzureQuantumService( - resource_id=SIMPLE_RESOURCE_ID, - location=LOCATION) - # Verify - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "Consider passing \"workspace\" argument explicitly" in str(w[-1].message) - - # Validate rising deprecation warning even if workspace is passed, but other parameters are also passed - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. - warnings.simplefilter("always") - # Try to trigger a warning. - workspace = Workspace( - resource_id=SIMPLE_RESOURCE_ID, - location=LOCATION) - - AzureQuantumService( - workspace=workspace, + from unittest.mock import patch + + # Create mock mgmt_client to avoid ARM calls + mock_mgmt_client = self.create_mock_mgmt_client() + + with patch('azure.quantum.workspace.WorkspaceMgmtClient', return_value=mock_mgmt_client): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + # Try to trigger a warning. + AzureQuantumService(resource_id=SIMPLE_RESOURCE_ID) + # Verify + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "Consider passing \"workspace\" argument explicitly" in str(w[-1].message) + + # Validate rising deprecation warning even if workspace is passed, but other parameters are also passed + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + # Try to trigger a warning. + workspace = Workspace( resource_id=SIMPLE_RESOURCE_ID, - location=LOCATION) - # Verify - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - assert "Consider passing \"workspace\" argument explicitly" in str(w[-1].message) + _mgmt_client=mock_mgmt_client) + + AzureQuantumService( + workspace=workspace, + resource_id=SIMPLE_RESOURCE_ID) + # Verify + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "Consider passing \"workspace\" argument explicitly" in str(w[-1].message) @pytest.mark.quantinuum @pytest.mark.ionq diff --git a/azure-quantum/tests/unit/test_mgmt_client.py b/azure-quantum/tests/unit/test_mgmt_client.py new file mode 100644 index 000000000..291b759e8 --- /dev/null +++ b/azure-quantum/tests/unit/test_mgmt_client.py @@ -0,0 +1,431 @@ +## +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +## +""" +Unit tests for the WorkspaceMgmtClient class. +""" + +import pytest +from unittest.mock import MagicMock, patch +from http import HTTPStatus +from azure.core.exceptions import HttpResponseError +from azure.quantum._mgmt_client import WorkspaceMgmtClient +from azure.quantum._workspace_connection_params import WorkspaceConnectionParams +from azure.quantum._constants import ConnectionConstants +from common import ( + SUBSCRIPTION_ID, + RESOURCE_GROUP, + WORKSPACE, + LOCATION, + ENDPOINT_URI, +) + + +class TestWorkspaceMgmtClient: + """Test suite for WorkspaceMgmtClient class.""" + + @pytest.fixture + def mock_credential(self): + """Create a mock credential.""" + return MagicMock() + + @pytest.fixture + def base_url(self): + """Return the ARM base URL.""" + return ConnectionConstants.ARM_PRODUCTION_ENDPOINT + + @pytest.fixture + def mgmt_client(self, mock_credential, base_url): + """Create a WorkspaceMgmtClient instance.""" + return WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + @pytest.fixture + def connection_params(self): + """Create a WorkspaceConnectionParams instance.""" + return WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + ) + + def test_init_creates_client(self, mock_credential, base_url): + """Test that initialization creates a properly configured client.""" + client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url, + user_agent="test-agent" + ) + + assert client._credential == mock_credential + assert client._base_url == base_url + assert client._client is not None + assert len(client._policies) == 5 + + def test_init_without_user_agent(self, mock_credential, base_url): + """Test initialization without user agent.""" + client = WorkspaceMgmtClient( + credential=mock_credential, + base_url=base_url + ) + + assert client._credential == mock_credential + assert client._base_url == base_url + assert client._client is not None + + def test_context_manager_enter(self, mgmt_client): + """Test __enter__ returns self.""" + with patch.object(mgmt_client._client, '__enter__', return_value=mgmt_client._client): + result = mgmt_client.__enter__() + assert result == mgmt_client + + def test_context_manager_exit(self, mgmt_client): + """Test __exit__ calls client exit.""" + with patch.object(mgmt_client._client, '__exit__') as mock_exit: + mgmt_client.__exit__(None, None, None) + mock_exit.assert_called_once() + + def test_close(self, mgmt_client): + """Test close method calls client close.""" + with patch.object(mgmt_client._client, 'close') as mock_close: + mgmt_client.close() + mock_close.assert_called_once() + + def test_load_workspace_from_arg_success(self, mgmt_client, connection_params): + """Test successful workspace loading from ARG.""" + # Setup mock response + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + # Clear some params to test ARG fills them + connection_params.subscription_id = None + connection_params.location = None + connection_params.quantum_endpoint = None + + mgmt_client.load_workspace_from_arg(connection_params) + + assert connection_params.subscription_id == SUBSCRIPTION_ID + assert connection_params.resource_group == RESOURCE_GROUP + assert connection_params.workspace_name == WORKSPACE + assert connection_params.location == LOCATION + assert connection_params.quantum_endpoint == ENDPOINT_URI + + def test_load_workspace_from_arg_with_resource_group_filter(self, mgmt_client): + """Test ARG query includes resource group filter when provided.""" + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE, + resource_group=RESOURCE_GROUP + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + # Verify the request was made and contains resource group filter + call_args = mock_send.call_args + request = call_args[0][0] + assert RESOURCE_GROUP in str(request.content) + + def test_load_workspace_from_arg_with_location_filter(self, mgmt_client): + """Test ARG query includes location filter when provided.""" + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE, + location=LOCATION + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + # Verify the request was made and contains location filter + call_args = mock_send.call_args + request = call_args[0][0] + assert LOCATION in str(request.content) + + def test_load_workspace_from_arg_with_subscription_filter(self, mgmt_client): + """Test ARG query includes subscription filter when provided.""" + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + # Verify the request includes subscriptions filter + call_args = mock_send.call_args + request = call_args[0][0] + request_body = request.content + assert 'subscriptions' in request_body + + def test_load_workspace_from_arg_no_workspace_name(self, mgmt_client): + """Test that missing workspace name raises ValueError.""" + connection_params = WorkspaceConnectionParams() + + with pytest.raises(ValueError, match="Workspace name must be specified"): + mgmt_client.load_workspace_from_arg(connection_params) + + def test_load_workspace_from_arg_no_matching_workspace(self, mgmt_client, connection_params): + """Test error when no matching workspace found.""" + mock_response = MagicMock() + mock_response.json.return_value = {'data': []} + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="No matching workspace found"): + mgmt_client.load_workspace_from_arg(connection_params) + + def test_load_workspace_from_arg_multiple_workspaces(self, mgmt_client, connection_params): + """Test error when multiple workspaces found.""" + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [ + { + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }, + { + 'name': WORKSPACE, + 'subscriptionId': 'another-sub-id', + 'resourceGroup': 'another-rg', + 'location': 'westus', + 'endpointUri': 'https://another.endpoint.com/' + } + ] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Multiple Azure Quantum workspaces found"): + mgmt_client.load_workspace_from_arg(connection_params) + + def test_load_workspace_from_arg_incomplete_workspace_data(self, mgmt_client, connection_params): + """Test error when workspace data is incomplete.""" + # Missing endpointUri + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Failed to retrieve complete workspace details"): + mgmt_client.load_workspace_from_arg(connection_params) + + def test_load_workspace_from_arg_request_exception(self, mgmt_client, connection_params): + """Test handling of request exceptions.""" + with patch.object(mgmt_client._client, 'send_request', side_effect=Exception("Network error")): + with pytest.raises(RuntimeError, match="Could not load workspace details from Azure Resource Graph"): + mgmt_client.load_workspace_from_arg(connection_params) + + def test_load_workspace_from_arm_success(self, mgmt_client, connection_params): + """Test successful workspace loading from ARM.""" + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + # Clear location and endpoint to test ARM fills them + connection_params.location = None + connection_params.quantum_endpoint = None + + mgmt_client.load_workspace_from_arm(connection_params) + + assert connection_params.location == LOCATION + assert connection_params.quantum_endpoint == ENDPOINT_URI + + def test_load_workspace_from_arm_missing_required_params(self, mgmt_client): + """Test error when required connection parameters are missing.""" + connection_params = WorkspaceConnectionParams( + workspace_name=WORKSPACE + ) + + with pytest.raises(ValueError, match="Missing required connection parameters"): + mgmt_client.load_workspace_from_arm(connection_params) + + def test_load_workspace_from_arm_workspace_not_found(self, mgmt_client, connection_params): + """Test error when workspace not found in ARM.""" + mock_error = HttpResponseError() + mock_error.status_code = HTTPStatus.NOT_FOUND + + with patch.object(mgmt_client._client, 'send_request', side_effect=mock_error): + with pytest.raises(ValueError, match="not found in resource group"): + mgmt_client.load_workspace_from_arm(connection_params) + + def test_load_workspace_from_arm_http_error(self, mgmt_client, connection_params): + """Test handling of other HTTP errors.""" + mock_error = HttpResponseError() + mock_error.status_code = HTTPStatus.FORBIDDEN + + with patch.object(mgmt_client._client, 'send_request', side_effect=mock_error): + with pytest.raises(HttpResponseError): + mgmt_client.load_workspace_from_arm(connection_params) + + def test_load_workspace_from_arm_missing_location(self, mgmt_client, connection_params): + """Test error when location is missing in ARM response.""" + mock_response = MagicMock() + mock_response.json.return_value = { + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Failed to retrieve location"): + mgmt_client.load_workspace_from_arm(connection_params) + + def test_load_workspace_from_arm_missing_endpoint(self, mgmt_client, connection_params): + """Test error when endpoint URI is missing in ARM response.""" + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': {} + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response): + with pytest.raises(ValueError, match="Failed to retrieve endpoint uri"): + mgmt_client.load_workspace_from_arm(connection_params) + + def test_load_workspace_from_arm_request_exception(self, mgmt_client, connection_params): + """Test handling of request exceptions from ARM.""" + with patch.object(mgmt_client._client, 'send_request', side_effect=Exception("Network error")): + with pytest.raises(RuntimeError, match="Could not load workspace details from ARM"): + mgmt_client.load_workspace_from_arm(connection_params) + + def test_load_workspace_from_arm_uses_custom_api_version(self, mgmt_client): + """Test that custom API version is used when provided.""" + connection_params = WorkspaceConnectionParams( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_version="2024-01-01" + ) + + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arm(connection_params) + + # Verify the custom API version was used + call_args = mock_send.call_args + request = call_args[0][0] + assert "2024-01-01" in request.url + + def test_load_workspace_from_arm_uses_default_api_version(self, mgmt_client, connection_params): + """Test that default API version is used when not provided.""" + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arm(connection_params) + + # Verify the default API version was used + call_args = mock_send.call_args + request = call_args[0][0] + assert ConnectionConstants.DEFAULT_WORKSPACE_API_VERSION in request.url + + def test_load_workspace_from_arg_constructs_correct_url(self, mgmt_client, connection_params): + """Test that ARG request uses correct URL.""" + mock_response = MagicMock() + mock_response.json.return_value = { + 'data': [{ + 'name': WORKSPACE, + 'subscriptionId': SUBSCRIPTION_ID, + 'resourceGroup': RESOURCE_GROUP, + 'location': LOCATION, + 'endpointUri': ENDPOINT_URI + }] + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arg(connection_params) + + # Verify the request URL + call_args = mock_send.call_args + request = call_args[0][0] + assert "/providers/Microsoft.ResourceGraph/resources" in request.url + assert ConnectionConstants.DEFAULT_ARG_API_VERSION in request.url + + def test_load_workspace_from_arm_constructs_correct_url(self, mgmt_client, connection_params): + """Test that ARM request uses correct URL.""" + mock_response = MagicMock() + mock_response.json.return_value = { + 'location': LOCATION, + 'properties': { + 'endpointUri': ENDPOINT_URI + } + } + + with patch.object(mgmt_client._client, 'send_request', return_value=mock_response) as mock_send: + mgmt_client.load_workspace_from_arm(connection_params) + + # Verify the request URL contains expected components + call_args = mock_send.call_args + request = call_args[0][0] + assert f"/subscriptions/{SUBSCRIPTION_ID}" in request.url + assert f"/resourceGroups/{RESOURCE_GROUP}" in request.url + assert f"/providers/Microsoft.Quantum/workspaces/{WORKSPACE}" in request.url diff --git a/azure-quantum/tests/unit/test_pagination.py b/azure-quantum/tests/unit/test_pagination.py index ae93ed19b..c4c8fea33 100644 --- a/azure-quantum/tests/unit/test_pagination.py +++ b/azure-quantum/tests/unit/test_pagination.py @@ -14,16 +14,14 @@ SUBSCRIPTION_ID, RESOURCE_GROUP, WORKSPACE, - LOCATION, ) class TestWorkspacePagination(QuantumTestBase): def test_filter_valid(self): - ws = Workspace( + ws = self.create_workspace_with_params( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, ) # pylint: disable=protected-access @@ -42,11 +40,10 @@ def test_filter_valid(self): def test_orderby_valid(self): var_names = ["Name", "ItemType", "JobType", "ProviderId", "Target", "State", "CreationTime"] - ws = Workspace( + ws = self.create_workspace_with_params( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, ) for var_name in var_names: @@ -64,11 +61,10 @@ def test_orderby_valid(self): self.assertEqual(orderby, expected) def test_orderby_invalid(self): - ws = Workspace( + ws = self.create_workspace_with_params( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, ) # pylint: disable=protected-access self.assertRaises(ValueError, ws._create_orderby, "test", True) diff --git a/azure-quantum/tests/unit/test_qiskit.py b/azure-quantum/tests/unit/test_qiskit.py index 902c7df59..bec16cdb9 100644 --- a/azure-quantum/tests/unit/test_qiskit.py +++ b/azure-quantum/tests/unit/test_qiskit.py @@ -16,7 +16,7 @@ from qiskit.providers import Options from qiskit.providers.exceptions import QiskitBackendNotFoundError -from common import QuantumTestBase, DEFAULT_TIMEOUT_SECS, LOCATION +from common import QuantumTestBase, DEFAULT_TIMEOUT_SECS from test_workspace import SIMPLE_RESOURCE_ID from azure.quantum.workspace import Workspace @@ -591,7 +591,7 @@ def test_qiskit_provider_init_with_workspace_not_raises_deprecation(self): # Cause all warnings to always be triggered. warnings.simplefilter("always") # Try to trigger a warning. - workspace = Workspace(resource_id=SIMPLE_RESOURCE_ID, location=LOCATION) + workspace = self.create_workspace_with_params(resource_id=SIMPLE_RESOURCE_ID) AzureQuantumProvider(workspace) warns = [ @@ -606,44 +606,52 @@ def test_qiskit_provider_init_with_workspace_not_raises_deprecation(self): def test_qiskit_provider_init_without_workspace_raises_deprecation(self): # testing warning according to https://docs.python.org/3/library/warnings.html#testing-warnings - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. - warnings.simplefilter("always") - # Try to trigger a warning. - AzureQuantumProvider(resource_id=SIMPLE_RESOURCE_ID, location=LOCATION) - - warns = [ - warn - for warn in w - if 'Consider passing "workspace" argument explicitly.' - in warn.message.args[0] - ] - - # Verify - assert len(warns) == 1 - assert issubclass(warns[0].category, DeprecationWarning) - - # Validate rising deprecation warning even if workspace is passed, but other parameters are also passed - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. - warnings.simplefilter("always") - # Try to trigger a warning. - workspace = Workspace(resource_id=SIMPLE_RESOURCE_ID, location=LOCATION) - - AzureQuantumProvider( - workspace=workspace, resource_id=SIMPLE_RESOURCE_ID, location=LOCATION - ) + from unittest.mock import patch + + # Create mock mgmt_client to avoid ARM calls + mock_mgmt_client = self.create_mock_mgmt_client() + + with patch('azure.quantum.workspace.WorkspaceMgmtClient', return_value=mock_mgmt_client): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + # Try to trigger a warning. + AzureQuantumProvider(resource_id=SIMPLE_RESOURCE_ID) + + warns = [ + warn + for warn in w + if 'Consider passing "workspace" argument explicitly.' + in warn.message.args[0] + ] + + # Verify + assert len(warns) == 1 + assert issubclass(warns[0].category, DeprecationWarning) + + # Validate rising deprecation warning even if workspace is passed, but other parameters are also passed + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + # Try to trigger a warning. + workspace = Workspace( + resource_id=SIMPLE_RESOURCE_ID, + _mgmt_client=mock_mgmt_client) + + AzureQuantumProvider( + workspace=workspace, resource_id=SIMPLE_RESOURCE_ID + ) - warns = [ - warn - for warn in w - if 'Consider passing "workspace" argument explicitly.' - in warn.message.args[0] - ] + warns = [ + warn + for warn in w + if 'Consider passing "workspace" argument explicitly.' + in warn.message.args[0] + ] - # Verify - assert len(warns) == 1 - assert issubclass(warns[0].category, DeprecationWarning) + # Verify + assert len(warns) == 1 + assert issubclass(warns[0].category, DeprecationWarning) @pytest.mark.ionq @pytest.mark.live_test diff --git a/azure-quantum/tests/unit/test_workspace.py b/azure-quantum/tests/unit/test_workspace.py index 9c424c73e..8f97a6ca4 100644 --- a/azure-quantum/tests/unit/test_workspace.py +++ b/azure-quantum/tests/unit/test_workspace.py @@ -12,7 +12,10 @@ WORKSPACE, LOCATION, STORAGE, + ENDPOINT_URI, API_KEY, + ZERO_UID, + PLACEHOLDER, ) from azure.quantum import Workspace from azure.quantum._constants import ( @@ -38,54 +41,125 @@ quantum_endpoint=ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT(LOCATION) ) +SIMPLE_CONNECTION_STRING_V2 = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_PRODUCTION_ENDPOINT_v2(LOCATION) +) + class TestWorkspace(QuantumTestBase): def test_create_workspace_instance_valid(self): + def assert_all_required_params(ws: Workspace): + self.assertEqual(ws.subscription_id, SUBSCRIPTION_ID) + self.assertEqual(ws.resource_group, RESOURCE_GROUP) + self.assertEqual(ws.name, WORKSPACE) + self.assertEqual(ws.location, LOCATION) + self.assertEqual(ws._connection_params.quantum_endpoint, ENDPOINT_URI) + + mock_mgmt_client = self.create_mock_mgmt_client() + ws = Workspace( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, + _mgmt_client=mock_mgmt_client, ) - self.assertEqual(ws.subscription_id, SUBSCRIPTION_ID) - self.assertEqual(ws.resource_group, RESOURCE_GROUP) - self.assertEqual(ws.name, WORKSPACE) - self.assertEqual(ws.location, LOCATION) + assert_all_required_params(ws) ws = Workspace( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, storage=STORAGE, + _mgmt_client=mock_mgmt_client, ) + assert_all_required_params(ws) self.assertEqual(ws.storage, STORAGE) ws = Workspace( resource_id=SIMPLE_RESOURCE_ID, - location=LOCATION, + _mgmt_client=mock_mgmt_client, ) - self.assertEqual(ws.subscription_id, SUBSCRIPTION_ID) - self.assertEqual(ws.resource_group, RESOURCE_GROUP) - self.assertEqual(ws.name, WORKSPACE) - self.assertEqual(ws.location, LOCATION) + assert_all_required_params(ws) ws = Workspace( resource_id=SIMPLE_RESOURCE_ID, storage=STORAGE, - location=LOCATION, + _mgmt_client=mock_mgmt_client, ) + assert_all_required_params(ws) self.assertEqual(ws.storage, STORAGE) + ws = Workspace( + name=WORKSPACE, + _mgmt_client=mock_mgmt_client, + ) + assert_all_required_params(ws) + + ws = Workspace( + name=WORKSPACE, + storage=STORAGE, + _mgmt_client=mock_mgmt_client, + ) + assert_all_required_params(ws) + self.assertEqual(ws.storage, STORAGE) + + ws = Workspace( + name=WORKSPACE, + location=LOCATION, + _mgmt_client=mock_mgmt_client, + ) + assert_all_required_params(ws) + + ws = Workspace( + name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID, + _mgmt_client=mock_mgmt_client, + ) + assert_all_required_params(ws) + + ws = Workspace( + name=WORKSPACE, + subscription_id=SUBSCRIPTION_ID, + location=LOCATION, + _mgmt_client=mock_mgmt_client, + ) + assert_all_required_params(ws) + + ws = Workspace( + name=WORKSPACE, + resource_group=RESOURCE_GROUP, + _mgmt_client=mock_mgmt_client, + ) + assert_all_required_params(ws) + + ws = Workspace( + name=WORKSPACE, + resource_group=RESOURCE_GROUP, + location=LOCATION, + _mgmt_client=mock_mgmt_client, + ) + assert_all_required_params(ws) + def test_create_workspace_locations(self): - # User-provided location name should be normalized + # Location name should be normalized location = "East US" + mock_mgmt_client = mock.MagicMock() + def mock_load_workspace_from_arm(connection_params): + connection_params.location = location + connection_params.quantum_endpoint = ENDPOINT_URI + mock_mgmt_client.load_workspace_from_arm = mock_load_workspace_from_arm + ws = Workspace( + name=WORKSPACE, subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, - name=WORKSPACE, - location=location, + _mgmt_client=mock_mgmt_client, ) + self.assertEqual(ws.location, "eastus") def test_env_connection_string(self): @@ -110,6 +184,7 @@ def test_env_connection_string(self): id(workspace.credential)) def test_workspace_from_connection_string(self): + mock_mgmt_client = self.create_mock_mgmt_client() with mock.patch.dict( os.environ, clear=True @@ -136,7 +211,7 @@ def test_workspace_from_connection_string(self): wrong_subscription_id = "00000000-2BAD-2BAD-2BAD-000000000000" wrong_resource_group = "wrongrg" wrong_workspace = "wrong-workspace" - wrong_location = "wrong-location" + wrong_location = "westus" # make sure the values above are really different from the default values self.assertNotEqual(wrong_subscription_id, SUBSCRIPTION_ID) @@ -168,7 +243,10 @@ def test_workspace_from_connection_string(self): self.assertIsInstance(workspace.credential, AzureKeyCredential) # if we pass a credential, then it should be used - workspace = Workspace(credential=EnvironmentCredential()) + os.environ[EnvironmentVariables.AZURE_CLIENT_ID] = ZERO_UID + os.environ[EnvironmentVariables.AZURE_TENANT_ID] = ZERO_UID + os.environ[EnvironmentVariables.AZURE_CLIENT_SECRET] = PLACEHOLDER + workspace = Workspace(credential=EnvironmentCredential(), _mgmt_client=mock_mgmt_client) self.assertIsInstance(workspace.credential, EnvironmentCredential) # the connection string passed as a parameter should override the @@ -204,6 +282,70 @@ def test_workspace_from_connection_string(self): self.assertEqual(workspace.subscription_id, SUBSCRIPTION_ID) self.assertEqual(workspace.resource_group, RESOURCE_GROUP) self.assertEqual(workspace.name, WORKSPACE) + + def test_workspace_from_connection_string_v2(self): + """Test that v2 QuantumEndpoint format is correctly parsed.""" + with mock.patch.dict( + os.environ, + clear=True + ): + workspace = Workspace.from_connection_string(SIMPLE_CONNECTION_STRING_V2) + self.assertEqual(workspace.location, LOCATION) + self.assertEqual(workspace.subscription_id, SUBSCRIPTION_ID) + self.assertEqual(workspace.resource_group, RESOURCE_GROUP) + self.assertEqual(workspace.name, WORKSPACE) + self.assertIsInstance(workspace.credential, AzureKeyCredential) + self.assertEqual(workspace.credential.key, API_KEY) + # pylint: disable=protected-access + self.assertIsInstance( + workspace._client._config.authentication_policy, + AzureKeyCredentialPolicy) + auth_policy = workspace._client._config.authentication_policy + self.assertEqual(auth_policy._name, ConnectionConstants.QUANTUM_API_KEY_HEADER) + self.assertEqual(id(auth_policy._credential), + id(workspace.credential)) + + def test_workspace_from_connection_string_v2_dogfood(self): + """Test v2 QuantumEndpoint with dogfood environment.""" + canary_location = "eastus2euap" + dogfood_connection_string_v2 = ConnectionConstants.VALID_CONNECTION_STRING( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + workspace_name=WORKSPACE, + api_key=API_KEY, + quantum_endpoint=ConnectionConstants.GET_QUANTUM_DOGFOOD_ENDPOINT_v2(canary_location) + ) + + with mock.patch.dict(os.environ, clear=True): + workspace = Workspace.from_connection_string(dogfood_connection_string_v2) + self.assertEqual(workspace.location, canary_location) + self.assertEqual(workspace.subscription_id, SUBSCRIPTION_ID) + self.assertEqual(workspace.resource_group, RESOURCE_GROUP) + self.assertEqual(workspace.name, WORKSPACE) + self.assertIsInstance(workspace.credential, AzureKeyCredential) + self.assertEqual(workspace.credential.key, API_KEY) + + def test_env_connection_string_v2(self): + """Test v2 QuantumEndpoint from environment variable.""" + with mock.patch.dict(os.environ): + self.clear_env_vars(os.environ) + os.environ[EnvironmentVariables.CONNECTION_STRING] = SIMPLE_CONNECTION_STRING_V2 + + workspace = Workspace() + self.assertEqual(workspace.location, LOCATION) + self.assertEqual(workspace.subscription_id, SUBSCRIPTION_ID) + self.assertEqual(workspace.name, WORKSPACE) + self.assertEqual(workspace.resource_group, RESOURCE_GROUP) + self.assertIsInstance(workspace.credential, AzureKeyCredential) + self.assertEqual(workspace.credential.key, API_KEY) + # pylint: disable=protected-access + self.assertIsInstance( + workspace._client._config.authentication_policy, + AzureKeyCredentialPolicy) + auth_policy = workspace._client._config.authentication_policy + self.assertEqual(auth_policy._name, ConnectionConstants.QUANTUM_API_KEY_HEADER) + self.assertEqual(id(auth_policy._credential), + id(workspace.credential)) def test_create_workspace_instance_invalid(self): def assert_value_error(exception): @@ -213,48 +355,20 @@ def assert_value_error(exception): with mock.patch.dict(os.environ): self.clear_env_vars(os.environ) - # missing location + # missing workspace name with self.assertRaises(ValueError) as context: Workspace( - location=None, subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, - name=WORKSPACE, - ) - assert_value_error(context.exception) - - # missing location - with self.assertRaises(ValueError) as context: - Workspace(resource_id=SIMPLE_RESOURCE_ID) - assert_value_error(context.exception) - - # missing subscription id - with self.assertRaises(ValueError) as context: - Workspace( - location=LOCATION, - subscription_id=None, - resource_group=RESOURCE_GROUP, - name=WORKSPACE - ) - assert_value_error(context.exception) - - # missing resource group - with self.assertRaises(ValueError) as context: - Workspace( - location=LOCATION, - subscription_id=SUBSCRIPTION_ID, - resource_group=None, - name=WORKSPACE + name=None ) assert_value_error(context.exception) - # missing workspace name + # provide only subscription id and resource group with self.assertRaises(ValueError) as context: Workspace( - location=LOCATION, subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, - name=None ) assert_value_error(context.exception) @@ -266,7 +380,6 @@ def assert_value_error(exception): # invalid resource id with self.assertRaises(ValueError) as context: Workspace( - location=LOCATION, resource_id="invalid/resource/id") self.assertIn("Invalid resource id", context.exception.args[0]) @@ -353,42 +466,38 @@ def test_workspace_user_agent_appid(self): # no UserAgent parameter and no EnvVar AppId os.environ[EnvironmentVariables.USER_AGENT_APPID] = "" - ws = Workspace( + ws = self.create_workspace_with_params( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION ) self.assertIsNone(ws.user_agent) # no UserAgent parameter and with EnvVar AppId os.environ[EnvironmentVariables.USER_AGENT_APPID] = app_id - ws = Workspace( + ws = self.create_workspace_with_params( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION ) self.assertEqual(ws.user_agent, app_id) # with UserAgent parameter and no EnvVar AppId os.environ[EnvironmentVariables.USER_AGENT_APPID] = "" - ws = Workspace( + ws = self.create_workspace_with_params( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, user_agent=user_agent ) self.assertEqual(ws.user_agent, user_agent) # with UserAgent parameter and EnvVar AppId os.environ[EnvironmentVariables.USER_AGENT_APPID] = app_id - ws = Workspace( + ws = self.create_workspace_with_params( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, user_agent=user_agent ) self.assertEqual(ws.user_agent, @@ -396,11 +505,10 @@ def test_workspace_user_agent_appid(self): # Append with UserAgent parameter and with EnvVar AppId os.environ[EnvironmentVariables.USER_AGENT_APPID] = app_id - ws = Workspace( + ws = self.create_workspace_with_params( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION, user_agent=user_agent ) ws.append_user_agent("featurex") @@ -412,11 +520,61 @@ def test_workspace_user_agent_appid(self): # Append with no UserAgent parameter and no EnvVar AppId os.environ[EnvironmentVariables.USER_AGENT_APPID] = "" - ws = Workspace( + ws = self.create_workspace_with_params( subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE, - location=LOCATION ) ws.append_user_agent("featurex") self.assertEqual(ws.user_agent, "featurex") + + def test_workspace_context_manager(self): + """Test that Workspace can be used as a context manager""" + mock_mgmt_client = self.create_mock_mgmt_client() + + # Test with statement + with Workspace( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + _mgmt_client=mock_mgmt_client, + ) as ws: + # Verify workspace is properly initialized + self.assertEqual(ws.subscription_id, SUBSCRIPTION_ID) + self.assertEqual(ws.resource_group, RESOURCE_GROUP) + self.assertEqual(ws.name, WORKSPACE) + self.assertEqual(ws.location, LOCATION) + + # Verify internal clients are accessible + self.assertIsNotNone(ws._client) + self.assertIsNotNone(ws._mgmt_client) + + def test_workspace_context_manager_calls_enter_exit(self): + """Test that __enter__ and __exit__ are called on internal clients""" + mock_mgmt_client = self.create_mock_mgmt_client() + + ws = Workspace( + subscription_id=SUBSCRIPTION_ID, + resource_group=RESOURCE_GROUP, + name=WORKSPACE, + _mgmt_client=mock_mgmt_client, + ) + + # Mock the internal clients' __enter__ and __exit__ methods + ws._client.__enter__ = mock.MagicMock(return_value=ws._client) + ws._client.__exit__ = mock.MagicMock(return_value=None) + ws._mgmt_client.__enter__ = mock.MagicMock(return_value=ws._mgmt_client) + ws._mgmt_client.__exit__ = mock.MagicMock(return_value=None) + + # Use workspace as context manager + with ws as context_ws: + # Verify __enter__ was called on both clients + ws._client.__enter__.assert_called_once() + ws._mgmt_client.__enter__.assert_called_once() + + # Verify context manager returns the workspace instance + self.assertIs(context_ws, ws) + + # Verify __exit__ was called on both clients after exiting context + ws._client.__exit__.assert_called_once() + ws._mgmt_client.__exit__.assert_called_once() diff --git a/azure-quantum/tests/unit/test_workspace_connection_params_validation.py b/azure-quantum/tests/unit/test_workspace_connection_params_validation.py new file mode 100644 index 000000000..08c5a7cb0 --- /dev/null +++ b/azure-quantum/tests/unit/test_workspace_connection_params_validation.py @@ -0,0 +1,322 @@ +## +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +## + +import unittest +from azure.quantum._workspace_connection_params import WorkspaceConnectionParams + + +class TestWorkspaceConnectionParamsValidation(unittest.TestCase): + """Test validation of WorkspaceConnectionParams fields.""" + + def test_valid_subscription_ids(self): + """Test that valid subscription_ids are accepted.""" + valid_ids = [ + "12345678-1234-1234-1234-123456789abc", + "ABCDEF01-2345-6789-ABCD-EF0123456789", + "abcdef01-2345-6789-abcd-ef0123456789", + ] + for subscription_id in valid_ids: + params = WorkspaceConnectionParams(subscription_id=subscription_id) + self.assertEqual(params.subscription_id, subscription_id) + + def test_invalid_subscription_ids(self): + """Test that invalid subscription_ids raise ValueError.""" + invalid_ids = [ + ("not-a-guid", "Subscription ID must be a valid GUID."), + (12345, "Subscription ID must be a string."), + ] + for subscription_id, expected_message in invalid_ids: + with self.assertRaises(ValueError) as context: + WorkspaceConnectionParams(subscription_id=subscription_id) + self.assertIn(expected_message, str(context.exception)) + + def test_valid_resource_groups(self): + """Test that valid resource_groups are accepted.""" + valid_groups = [ + "my-resource-group", + "MyResourceGroup", + "resource_group_123", + "rg123", + "a" * 90, # Max length (90 chars) + "a", # Min length (1 char) + "Resource_Group-1", + "my.resource.group", # Periods allowed (except at end) + "group(test)", # Parentheses allowed + "group(test)name", + "(parentheses)", + "test-group_name", + "GROUP-123", + "123-group", + "Test.Group.Name", + "my-group.v2", + "rg_test(prod)-v1.2", + "café", # Unicode letters (Lo) + "日本語", # Unicode letters (Lo) + "Казан", # Unicode letters (Lu, Ll) + "αβγ", # Greek letters (Ll) + "test-café-123", # Mixed ASCII and Unicode + "group_名前", # Mixed ASCII and Unicode + "test.group(1)-name_v2", # Multiple special chars + ] + for resource_group in valid_groups: + params = WorkspaceConnectionParams(resource_group=resource_group) + self.assertEqual(params.resource_group, resource_group) + + def test_invalid_resource_groups(self): + """Test that invalid resource_groups raise ValueError.""" + rg_invalid_chars_msg = "Resource group name can only include alphanumeric, underscore, parentheses, hyphen, period (except at end), and Unicode characters that match the allowed characters." + invalid_groups = [ + ("my/resource/group", rg_invalid_chars_msg), + ("my\\resource\\group", rg_invalid_chars_msg), + ("my resource group", rg_invalid_chars_msg), + (12345, "Resource group name must be a string."), + ("group.", rg_invalid_chars_msg), # Period at end + ("my-group.", rg_invalid_chars_msg), # Period at end + ("test.group.", rg_invalid_chars_msg), # Period at end + ("a" * 91, "Resource group name must be between 1 and 90 characters long."), # Too long + ("group@test", rg_invalid_chars_msg), # @ symbol + ("group#test", rg_invalid_chars_msg), # # symbol + ("group$test", rg_invalid_chars_msg), # $ symbol + ("group%test", rg_invalid_chars_msg), # % symbol + ("group^test", rg_invalid_chars_msg), # ^ symbol + ("group&test", rg_invalid_chars_msg), # & symbol + ("group*test", rg_invalid_chars_msg), # * symbol + ("group+test", rg_invalid_chars_msg), # + symbol + ("group=test", rg_invalid_chars_msg), # = symbol + ("group[test]", rg_invalid_chars_msg), # Square brackets + ("group{test}", rg_invalid_chars_msg), # Curly brackets + ("group|test", rg_invalid_chars_msg), # Pipe + ("group:test", rg_invalid_chars_msg), # Colon + ("group;test", rg_invalid_chars_msg), # Semicolon + ("group\"test", rg_invalid_chars_msg), # Quote + ("group'test", rg_invalid_chars_msg), # Single quote + ("group", rg_invalid_chars_msg), # Angle brackets + ("group,test", rg_invalid_chars_msg), # Comma + ("group?test", rg_invalid_chars_msg), # Question mark + ("group!test", rg_invalid_chars_msg), # Exclamation mark + ("group`test", rg_invalid_chars_msg), # Backtick + ("group~test", rg_invalid_chars_msg), # Tilde + ("test\ngroup", rg_invalid_chars_msg), # Newline + ("test\tgroup", rg_invalid_chars_msg), # Tab + ] + for resource_group, expected_message in invalid_groups: + with self.assertRaises(ValueError) as context: + WorkspaceConnectionParams(resource_group=resource_group) + self.assertIn(expected_message, str(context.exception)) + + def test_empty_resource_group(self): + """Test that empty resource_group is treated as None (not set).""" + # Empty strings are treated as falsy in the merge logic and not set + params = WorkspaceConnectionParams(resource_group="") + self.assertIsNone(params.resource_group) + + def test_valid_workspace_names(self): + """Test that valid workspace names are accepted.""" + valid_names = [ + "12", + "a1", + "1a", + "ab", + "myworkspace", + "WORKSPACE", + "MyWorkspace", + "myWorkSpace", + "myworkspacE", + "1234567890", + "123workspace", + "workspace123", + "w0rksp4c3", + "123abc456def", + "abc123", + # with hyphens + "my-workspace", + "my-work-space", + "workspace-with-a-long-name-that-is-still-valid", + "a-b-c-d-e", + "my-workspace-2", + "workspace-1-2-3", + "1-a", + "b-2", + "1-2", + "a-b", + "1-b-2", + "a-1-b", + "workspace" + "-" * 10 + "test", + "a" * 54, # Max length (54 chars) + "1" * 54, # Max length with numbers + ] + for workspace_name in valid_names: + params = WorkspaceConnectionParams(workspace_name=workspace_name) + self.assertEqual(params.workspace_name, workspace_name) + + def test_invalid_workspace_names(self): + """Test that invalid workspace names raise ValueError.""" + not_valid_names = [ + ("a", "Workspace name must be between 2 and 54 characters long."), + ("1", "Workspace name must be between 2 and 54 characters long."), + ("a" * 55, "Workspace name must be between 2 and 54 characters long."), + ("1" * 55, "Workspace name must be between 2 and 54 characters long."), + ("my_workspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("my/workspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("my workspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("-myworkspace", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + ("myworkspace-", "Workspace name can only include alphanumerics (a-zA-Z0-9) and hyphens, and cannot start or end with hyphen."), + (12345, "Workspace name must be a string."), + ] + for workspace_name, expected_message in not_valid_names: + with self.assertRaises(ValueError) as context: + WorkspaceConnectionParams(workspace_name=workspace_name) + self.assertIn(expected_message, str(context.exception)) + + def test_empty_workspace_name(self): + """Test that empty workspace_name is treated as None (not set).""" + # Empty strings are treated as falsy in the merge logic and not set + params = WorkspaceConnectionParams(workspace_name="") + self.assertIsNone(params.workspace_name) + + def test_valid_locations(self): + """Test that valid locations are accepted and normalized.""" + valid_locations = [ + ("East US", "eastus"), + ("West Europe", "westeurope"), + ("eastus", "eastus"), + ("westus2", "westus2"), + ("EASTUS", "eastus"), + ("WestUs2", "westus2"), + ("South Central US", "southcentralus"), + ("North Europe", "northeurope"), + ("Southeast Asia", "southeastasia"), + ("Japan East", "japaneast"), + ("UK South", "uksouth"), + ("Australia East", "australiaeast"), + ("Central India", "centralindia"), + ("France Central", "francecentral"), + ("Germany West Central", "germanywestcentral"), + ("Switzerland North", "switzerlandnorth"), + ("UAE North", "uaenorth"), + ("Brazil South", "brazilsouth"), + ("Korea Central", "koreacentral"), + ("South Africa North", "southafricanorth"), + ("Norway East", "norwayeast"), + ("Sweden Central", "swedencentral"), + ("Qatar Central", "qatarcentral"), + ("Poland Central", "polandcentral"), + ("Italy North", "italynorth"), + ("Israel Central", "israelcentral"), + ("Spain Central", "spaincentral"), + ("Austria East", "austriaeast"), + ("Belgium Central", "belgiumcentral"), + ("Chile Central", "chilecentral"), + ("Indonesia Central", "indonesiacentral"), + ("Malaysia West", "malaysiawest"), + ("Mexico Central", "mexicocentral"), + ("New Zealand North", "newzealandnorth"), + ("westus3", "westus3"), + ("canadacentral", "canadacentral"), + ("westcentralus", "westcentralus"), + ] + for location, expected in valid_locations: + params = WorkspaceConnectionParams(location=location) + self.assertEqual(params.location, expected) + + def test_invalid_locations(self): + """Test that invalid locations raise ValueError.""" + location_invalid_region_msg = "Location must be one of the Azure regions listed in https://learn.microsoft.com/en-us/azure/reliability/regions-list." + invalid_locations = [ + (" ", location_invalid_region_msg), + ("invalid-region", location_invalid_region_msg), + ("us-east", location_invalid_region_msg), + ("east-us", location_invalid_region_msg), + ("westus4", location_invalid_region_msg), + ("southus", location_invalid_region_msg), + ("centraleurope", location_invalid_region_msg), + ("asiaeast", location_invalid_region_msg), + ("chinaeast", location_invalid_region_msg), + ("usgovtexas", location_invalid_region_msg), + ("East US 3", location_invalid_region_msg), + ("not a region", location_invalid_region_msg), + (12345, "Location must be a string."), + (3.14, "Location must be a string."), + (True, "Location must be a string."), + ] + for location, expected_message in invalid_locations: + with self.assertRaises(ValueError) as context: + WorkspaceConnectionParams(location=location) + self.assertIn(expected_message, str(context.exception)) + + def test_empty_location(self): + """Test that empty location is treated as None (not set).""" + # Empty strings are treated as falsy in the merge logic and not set + params = WorkspaceConnectionParams(location="") + self.assertIsNone(params.location) + + # None is also allowed and treated as not set + params = WorkspaceConnectionParams(location=None) + self.assertIsNone(params.location) + + def test_none_values_are_allowed(self): + """Test that None values for optional fields are allowed.""" + # This should not raise any exceptions + params = WorkspaceConnectionParams( + subscription_id=None, + resource_group=None, + workspace_name=None, + location=None, + user_agent=None, + ) + self.assertIsNone(params.subscription_id) + self.assertIsNone(params.resource_group) + self.assertIsNone(params.workspace_name) + self.assertIsNone(params.location) + self.assertIsNone(params.user_agent) + + def test_multiple_valid_parameters(self): + """Test that multiple valid parameters work together.""" + params = WorkspaceConnectionParams( + subscription_id="12345678-1234-1234-1234-123456789abc", + resource_group="my-resource-group", + workspace_name="my-workspace", + location="East US", + user_agent="my-app/1.0", + ) + self.assertEqual(params.subscription_id, "12345678-1234-1234-1234-123456789abc") + self.assertEqual(params.resource_group, "my-resource-group") + self.assertEqual(params.workspace_name, "my-workspace") + self.assertEqual(params.location, "eastus") + self.assertEqual(params.user_agent, "my-app/1.0") + + def test_validation_on_resource_id(self): + """Test that validation works when using resource_id.""" + # Valid resource_id should work + resource_id = ( + "/subscriptions/12345678-1234-1234-1234-123456789abc" + "/resourceGroups/my-rg" + "/providers/Microsoft.Quantum" + "/Workspaces/my-ws" + ) + params = WorkspaceConnectionParams(resource_id=resource_id) + self.assertEqual(params.subscription_id, "12345678-1234-1234-1234-123456789abc") + self.assertEqual(params.resource_group, "my-rg") + self.assertEqual(params.workspace_name, "my-ws") + + def test_validation_on_connection_string(self): + """Test that validation works when using connection_string.""" + # Valid connection string should work + connection_string = ( + "SubscriptionId=12345678-1234-1234-1234-123456789abc;" + "ResourceGroupName=my-rg;" + "WorkspaceName=my-ws;" + "ApiKey=test-key;" + "QuantumEndpoint=https://eastus.quantum.azure.com/;" + ) + params = WorkspaceConnectionParams(connection_string=connection_string) + self.assertEqual(params.subscription_id, "12345678-1234-1234-1234-123456789abc") + self.assertEqual(params.resource_group, "my-rg") + self.assertEqual(params.workspace_name, "my-ws") + self.assertEqual(params.location, "eastus") + + +if __name__ == "__main__": + unittest.main() diff --git a/samples/hello-world/HW-ionq-cirq.ipynb b/samples/hello-world/HW-ionq-cirq.ipynb index 2faa9a26f..dc58ab815 100644 --- a/samples/hello-world/HW-ionq-cirq.ipynb +++ b/samples/hello-world/HW-ionq-cirq.ipynb @@ -49,10 +49,7 @@ "from azure.quantum import Workspace\n", "from azure.quantum.cirq import AzureQuantumService\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")\n", + "workspace = Workspace(resource_id = \"\")\n", "\n", "service = AzureQuantumService(workspace)" ] diff --git a/samples/hello-world/HW-ionq-qiskit.ipynb b/samples/hello-world/HW-ionq-qiskit.ipynb index f7153274d..fdb9af103 100644 --- a/samples/hello-world/HW-ionq-qiskit.ipynb +++ b/samples/hello-world/HW-ionq-qiskit.ipynb @@ -49,10 +49,7 @@ "from azure.quantum import Workspace\n", "from azure.quantum.qiskit import AzureQuantumProvider\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")\n", + "workspace = Workspace(resource_id = \"\")\n", "\n", "provider = AzureQuantumProvider(workspace)" ] diff --git a/samples/hello-world/HW-ionq-qsharp.ipynb b/samples/hello-world/HW-ionq-qsharp.ipynb index e01c801b7..9a5d3c834 100644 --- a/samples/hello-world/HW-ionq-qsharp.ipynb +++ b/samples/hello-world/HW-ionq-qsharp.ipynb @@ -52,10 +52,7 @@ "source": [ "from azure.quantum import Workspace\n", "\n", - "workspace = Workspace (\n", - " resource_id = \"\",\n", - " location = \"\"\n", - ")" + "workspace = Workspace (resource_id = \"\")" ] }, { diff --git a/samples/hello-world/HW-pasqal-pulser.ipynb b/samples/hello-world/HW-pasqal-pulser.ipynb index 124a81d14..7909be8f3 100644 --- a/samples/hello-world/HW-pasqal-pulser.ipynb +++ b/samples/hello-world/HW-pasqal-pulser.ipynb @@ -50,12 +50,9 @@ "source": [ "from azure.quantum import Workspace\n", "\n", - "# Your `resource_id` and `location` should be available on the Overview page of your Quantum Workspace.\n", + "# Your `resource_id` should be available on the Overview page of your Quantum Workspace.\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")" + "workspace = Workspace(resource_id = \"\")" ] }, { diff --git a/samples/hello-world/HW-quantinuum-cirq.ipynb b/samples/hello-world/HW-quantinuum-cirq.ipynb index e09dcf4a2..dae83401e 100644 --- a/samples/hello-world/HW-quantinuum-cirq.ipynb +++ b/samples/hello-world/HW-quantinuum-cirq.ipynb @@ -59,10 +59,7 @@ "from azure.quantum import Workspace\n", "from azure.quantum.cirq import AzureQuantumService\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")\n", + "workspace = Workspace(resource_id = \"\")\n", "\n", "service = AzureQuantumService(workspace)" ] diff --git a/samples/hello-world/HW-quantinuum-qiskit.ipynb b/samples/hello-world/HW-quantinuum-qiskit.ipynb index dc97796f6..94d3f23f8 100644 --- a/samples/hello-world/HW-quantinuum-qiskit.ipynb +++ b/samples/hello-world/HW-quantinuum-qiskit.ipynb @@ -53,10 +53,7 @@ "from azure.quantum import Workspace\n", "from azure.quantum.qiskit import AzureQuantumProvider\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")\n", + "workspace = Workspace(resource_id = \"\")\n", "\n", "provider = AzureQuantumProvider(workspace)" ] diff --git a/samples/hello-world/HW-quantinuum-qsharp.ipynb b/samples/hello-world/HW-quantinuum-qsharp.ipynb index 70be03086..2c1ede6ca 100644 --- a/samples/hello-world/HW-quantinuum-qsharp.ipynb +++ b/samples/hello-world/HW-quantinuum-qsharp.ipynb @@ -62,10 +62,7 @@ "source": [ "from azure.quantum import Workspace\n", "\n", - "workspace = Workspace (\n", - " resource_id = \"\",\n", - " location = \"\"\n", - ")" + "workspace = Workspace (resource_id = \"\")" ] }, { diff --git a/samples/hello-world/HW-rigetti-qiskit.ipynb b/samples/hello-world/HW-rigetti-qiskit.ipynb index fe36e70a9..b09a59b8f 100644 --- a/samples/hello-world/HW-rigetti-qiskit.ipynb +++ b/samples/hello-world/HW-rigetti-qiskit.ipynb @@ -53,10 +53,7 @@ "from azure.quantum import Workspace\n", "from azure.quantum.qiskit import AzureQuantumProvider\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")\n", + "workspace = Workspace(resource_id = \"\")\n", "\n", "provider = AzureQuantumProvider(workspace)" ] diff --git a/samples/hello-world/HW-rigetti-qsharp.ipynb b/samples/hello-world/HW-rigetti-qsharp.ipynb index 8f08adfd3..067813199 100644 --- a/samples/hello-world/HW-rigetti-qsharp.ipynb +++ b/samples/hello-world/HW-rigetti-qsharp.ipynb @@ -52,10 +52,7 @@ "source": [ "from azure.quantum import Workspace\n", "\n", - "workspace = Workspace (\n", - " resource_id = \"\",\n", - " location = \"\"\n", - ")" + "workspace = Workspace (resource_id = \"\")" ] }, { diff --git a/samples/hidden-shift/hidden-shift.ipynb b/samples/hidden-shift/hidden-shift.ipynb index b63080557..893189a04 100644 --- a/samples/hidden-shift/hidden-shift.ipynb +++ b/samples/hidden-shift/hidden-shift.ipynb @@ -71,10 +71,7 @@ "from azure.quantum import Workspace\n", "from azure.quantum.qiskit import AzureQuantumProvider\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")\n", + "workspace = Workspace(resource_id = \"\")\n", "\n", "provider = AzureQuantumProvider(workspace)" ] diff --git a/samples/quantum-approximation-optimization/qaoa-for-quadratic-unconstrained-binary-optimization.ipynb b/samples/quantum-approximation-optimization/qaoa-for-quadratic-unconstrained-binary-optimization.ipynb index 188b149b6..27f7cac6a 100644 --- a/samples/quantum-approximation-optimization/qaoa-for-quadratic-unconstrained-binary-optimization.ipynb +++ b/samples/quantum-approximation-optimization/qaoa-for-quadratic-unconstrained-binary-optimization.ipynb @@ -654,17 +654,14 @@ { "metadata": {}, "cell_type": "markdown", - "source": "Please set your `resource_id` and `location` below." + "source": "Please set your `resource_id` below." }, { "cell_type": "code", "metadata": {}, "source": [ "from azure.quantum import Workspace\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")\n", + "workspace = Workspace(resource_id = \"\")\n", "target = workspace.get_targets(name=\"pasqal.sim.emu-tn\")\n", "\n", "if isinstance(target, list):\n", @@ -788,4 +785,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/samples/quantum-signal-processing/signal-processing.ipynb b/samples/quantum-signal-processing/signal-processing.ipynb index e4d7ff0ba..3e68fbd95 100644 --- a/samples/quantum-signal-processing/signal-processing.ipynb +++ b/samples/quantum-signal-processing/signal-processing.ipynb @@ -313,9 +313,7 @@ "from azure.quantum import Workspace\n", "from azure.quantum.qiskit import AzureQuantumProvider\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"/subscriptions/677fc922-91d0-4bf6-9b06-4274d319a0fa/resourceGroups/xiou/providers/Microsoft.Quantum/Workspaces/xiou-notebooks-demo\",\n", - " location = \"eastus2euap\")\n", + "workspace = Workspace(resource_id = \"/subscriptions/677fc922-91d0-4bf6-9b06-4274d319a0fa/resourceGroups/xiou/providers/Microsoft.Quantum/Workspaces/xiou-notebooks-demo\")\n", "\n", "\n", "provider = AzureQuantumProvider(workspace)" diff --git a/samples/sessions/introduction-to-sessions.ipynb b/samples/sessions/introduction-to-sessions.ipynb index 1cabbc422..e4ce1a4a0 100644 --- a/samples/sessions/introduction-to-sessions.ipynb +++ b/samples/sessions/introduction-to-sessions.ipynb @@ -57,10 +57,7 @@ "from azure.quantum import Workspace\n", "from azure.quantum.qiskit import AzureQuantumProvider\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")\n", + "workspace = Workspace(resource_id = \"\")\n", "\n", "provider = AzureQuantumProvider(workspace)\n", "\n", diff --git a/samples/vqe/VQE-qiskit-hydrogen-session.ipynb b/samples/vqe/VQE-qiskit-hydrogen-session.ipynb index 324047dca..87a1ae1f0 100644 --- a/samples/vqe/VQE-qiskit-hydrogen-session.ipynb +++ b/samples/vqe/VQE-qiskit-hydrogen-session.ipynb @@ -258,10 +258,7 @@ "from azure.quantum import Workspace\n", "from azure.quantum.qiskit import AzureQuantumProvider\n", "\n", - "workspace = Workspace(\n", - " resource_id = \"\",\n", - " location = \"\",\n", - ")\n", + "workspace = Workspace(resource_id = \"\")\n", "\n", "# Connect to the Azure Quantum workspace via a Qiskit provider\n", "provider = AzureQuantumProvider(workspace)"