Source code for qbraid.transpiler.conversions.cirq.cirq_qasm_parser

# Copyright (C) 2023 qBraid
# Copyright (C) The Cirq Developers
#
# This file is part of the qBraid-SDK.
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. This specific file, adapted from Cirq, is dual-licensed under both the
# Apache License, Version 2.0, and the GPL v3. You may not use this file except in
# compliance with the applicable license. You may obtain a copy of the Apache License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# This file includes code adapted from Cirq (https://github.com/quantumlib/Cirq)
# with modifications by qBraid. The original copyright notice is included above.
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

# qbraid: skip-header
# isort: skip_file
# pylint: skip-file
# flake8: noqa
# fmt: off

"""
Module defining qBraid Cirq QASM parser.

"""
import functools
import operator
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Union, TYPE_CHECKING

import numpy as np
# import sympy
from ply import yacc

from cirq import ops, Circuit, NamedQubit, CX
from cirq.circuits.qasm_output import QasmUGate
from cirq.contrib.qasm_import._lexer import QasmLexer
from cirq.contrib.qasm_import.exception import QasmException

import qbraid.transpiler.conversions.cirq.custom_ops as qbraid_cirq_gates

# Redefined lexer tokens (4/7/21) to surpress warning:
# Token ['IF', 'NE'] defined, but not used
QasmLexer.tokens = [
    "FORMAT_SPEC",
    "NUMBER",
    "NATURAL_NUMBER",
    "QELIBINC",
    "ID",
    "PI",
    "QREG",
    "CREG",
    "MEASURE",
    "ARROW",
]

if TYPE_CHECKING:
    import cirq


class Qasm:
    """Qasm stores the final result of the Qasm parsing."""

    def __init__(
        self, supported_format: bool, qelib1_include: bool, qregs: dict, cregs: dict, c: Circuit
    ):
        # defines whether the Quantum Experience standard header
        # is present or not
        self.qelib1Include = qelib1_include
        # defines if it has a supported format or not
        self.supportedFormat = supported_format
        # circuit
        self.qregs = qregs
        self.cregs = cregs
        self.circuit = c


class QasmGateStatement:
    """Specifies how to convert a call to an OpenQASM gate
    to a list of cirq.GateOperation's.
    Has the responsibility to validate the arguments
    and parameters of the call and to generate a list of corresponding
    cirq.GateOperation's in the 'on' method.
    """

    def __init__(
        self,
        qasm_gate: str,
        cirq_gate: Union[ops.Gate, Callable[[List[float]], ops.Gate]],
        num_params: int,
        num_args: int,
    ):
        """Initializes a Qasm gate statement.
        Args:
            qasm_gate: The symbol of the QASM gate.
            cirq_gate: The gate class on the cirq side.
            num_params: The number of params taken by this gate.
            num_args: The number of qubits (used in validation) this gate takes.
        """
        self.qasm_gate = qasm_gate
        self.cirq_gate = cirq_gate
        self.num_params = num_params

        # at least one quantum argument is mandatory for gates to act on
        assert num_args >= 1
        self.num_args = num_args

    def _validate_args(self, args: List[List[ops.Qid]], lineno: int):
        if len(args) != self.num_args:
            raise QasmException(
                "{} only takes {} arg(s) (qubits and/or registers), "
                "got: {}, at line {}".format(self.qasm_gate, self.num_args, len(args), lineno)
            )

    def _validate_params(self, params: List[float], lineno: int):
        if len(params) != self.num_params:
            raise QasmException(
                "{} takes {} parameter(s), got: {}, at line {}".format(
                    self.qasm_gate, self.num_params, len(params), lineno
                )
            )

    def on(
        self, params: List[float], args: List[List[ops.Qid]], lineno: int
    ) -> Iterable[ops.Operation]:
        self._validate_args(args, lineno)
        self._validate_params(params, lineno)

        reg_sizes = np.unique([len(reg) for reg in args])
        if len(reg_sizes) > 2 or (len(reg_sizes) > 1 and reg_sizes[0] != 1):
            raise QasmException(
                f"Non matching quantum registers of length {reg_sizes} at line {lineno}"
            )

        # the actual gate we'll apply the arguments to might be a parameterized
        # or non-parameterized gate
        final_gate: ops.Gate = (
            self.cirq_gate if isinstance(self.cirq_gate, ops.Gate) else self.cirq_gate(params)
        )
        # OpenQASM gates can be applied on single qubits and qubit registers.
        # We represent single qubits as registers of size 1.
        # Based on the OpenQASM spec (https://arxiv.org/abs/1707.03429),
        # single qubit arguments can be mixed with qubit registers.
        # Given quantum registers of length reg_size and single qubits are both
        # used as arguments, we generate reg_size GateOperations via iterating
        # through each qubit of the registers 0 to n-1 and use the same one
        # qubit from the "single-qubit registers" for each operation.
        op_qubits = functools.reduce(
            cast(Callable[[List['cirq.Qid'], List['cirq.Qid']], List['cirq.Qid']], np.broadcast),
            args,
        )
        for qubits in op_qubits:
            if isinstance(qubits, ops.Qid):
                yield final_gate.on(qubits)
            elif len(np.unique(qubits)) < len(qubits):
                raise QasmException(f"Overlapping qubits in arguments at line {lineno}")
            else:
                yield final_gate.on(*qubits)


[docs] class QasmParser: """Cirq Parser for QASM strings. Example: qasm = "OPENQASM 2.0; qreg q1[2]; CX q1[0], q1[1];" parsedQasm = QasmParser().parse(qasm) """
[docs] def __init__(self): self.parser = yacc.yacc(module=self, debug=False, write_tables=False) self.circuit = Circuit() self.qregs: Dict[str, int] = {} self.cregs: Dict[str, int] = {} self.qelibinc = False self.lexer = QasmLexer() self.supported_format = False self.parsedQasm: Optional[Qasm] = None self.qubits: Dict[str, ops.Qid] = {} self.functions = { 'sin': np.sin, 'cos': np.cos, 'tan': np.tan, 'exp': np.exp, 'ln': np.log, 'sqrt': np.sqrt, 'acos': np.arccos, 'atan': np.arctan, 'asin': np.arcsin, } self.binary_operators = { '+': operator.add, '-': operator.sub, '*': operator.mul, '/': operator.truediv, '^': operator.pow, }
basic_gates: Dict[str, QasmGateStatement] = { 'CX': QasmGateStatement(qasm_gate='CX', cirq_gate=CX, num_params=0, num_args=2), 'U': QasmGateStatement( qasm_gate='U', num_params=3, num_args=1, # QasmUGate expects half turns cirq_gate=(lambda params: QasmUGate(*[p / np.pi for p in params])), ), } qelib_gates = { 'rx': QasmGateStatement( qasm_gate='rx', cirq_gate=(lambda params: ops.rx(params[0])), num_params=1, num_args=1 ), 'crx': QasmGateStatement( qasm_gate='crx', cirq_gate=(lambda params: ops.ControlledGate(ops.rx(params[0]))), num_params=1, num_args=2 ), 'sx': QasmGateStatement( qasm_gate='sx', num_params=0, num_args=1, cirq_gate=ops.XPowGate(exponent=0.5) ), 'sxdg': QasmGateStatement( qasm_gate='sxdg', num_params=0, num_args=1, cirq_gate=ops.XPowGate(exponent=-0.5) ), 'ry': QasmGateStatement( qasm_gate='ry', cirq_gate=(lambda params: ops.ry(params[0])), num_params=1, num_args=1 ), 'rz': QasmGateStatement( qasm_gate='rz', cirq_gate=(lambda params: ops.rz(params[0])), num_params=1, num_args=1 ), 'id': QasmGateStatement( qasm_gate='id', cirq_gate=ops.IdentityGate(1), num_params=0, num_args=1 ), 'u1': QasmGateStatement( qasm_gate='u1', cirq_gate=(lambda params: ops.ZPowGate(exponent=params[0] / np.pi)), num_params=1, num_args=1, ), 'u2': QasmGateStatement( qasm_gate='u2', cirq_gate=(lambda params: qbraid_cirq_gates.U2Gate(*params)), num_params=2, num_args=1, ), 'u3': QasmGateStatement( qasm_gate='u3', cirq_gate=(lambda params: qbraid_cirq_gates.U3Gate(*params)), num_params=3, num_args=1, ), 'u': QasmGateStatement( qasm_gate='u', cirq_gate=(lambda params: qbraid_cirq_gates.U3Gate(*params)), num_params=3, num_args=1, ), 'r': QasmGateStatement( qasm_gate='r', num_params=2, num_args=1, cirq_gate=( lambda params: QasmUGate( params[0] / np.pi, (params[1] / np.pi) - 0.5, (-params[1] / np.pi) + 0.5 ) ), ), 'x': QasmGateStatement(qasm_gate='x', num_params=0, num_args=1, cirq_gate=ops.X), 'y': QasmGateStatement(qasm_gate='y', num_params=0, num_args=1, cirq_gate=ops.Y), 'z': QasmGateStatement(qasm_gate='z', num_params=0, num_args=1, cirq_gate=ops.Z), 'h': QasmGateStatement(qasm_gate='h', num_params=0, num_args=1, cirq_gate=ops.H), 's': QasmGateStatement(qasm_gate='s', num_params=0, num_args=1, cirq_gate=ops.S), 'cs': QasmGateStatement(qasm_gate='cs', num_params=0, num_args=2, cirq_gate=ops.ControlledGate(ops.S)), 't': QasmGateStatement(qasm_gate='t', num_params=0, num_args=1, cirq_gate=ops.T), 'cx': QasmGateStatement(qasm_gate='cx', cirq_gate=CX, num_params=0, num_args=2), 'cy': QasmGateStatement( qasm_gate='cy', cirq_gate=ops.ControlledGate(ops.Y), num_params=0, num_args=2 ), 'cz': QasmGateStatement(qasm_gate='cz', cirq_gate=ops.CZ, num_params=0, num_args=2), 'ccz': QasmGateStatement(qasm_gate='ccz', cirq_gate=ops.CCZ, num_params=0, num_args=3), 'ch': QasmGateStatement( qasm_gate='ch', cirq_gate=ops.ControlledGate(ops.H), num_params=0, num_args=2 ), 'swap': QasmGateStatement(qasm_gate='swap', cirq_gate=ops.SWAP, num_params=0, num_args=2), 'cswap': QasmGateStatement( qasm_gate='cswap', num_params=0, num_args=3, cirq_gate=ops.CSWAP ), 'ccx': QasmGateStatement(qasm_gate='ccx', num_params=0, num_args=3, cirq_gate=ops.CCX), 'c3x': QasmGateStatement(qasm_gate='c3x', num_params=0, num_args=4, cirq_gate=ops.ControlledGate(ops.CCX)), 'c4x': QasmGateStatement(qasm_gate='c4x', num_params=0, num_args=5, cirq_gate=ops.ControlledGate(ops.ControlledGate(ops.CCX))), 'sdg': QasmGateStatement(qasm_gate='sdg', num_params=0, num_args=1, cirq_gate=ops.S**-1), 'csdg': QasmGateStatement(qasm_gate='csdg', num_params=0, num_args=2, cirq_gate=ops.ControlledGate(ops.S**-1)), 'tdg': QasmGateStatement(qasm_gate='tdg', num_params=0, num_args=1, cirq_gate=ops.T**-1), 'crz': QasmGateStatement( qasm_gate='crz', cirq_gate=(lambda params: ops.ControlledGate(ops.rz(params[0]))), num_params=1, num_args=2, ), 'cry': QasmGateStatement( qasm_gate='cry', cirq_gate=(lambda params: ops.ControlledGate(ops.ry(params[0]))), num_params=1, num_args=2 ), 'csx': QasmGateStatement( qasm_gate='csx', num_params=0, num_args=2, cirq_gate=ops.ControlledGate(ops.XPowGate(exponent=0.5)) ), 'c3sx': QasmGateStatement( qasm_gate='c3sx', num_params=0, num_args=4, cirq_gate=ops.ControlledGate(ops.ControlledGate(ops.ControlledGate(ops.XPowGate(exponent=0.5)))) ), 'c3sqrtx': QasmGateStatement( qasm_gate='c3sqrtx', num_params=0, num_args=4, cirq_gate=ops.ControlledGate(ops.ControlledGate(ops.ControlledGate(ops.XPowGate(exponent=0.5)))) ), 'cu1': QasmGateStatement( qasm_gate='cu1', cirq_gate=(lambda params: ops.ControlledGate(ops.ZPowGate(exponent=params[0] / np.pi))), num_params=1, num_args=2, ), 'cu3': QasmGateStatement( qasm_gate='cu3', cirq_gate=(lambda params: ops.ControlledGate(qbraid_cirq_gates.U3Gate(*params))), num_params=3, num_args=2, ), 'cu': QasmGateStatement( qasm_gate='cu', cirq_gate=(lambda params: ops.ControlledGate(qbraid_cirq_gates.U3Gate(*params))), num_params=3, num_args=2, ), 'p': QasmGateStatement( qasm_gate='p', cirq_gate=(lambda params: ops.ZPowGate(exponent=params[0] / np.pi)), num_params=1, num_args=1, ), 'cp': QasmGateStatement( qasm_gate='cp', cirq_gate=(lambda params: ops.CZPowGate(exponent=params[0] / np.pi)), num_params=1, num_args=2, ), 'iswap': QasmGateStatement( qasm_gate='iswap', cirq_gate=ops.ISWAP, num_params=0, num_args=2 ), 'rzz': QasmGateStatement( qasm_gate='rzz', cirq_gate=(lambda params: qbraid_cirq_gates.rzz(params[0])), num_params=1, num_args=2, ), } all_gates = {**basic_gates, **qelib_gates} tokens = QasmLexer.tokens start = 'start' precedence = (('left', '+', '-'), ('left', '*', '/'), ('right', '^')) def p_start(self, p): """start : qasm""" p[0] = p[1] def p_qasm_format_only(self, p): """qasm : format""" self.supported_format = True p[0] = Qasm(self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit) def p_qasm_no_format_specified_error(self, p): """qasm : QELIBINC | circuit""" if self.supported_format is False: raise QasmException("Missing 'OPENQASM 2.0;' statement") def p_qasm_include(self, p): """qasm : qasm QELIBINC""" self.qelibinc = True p[0] = Qasm(self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit) def p_qasm_circuit(self, p): """qasm : qasm circuit""" p[0] = Qasm(self.supported_format, self.qelibinc, self.qregs, self.cregs, p[2]) def p_format(self, p): """format : FORMAT_SPEC""" if p[1] != "2.0": raise QasmException( "Unsupported OpenQASM version: {}, " "only 2.0 is supported currently by Cirq".format(p[1]) ) # circuit : new_reg circuit # | gate_op circuit # | measurement circuit # | if circuit # | empty def p_circuit_reg(self, p): """circuit : new_reg circuit""" p[0] = self.circuit def p_circuit_gate_or_measurement(self, p): """circuit : circuit gate_op | circuit measurement""" self.circuit.append(p[2]) p[0] = self.circuit def p_circuit_empty(self, p): """circuit : empty""" p[0] = self.circuit # qreg and creg def p_new_reg(self, p): """new_reg : QREG ID '[' NATURAL_NUMBER ']' ';' | CREG ID '[' NATURAL_NUMBER ']' ';'""" name, length = p[2], p[4] if name in self.qregs.keys() or name in self.cregs.keys(): raise QasmException(f"{name} is already defined at line {p.lineno(2)}") if length == 0: raise QasmException(f"Illegal, zero-length register '{name}' at line {p.lineno(4)}") if p[1] == "qreg": self.qregs[name] = length else: self.cregs[name] = length p[0] = (name, length) # gate operations # gate_op : ID qargs # | ID ( params ) qargs def p_gate_op_no_params(self, p): """gate_op : ID qargs""" self._resolve_gate_operation(p[2], gate=p[1], p=p, params=[]) def p_gate_op_with_params(self, p): """gate_op : ID '(' params ')' qargs""" self._resolve_gate_operation(args=p[5], gate=p[1], p=p, params=p[3]) def _resolve_gate_operation( self, args: List[List[ops.Qid]], gate: str, p: Any, params: List[float] ): gate_set = self.basic_gates if not self.qelibinc else self.all_gates if gate not in gate_set.keys(): msg = 'Unknown gate "{}" at line {}{}'.format( gate, p.lineno(1), ", did you forget to include qelib1.inc?" if not self.qelibinc else "", ) raise QasmException(msg) p[0] = gate_set[gate].on(args=args, params=params, lineno=p.lineno(1)) # params : parameter ',' params # | parameter def p_params_multiple(self, p): """params : expr ',' params""" p[3].insert(0, p[1]) p[0] = p[3] def p_params_single(self, p): """params : expr""" p[0] = [p[1]] # expr : term # | func '(' expression ')' """ # | binary_op # | unary_op def p_expr_term(self, p): """expr : term""" p[0] = p[1] def p_expr_parens(self, p): """expr : '(' expr ')'""" p[0] = p[2] def p_expr_function_call(self, p): """expr : ID '(' expr ')'""" func = p[1] if func not in self.functions.keys(): raise QasmException(f"Function not recognized: '{func}' at line {p.lineno(1)}") p[0] = self.functions[func](p[3]) def p_expr_unary(self, p): """expr : '-' expr | '+' expr""" if p[1] == '-': p[0] = -p[2] else: p[0] = p[2] def p_expr_binary(self, p): """expr : expr '*' expr | expr '/' expr | expr '+' expr | expr '-' expr | expr '^' expr """ p[0] = self.binary_operators[p[2]](p[1], p[3]) def p_term(self, p): """term : NUMBER | NATURAL_NUMBER | PI""" p[0] = p[1] # qargs : qarg ',' qargs # | qarg ';' def p_args_multiple(self, p): """qargs : qarg ',' qargs""" p[3].insert(0, p[1]) p[0] = p[3] def p_args_single(self, p): """qargs : qarg ';'""" p[0] = [p[1]] # qarg : ID # | ID '[' NATURAL_NUMBER ']' def p_quantum_arg_register(self, p): """qarg : ID""" reg = p[1] if reg not in self.qregs.keys(): raise QasmException(f'Undefined quantum register "{reg}" at line {p.lineno(1)}') qubits = [] for idx in range(self.qregs[reg]): arg_name = self.make_name(idx, reg) if arg_name not in self.qubits.keys(): self.qubits[arg_name] = NamedQubit(arg_name) qubits.append(self.qubits[arg_name]) p[0] = qubits # carg : ID # | ID '[' NATURAL_NUMBER ']' def p_classical_arg_register(self, p): """carg : ID""" reg = p[1] if reg not in self.cregs.keys(): raise QasmException(f'Undefined classical register "{reg}" at line {p.lineno(1)}') p[0] = [self.make_name(idx, reg) for idx in range(self.cregs[reg])] def make_name(self, idx, reg): return str(reg) + "_" + str(idx) def p_quantum_arg_bit(self, p): """qarg : ID '[' NATURAL_NUMBER ']'""" reg = p[1] idx = p[3] arg_name = self.make_name(idx, reg) if reg not in self.qregs.keys(): raise QasmException(f'Undefined quantum register "{reg}" at line {p.lineno(1)}') size = self.qregs[reg] if idx >= size: raise QasmException( 'Out of bounds qubit index {} ' 'on register {} of size {} ' 'at line {}'.format(idx, reg, size, p.lineno(1)) ) if arg_name not in self.qubits.keys(): self.qubits[arg_name] = NamedQubit(arg_name) p[0] = [self.qubits[arg_name]] def p_classical_arg_bit(self, p): """carg : ID '[' NATURAL_NUMBER ']'""" reg = p[1] idx = p[3] arg_name = self.make_name(idx, reg) if reg not in self.cregs.keys(): raise QasmException(f'Undefined classical register "{reg}" at line {p.lineno(1)}') size = self.cregs[reg] if idx >= size: raise QasmException( 'Out of bounds bit index {} ' 'on classical register {} of size {} ' 'at line {}'.format(idx, reg, size, p.lineno(1)) ) p[0] = [arg_name] # measurement operations # measurement : MEASURE qarg ARROW carg def p_measurement(self, p): """measurement : MEASURE qarg ARROW carg ';'""" qreg = p[2] creg = p[4] if len(qreg) != len(creg): raise QasmException( 'mismatched register sizes {} -> {} for measurement ' 'at line {}'.format(len(qreg), len(creg), p.lineno(1)) ) p[0] = [ ops.MeasurementGate(num_qubits=1, key=creg[i]).on(qreg[i]) for i in range(len(qreg)) ] # if operations # if : IF '(' carg EQ NATURAL_NUMBER ')' ID qargs # def p_if(self, p): # """if : IF '(' carg EQ NATURAL_NUMBER ')' gate_op""" # # We have to split the register into bits (since that's what measurement does above), # # and create one condition per bit, checking against that part of the binary value. # conditions = [] # for i, key in enumerate(p[3]): # v = (p[5] >> i) & 1 # conditions.append(sympy.Eq(sympy.Symbol(key), v)) # p[0] = [ # ops.ClassicallyControlledOperation(conditions=conditions, sub_operation=tuple(p[7])[0]) # ] def p_error(self, p): if p is None: raise QasmException('Unexpected end of file') raise QasmException( f"""Syntax error: '{p.value}' {self.debug_context(p)} at line {p.lineno}, column {self.find_column(p)}""" ) def find_column(self, p): line_start = self.qasm.rfind('\n', 0, p.lexpos) + 1 return (p.lexpos - line_start) + 1 def p_empty(self, p): """empty :""" def parse(self, qasm: str) -> Qasm: if self.parsedQasm is None: self.qasm = qasm self.lexer.input(self.qasm) self.parsedQasm = self.parser.parse(lexer=self.lexer) return self.parsedQasm def debug_context(self, p): debug_start = max(self.qasm.rfind('\n', 0, p.lexpos) + 1, p.lexpos - 5) debug_end = min(self.qasm.find('\n', p.lexpos, p.lexpos + 5), p.lexpos + 5) return ( "..." + self.qasm[debug_start:debug_end] + "\n" + (" " * (3 + p.lexpos - debug_start)) + "^" )