Source code for qbraid.providers.aws.provider

# Copyright (C) 2023 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Module for configuring provider credentials and authentication.

"""

import os
from typing import TYPE_CHECKING, List, Optional

import boto3
from boto3.session import Session
from braket.aws import AwsDevice, AwsSession
from qbraid_core.services.quantum import quantum_lib_proxy_state

from qbraid.exceptions import QbraidError
from qbraid.providers.provider import QuantumProvider

if TYPE_CHECKING:
    import braket.aws


[docs] class BraketProvider(QuantumProvider): """ This class is responsible for managing the interactions and authentications with the AWS services. Attributes: aws_access_key_id (str): AWS access key ID for authenticating with AWS services. aws_secret_access_key (str): AWS secret access key for authenticating with AWS services. """ REGIONS = ("us-east-1", "us-west-1", "us-west-2", "eu-west-2")
[docs] def __init__(self, aws_access_key_id=None, aws_secret_access_key=None): """ Initializes the QbraidProvider object with optional AWS credentials. Args: aws_access_key_id (str, optional): AWS access key ID. Defaults to None. aws_secret_access_key (str, optional): AWS secret access token. Defaults to None. """ self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID") self.aws_secret_access_key = aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
def save_config(self): """Save the current configuration.""" raise NotImplementedError @staticmethod def _get_default_region() -> str: """Returns the default AWS region.""" try: session = Session() return session.region_name except Exception: # pylint: disable=broad-exception-caught return "us-east-1" @staticmethod def _get_s3_default_bucket() -> str: """Return name of the default bucket to use in the AWS Session""" try: aws_session = AwsSession() return aws_session.default_bucket() except Exception: # pylint: disable=broad-exception-caught return "amazon-braket-qbraid-provider" def _get_region_name(self, device_arn: str) -> str: """Returns the AWS region name.""" maybe_region = device_arn.split(":")[3] if maybe_region in self.REGIONS: return maybe_region provider = device_arn.split("/")[-2] if provider in ["ionq", "quera", "xanadu"]: return "us-east-1" if provider == "oqc": return "eu-west-2" if provider == "rigetti": return "us-west-1" return self._get_default_region() def _get_aws_session(self, region_name: Optional[str] = None) -> "braket.aws.AwsSession": """Returns the AwsSession provider.""" region_name = region_name or self._get_default_region() default_bucket = self._get_s3_default_bucket() boto_session = Session( region_name=region_name, aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, ) return AwsSession(boto_session=boto_session, default_bucket=default_bucket) def get_devices( self, aws_session=None, statuses=None, **kwargs ) -> List["braket.aws.AwsDevice"]: """Return a list of backends matching the specified filtering.""" aws_session = self._get_aws_session() if aws_session is None else aws_session statuses = ["ONLINE", "OFFLINE"] if statuses is None else statuses return AwsDevice.get_devices(aws_session=aws_session, statuses=statuses, **kwargs) def get_device(self, device_id: str) -> "braket.aws.AwsDevice": """Returns the AWS device.""" region_name = self._get_region_name(device_id) # deviceArn aws_session = self._get_aws_session(region_name=region_name) return AwsDevice(arn=device_id, aws_session=aws_session) def get_tasks_by_tag( self, key: str, values: Optional[List[str]] = None, region_names: Optional[List[str]] = None ) -> List[str]: """ Retrieves a list of quantum task ARNs that match the specified tag keys or key/value pairs. Args: key (str): The tag key to match. values (Optional[List[str]]): A list of tag values to match against the provided key. If None, tasks with the specified key, regardless of its value, are matched. region_names (Optional[List[str]]): A list of region names to search. If None, all regions in `self.REGIONS` are searched. Returns: List[str]: A list of ARNs for quantum tasks that match the given tag criteria. Raises: QbraidError: If the function is called within a qBraid quantum job environment where AWS S3 requests are not supported. """ try: jobs_state = quantum_lib_proxy_state("braket") jobs_enabled = jobs_state.get("enabled", False) except ValueError: jobs_enabled = False if jobs_enabled: raise QbraidError("AWS S3 requests not supported by qBraid quantum jobs.") region_names = ( region_names if region_names is not None and len(region_names) > 0 else list(self.REGIONS) ) values = values if values is not None else [] tasks = [] for region_name in region_names: client = boto3.client("resourcegroupstaggingapi", region_name=region_name) response = client.get_resources( TagFilters=[ { "Key": key, "Values": values, } ], ) matches = [t["ResourceARN"] for t in response["ResourceTagMappingList"]] tasks += matches return tasks