Source code for qiskit_metal.renderers.renderer_mpl.mpl_renderer

# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2017, 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""MPL Renderer."""
import logging
import random
import sys
from typing import TYPE_CHECKING, List

import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from cycler import cycler
from .patch import PolygonPatch
from IPython.display import display
from matplotlib.axes import Axes
from matplotlib.backends.backend_qt5agg import \
    FigureCanvasQTAgg as FigureCanvas
from matplotlib.cbook import _OrderedSet
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.figure import Figure
from matplotlib.transforms import Bbox

from shapely.geometry import CAP_STYLE, JOIN_STYLE, LineString

from ... import Dict
from ...designs import QDesign
from .mpl_interaction import MplInteraction, PanAndZoom
from .mpl_toolbox import _axis_set_watermark_img, clear_axis, get_prop_cycle

from .. import config
if not config.is_building_docs():
    from ...toolbox_python.utility_functions import log_error_easy
    from qiskit_metal.toolbox_python.utility_functions import bad_fillet_idxs

if TYPE_CHECKING:
    from ..._gui.main_window import MetalGUI
    from ..._gui.widgets.plot_widget.plot_window import QMainWindowPlot
    from .mpl_canvas import PlotCanvas
    from qiskit_metal.elements.elements_handler import QGeometryTables

__all__ = ['QMplRenderer']

to_poly_patch = np.vectorize(PolygonPatch)


[docs] class QMplRenderer(): """Matplotlib handle all rendering of an axis. The axis is given in the function render. Access: self = gui.canvas.metal_renderer """ def __init__(self, canvas: 'PlotCanvas', design: QDesign, logger: logging.Logger): """ Args: canvas (PlotCanvas): The canvas design (QDesign): The design logger (logging.Logger): The logger """ super().__init__() self.logger = logger self.canvas = canvas self.ax = None self.design = design self.options = Dict(resolution='16',) # Filter view options self.hidden_layers = set() # Set of component ids which are integers. self._hidden_components = set() self.colors = [ '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' ] self.set_design(design)
[docs] def get_color_num(self, num: int) -> str: """Get the color from the given number. Args: num (int): number Return: str: color """ return self.colors[num % len(self.colors)]
[docs] def hide_component(self, name): """Hide the component with the given name. Args: name (str): Component name """ comp_id = self.design.components[name].id self._hidden_components.add(comp_id)
[docs] def show_component(self, name): """Show the component with the given name. Args: name (str): Component name """ comp_id = self.design.components[name].id self._hidden_components.discard(name)
[docs] def hide_layer(self, name): """Hide the layer with the given name. Args: name (str): Layer name """ self.hidden_layers.add(name)
[docs] def show_layer(self, name): """Show the layer with the given name. Args: name (str): Layer name """ self.hidden_layers.discard(name)
[docs] def set_design(self, design: QDesign): """Set the design. Args: design (QDesign): The design """ self.design = design self.clear_options()
# TODO
[docs] def clear_options(self): """Clear all options.""" self._hidden_components.clear() self.hidden_layers.clear()
[docs] def render(self, ax: Axes): """Assumes that the axis has been cleared already and so on. Args: ax (matplotlib.axes.Axes): mpl axis to draw on """ self.logger.debug('Rendering element tables to plot window.') self.render_tables(ax)
[docs] def get_mask(self, table: pd.DataFrame) -> pd.Series: """Gets the mask. Args: table (pd.DataFrame): dataframe Returns: pd.Series: return pandas index series with boolen mask - i.e., which are not hidden or otherwise """ # TODO: Ideally these should be replaced with interface functions, # not direct access to underlying internal representation mask = table.layer.isin(self.hidden_layers) mask = table.component.isin(self._hidden_components) return ~mask # not
def _render_poly_array(self, ax: Axes, poly_array: np.array, mpl_kw: dict): """Render the poly array. Args: ax (Axes): The axis poly_array (np.array): The poly mpl_kw (dict): The parameters dictionary """ if len(poly_array) > 0: poly_array = to_poly_patch(poly_array) ax.add_collection(PatchCollection(poly_array, **mpl_kw)) @property def qgeometry(self) -> 'QGeometryTables': """Return the qgeometry of the design.""" return self.design.qgeometry # TODO: move to some config and user input also make widget styles = { 'path': { 'base': dict(linewidth=2, alpha=0.5), 'subtracted': dict(), # linestyle='--', edgecolors='k', color='gray'), 'non-subtracted': dict() }, 'poly': { 'base': dict(linewidth=1, alpha=0.5, edgecolors='k'), 'subtracted': dict(linestyle='--', color='gray'), 'non-subtracted': dict() }, 'JJ': { 'base': dict(linewidth=1, alpha=0.2, edgecolors='k'), 'subtracted': dict(linestyle='--', color='gray'), 'non-subtracted': dict() } } """Styles"""
[docs] def get_style(self, element_type: str, subtracted=False, layer=None, extra=None): """Get the style. Args: element_type (str): The type of element. subtracted (bool): True to subtract the key. Defaults to False. layer (layer): The layer. Defaults to None. extra (dict): Extra stuff to add. Defaults to None. Return: dict: Style dictionary """ # element_type - poly path extra = extra or {} key = 'subtracted' if subtracted else 'non-subtracted' kw = { **self.styles[element_type].get('base', {}), **self.styles[element_type].get(key, {}), **extra } # TODO: maybe pop keys that are invalid for line etc. # we could have a validation flag to validate for specific poly / path return kw
[docs] def render_tables(self, ax: Axes): """Render the tables. Args: ax (Axes): The axes """ for element_type, table in self.qgeometry.tables.items(): # Mask the table table = table[self.get_mask(table)] # subtracted mask = table['subtract'] == True render_func = getattr(self, f'render_{element_type}') render_func(table[mask], ax, subtracted=True) # non-subtracted table1 = table[~mask] # TODO: do by layer and color # self.get_color_num() # TODO: Check that the function exists render_func = getattr(self, f'render_{element_type}') render_func(table1, ax, subtracted=False)
[docs] def render_junction(self, table: pd.DataFrame, ax: Axes, subtracted: bool = False, extra_kw: dict = None): """Render a table of junction geometry. A junction is basically drawn like a path with finite width and no fillet. Args: table (DataFrame): Element table ax (matplotlib.axes.Axes): Axis to render on extra_kw (dict): Style params """ if len(table) > 0: mask = (table.width == 0) | table.width.isna() table1 = table[~mask] if len(table1) > 0: table1.geometry = table1[['geometry', 'width']].apply( lambda x: x[0].buffer(distance=float(x[1]) / 2., cap_style=CAP_STYLE.flat, join_style=JOIN_STYLE.mitre, resolution=int(self.options[ 'resolution'])), axis=1) kw = self.get_style('JJ', subtracted=subtracted, extra=extra_kw) self.render_poly(table1, ax, subtracted=subtracted, extra_kw=kw) table1 = table[mask] if len(table1) > 0: self.logger.warning( 'One or more junctions have zero width. Consider changing this.' )
[docs] def render_poly(self, table: pd.DataFrame, ax: Axes, subtracted: bool = False, extra_kw: dict = None): """Render a table of poly geometry. Args: table (DataFrame): Element table ax (matplotlib.axes.Axes): Axis to render on kw (dict): Style params """ if len(table) < 1: return kw = self.get_style('poly', subtracted=subtracted, extra=extra_kw) self._render_poly_array(ax, table.geometry, kw)
[docs] def render_fillet(self, table): """Renders fillet path. Args: table (DataFrame): Table of elements with fillets Returns: DataFrame table with geometry field updated with a polygon filleted path. """ table['geometry'] = table.apply(self.fillet_path, axis=1) return table
[docs] def fillet_path(self, row): """Output the filleted path. Args: row (DataFrame): Row to fillet. Returns: Polygon of the new filleted path. """ path = row["geometry"].coords if len(path) <= 2: # only start and end points, no need to fillet return row["geometry"] newpath = np.array([path[0]]) # Get list of vertices that can't be filleted no_fillet = bad_fillet_idxs(path, row["fillet"], self.design.template_options.PRECISION) # Iterate through every three-vertex corner for (i, (start, corner, end)) in enumerate(zip(path, path[1:], path[2:])): if i + 1 in no_fillet: # don't fillet this corner newpath = np.concatenate((newpath, np.array([corner]))) else: fillet = self._calc_fillet(np.array(start), np.array(corner), np.array(end), row["fillet"], int(self.options['resolution'])) if fillet is not False: newpath = np.concatenate((newpath, fillet)) else: newpath = np.concatenate((newpath, np.array([corner]))) newpath = np.concatenate((newpath, np.array([end]))) return LineString(newpath)
def _calc_fillet(self, vertex_start, vertex_corner, vertex_end, radius, points=16): """Returns the filleted path based on the start, corner, and end vertices and the fillet radius. Args: vertex_start (np.ndarray): x-y coordinates of starting vertex. vertex_corner (np.ndarray): x-y coordinates of corner vertex. vertex_end (np.ndarray): x-y coordinates of end vertex. radius (float): Fillet radius. points (int): Number of points to draw in the fillet corner. """ # Start, corner, and end vertices must be distinct if np.array_equal(vertex_start, vertex_corner) or np.array_equal( vertex_end, vertex_corner): return False # Vectors pointing from corner to start and end vertices, respectively # Also calculate their lengths and unit vectors sc_vec = vertex_start - vertex_corner ec_vec = vertex_end - vertex_corner sc_norm = np.linalg.norm(sc_vec) ec_norm = np.linalg.norm(ec_vec) sc_uvec = sc_vec / sc_norm ec_uvec = ec_vec / ec_norm # Angle between previous unit vectors end_angle = np.arccos(np.dot(sc_uvec, ec_uvec)) # Start, corner, and end vertices can't be collinear if (end_angle == 0) or (end_angle == np.pi): return False # Fillet circle must be small enough to fit inside corner if radius / np.tan(end_angle / 2) > min(sc_norm, ec_norm): return False # Unit vector pointing from corner vertex to center of fillet circle net_uvec = (sc_uvec + ec_uvec) / np.linalg.norm(sc_uvec + ec_uvec) # Coordinates of center of fillet circle circle_center = vertex_corner + net_uvec * radius / np.sin( end_angle / 2) # Deltas represent displacement from corner vertex to circle center # Midpoint angle from circle center to corner, wrt to horizontal extending from former # Note: arctan is fine for angles in range (-pi / 2, pi / 2] but needs extra pi factor otherwise delta_x = vertex_corner[0] - circle_center[0] delta_y = vertex_corner[1] - circle_center[1] if delta_x: theta_mid = np.arctan(delta_y / delta_x) + np.pi * int(delta_x < 0) else: theta_mid = np.pi * ((1 - 2 * int(delta_y < 0)) + int(delta_y < 0)) # Start and end sweep angles determined relative to midpoint angle # Swap them as needed to resolve ambiguity in arctan theta_start = theta_mid - (np.pi - end_angle) / 2 theta_end = theta_mid + (np.pi - end_angle) / 2 p1 = circle_center + radius * np.array( [np.cos(theta_start), np.sin(theta_start)]) p2 = circle_center + radius * np.array( [np.cos(theta_end), np.sin(theta_end)]) if np.linalg.norm(vertex_start - p2) < np.linalg.norm(vertex_start - p1): theta_start, theta_end = theta_end, theta_start # Populate the fillet corner, skipping the start point since it's already added path = np.array([ circle_center + radius * np.array( [np.cos(theta_start), np.sin(theta_start)]) ]) for theta in np.linspace(theta_start, theta_end, points)[1:]: path = np.concatenate( (path, np.array([ circle_center + radius * np.array([np.cos(theta), np.sin(theta)]) ]))) return path
[docs] def render_path(self, table: pd.DataFrame, ax: Axes, subtracted: bool = False, extra_kw: dict = None): """Render a table of path geometry. Args: table (DataFrame): Element table ax (matplotlib.axes.Axes): Axis to render on kw (dict): Style params """ if len(table) < 1: return # mask for all non zero width paths # TODO: could there be a problem with float vs int here? mask = (table.width == 0) | table.width.isna() # print(f'subtracted={subtracted}\n\n') # display(table) # display(imask) # convert to polys - handle non zero width table1 = table[~mask] mask2 = (table1.fillet == 0) table2 = table1[~mask2] for index, row in table2[table2.fillet.notnull()].iterrows(): table1.loc[index, 'geometry'] = self.fillet_path(row) if len(table1) > 0: table1.geometry = table1[['geometry', 'width']].apply(lambda x: x[ 0].buffer(distance=float(x[1]) / 2., cap_style=CAP_STYLE.flat, join_style=JOIN_STYLE.mitre, resolution=int(self.options['resolution'])), axis=1) kw = self.get_style('poly', subtracted=subtracted, extra=extra_kw) # render components self.render_poly(table1, ax, subtracted=subtracted, extra_kw=kw) # handle zero width table1 = table[mask] # best way to plot? # TODO: speed and vectorize? if len(table1) > 0: kw = self.get_style('path', subtracted=subtracted, extra=extra_kw) line_segments = LineCollection(table1.geometry) ax.add_collection(line_segments)
# DEFAULT['renderer_mpl'] = Dict( # annot_conectors=Dict( # ofst=[0.025, 0.025], # annotate_kw=dict( # called by ax.annotate # color='r', # arrowprops=dict(color='r', shrink=0.1, width=0.05, headwidth=0.1) # ), # line_kw=dict(lw=2, c='r') # ), # ) # class QRendererMPL(QRendererGui): # """ # Renderer for matplotlib in a GUI environment. # TODO: How do we handle component selection, etc. # """ # name = 'mpl' # element_extensions = dict() # def render_shapely(self, obj, kw=None): # # TODO: simplify, specialize, and update this function # # right now, this is just calling the V0.1 old style # render(obj, ax=self.ax, kw= {} or kw) # def render_connectors(self): # ''' # Plots all connectors on the active axes. Draws the 1D line that # represents the "port" of a connector point. These are referenced for smart placement # of Metal components, such as when using functions like Metal_CPW_Connect. # TODO: add some filter for sense of what components are visible? # or on what chip the connectors are # ''' # for name, conn in self.design.connectors.items(): # line = LineString(conn.points) # self.render_shapely(line, kw=DEFAULT.annot_conectors.line_kw) # self.ax.annotate(name, xy=conn.middle[:2], xytext=conn.middle + # np.array(DEFAULT.annot_conectors.ofst), # **DEFAULT.annot_conectors.annotate_kw)