Source code for qbraid.visualization.plot_counts

# Copyright (C) 2023 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Module for plotting historgram of measurement counts against quantum states.

"""

from typing import Callable, Dict, List, Optional, Union

import matplotlib.pyplot as plt
from matplotlib import colormaps

# pylint: disable=too-many-arguments,unnecessary-lambda


def _counts_to_decimal(counts: dict) -> dict:
    """
    Converts a dictionary of counts to decimal form.

    Args:
        counts_dict (dict): A dictionary where keys are strings representing states and
                            values are integers representing counts.

    Returns:
        dict: A dictionary with the same keys as the input dictionary, but with values
              converted to their respective decimal proportions.

    Raises:
        ValueError: If the total count is zero.
        TypeError:  If the input dictionary contains non-integer values

    Example:
        >>> counts_to_decimal({"00": 10, "01": 15, "10": 20, "11": 5})
        {"00": 0.2, "01": 0.3, "10": 0.4, "11": 0.1}
    """

    try:
        total_count = sum(counts.values())
    except TypeError as err:
        raise TypeError("Counts values must be integers.") from err

    if total_count == 0:
        raise ValueError("Total count cannot be zero.")

    decimal_dict = {key: value / total_count for key, value in counts.items()}

    return decimal_dict


def _plot_data(
    counts: Union[List[Dict], dict],
    legend: Optional[Union[List[str], str]] = None,
    colors: Optional[Union[List[str], str]] = None,
    title: Optional[str] = None,
    x_label: Optional[str] = None,
    y_label: Optional[str] = None,
    show_plot: Optional[bool] = True,
    save_path: Optional[str] = None,
    transform_fn: Optional[Callable[[dict], dict]] = None,
    label_format_fn: Optional[Callable[[float], str]] = lambda x: str(x),
):
    if not isinstance(counts, list):
        counts = [counts]
        if isinstance(colors, str):
            colors = [colors]
        if isinstance(legend, str):
            legend = [legend]

    if transform_fn:
        counts = [transform_fn(counts_dict) for counts_dict in counts]

    all_states = sorted(set(state for counts_dict in counts for state in counts_dict.keys()))

    num_dicts = len(counts)
    bar_width = 0.8 / num_dicts

    x_positions = range(len(all_states))

    if colors is None:
        cmap = colormaps.get_cmap("tab10")
        colors = [cmap(i / 10) for i in range(num_dicts)]

    if len(colors) != len(counts):
        raise ValueError("Number of colors must match number of datasets")

    if isinstance(legend, list) and len(legend) != len(counts):
        raise ValueError("Number of legend labels must match number of datasets")

    for i, counts_dict in enumerate(counts):
        counts_iter = [counts_dict.get(state, 0) for state in all_states]

        default_label = f"Job {i}" if num_dicts > 1 else None
        label = legend[i] if legend and i < len(legend) else default_label

        bars = plt.bar(
            [x + (i * bar_width) for x in x_positions],
            counts_iter,
            width=bar_width,
            color=colors[i % num_dicts],
            label=label,
            align="center",
        )

        y_min, y_max = plt.gca().get_ylim()
        y_range = y_max - y_min
        offset_percentage = 0.02  # Adjust this value to get the desired offset
        y_offset = y_range * offset_percentage

        for hbar, count in zip(bars, counts_iter):
            plt.text(
                hbar.get_x() + hbar.get_width() / 2,
                hbar.get_height() + +y_offset,
                label_format_fn(count),
                ha="center",
                va="bottom",
                color="black",
                fontsize=8,
            )

    plt.xticks(x_positions, all_states, rotation=45)

    y_ticks = plt.gca().get_yticks()
    plt.yticks(y_ticks)
    plt.grid(axis="y", linestyle="--", linewidth=0.7, alpha=0.7)

    if y_label:
        plt.ylabel(y_label)

    if x_label:
        plt.xlabel(x_label)

    if title:
        plt.title(title)

    if legend or num_dicts > 1:
        plt.legend()

    if save_path:
        plt.savefig(save_path)

    if show_plot:
        plt.show()


[docs] def plot_distribution( counts: Union[List[Dict], Dict], legend: Optional[Union[List[str], str]] = None, colors: Optional[Union[List[str], str]] = None, title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, show_plot: Optional[bool] = True, save_path: Optional[str] = None, ): """ Plots a histogram probability distribution of quantum states. Args: counts (Union[List[Dict], Dict]): Dictionary or a list of dictionaries containing the quantum states as keys and their respective counts as values. legend (Optional[Union[List[str], str]]): List of strings or a single string representing the labels of the datasets. Defaults to None, where it generates default labels. colors (Optional[Union[List[str], str]]): List of strings or a single string representing the colors for each dataset. Defaults to None, where it generates a color sequence. title (Optional[str]): String representing the title of the plot. Defaults to None. x_label (Optional[str]): String representing the label for the x-axis. Defaults to None. y_label (Optional[str]): String representing the label for the y-axis. Defaults to None. show_plot (Optional[bool]): Boolean representing whether to show the plot. Defaults to True. save_path (Optional[str]): String representing the path to save the plot. Defaults to None. Returns: None: This function does not return a value; it displays a plot. Raises: ValueError: Raises an error if input arguments do not match the expected types or formats. Example: .. code-block:: python counts_dict1 = {'00': 50, '01': 30, '10': 10, '11': 10} counts_dict2 = {'00': 20, '01': 40, '10': 30, '11': 10} plot_distribution( counts=[counts_dict1, counts_dict2], legend=['First Execution', 'Second Execution'], colors=['crimson', 'midnightblue'], title="Quantum State Probability Distribution", x_label="Quantum States", y_label="Probabilities" ) """ y_label = "Quasi-probabilities" if not y_label else y_label _plot_data( counts, legend, colors, title, x_label, y_label, show_plot, save_path, transform_fn=_counts_to_decimal, label_format_fn=lambda x: "{:.3f}".format(x), # pylint: disable=consider-using-f-string )
[docs] def plot_histogram( counts: Union[List[Dict], Dict], legend: Optional[Union[List[str], str]] = None, colors: Optional[Union[List[str], str]] = None, title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, show_plot: Optional[bool] = True, save_path: Optional[str] = None, ): """ Plots a histogram of measurement counts against quantum states. Args: counts (Union[List[Dict], Dict]): Dictionary or a list of dictionaries containing the quantum states as keys and their respective counts as values. legend (Optional[Union[List[str], str]]): List of strings or a single string representing the labels of the datasets. Defaults to None, where it generates default labels. colors (Optional[Union[List[str], str]]): List of strings or a single string representing the colors for each dataset. Defaults to None, where it generates a color sequence. title (Optional[str]): String representing the title of the plot. Defaults to None. x_label (Optional[str]): String representing the label for the x-axis. Defaults to None. y_label (Optional[str]): String representing the label for the y-axis. Defaults to None. show_plot (Optional[bool]): Boolean representing whether to show the plot. Defaults to True. save_path (Optional[str]): String representing the path to save the plot. Defaults to None. Returns: None: This function does not return a value; it displays a plot. Raises: ValueError: Raises an error if input arguments do not match the expected types or formats. Example: .. code-block:: python counts_dict1 = {'00': 50, '01': 30, '10': 10, '11': 10} counts_dict2 = {'00': 20, '01': 40, '10': 30, '11': 10} plot_histogram( counts=[counts_dict1, counts_dict2], legend=['First Execution', 'Second Execution'], colors=['crimson', 'midnightblue'], title="Quantum State Measurement Counts", x_label="Quantum States", y_label="Counts" ) """ y_label = "Counts" if not y_label else y_label _plot_data( counts, legend, colors, title, x_label, y_label, show_plot, save_path, label_format_fn=lambda x: str(x), )