# -*- coding: utf-8 -*-
# This code is part of Qiskit.
#
# (C) Copyright IBM 2017, 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""MPL Renderer."""
import logging
from typing import TYPE_CHECKING, Optional
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection, PatchCollection
from shapely.geometry import CAP_STYLE, JOIN_STYLE, LineString
from qiskit_metal import Dict, config
from qiskit_metal.designs import QDesign
from qiskit_metal.renderers.renderer_mpl.patch import PolygonPatch
if not config.is_building_docs():
from qiskit_metal.toolbox_python.utility_functions import bad_fillet_idxs
if TYPE_CHECKING:
from qiskit_metal.elements.elements_handler import QGeometryTables
from .mpl_canvas import PlotCanvas
__all__ = ["QMplRenderer"]
to_poly_patch = np.vectorize(PolygonPatch)
[docs]
class QMplRenderer:
"""Matplotlib renderer for Metal designs.
Notes:
Access via ``gui.canvas.metal_renderer``. The axis to render is passed
in the ``render`` method.
"""
def __init__(self, canvas: "PlotCanvas", design: QDesign, logger: logging.Logger):
"""
Args:
canvas (PlotCanvas): The canvas
design (QDesign): The design
logger (logging.Logger): The logger
"""
super().__init__()
self.logger = logger
self.canvas = canvas
self.ax = None
self.design = design
self.options = Dict(
resolution="16",
)
# Filter view options
self.hidden_layers = set()
# Set of component ids which are integers.
self._hidden_components = set()
self.colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
self.set_design(design)
[docs]
def get_color_num(self, num: int) -> str:
"""Get the color from the given number.
Args:
num (int): number
Return:
str: color
"""
return self.colors[num % len(self.colors)]
[docs]
def hide_component(self, name):
"""Hide the component with the given name.
Args:
name (str): Component name
"""
comp_id = self.design.components[name].id
self._hidden_components.add(comp_id)
[docs]
def show_component(self, name):
"""Show the component with the given name.
Args:
name (str): Component name
"""
comp_id = self.design.components[name].id
self._hidden_components.discard(name)
[docs]
def hide_layer(self, name):
"""Hide the layer with the given name.
Args:
name (str): Layer name
"""
self.hidden_layers.add(name)
[docs]
def show_layer(self, name):
"""Show the layer with the given name.
Args:
name (str): Layer name
"""
self.hidden_layers.discard(name)
[docs]
def set_design(self, design: QDesign):
"""Set the design.
Args:
design (QDesign): The design
"""
self.design = design
self.clear_options()
# TODO
[docs]
def clear_options(self):
"""Clear all options."""
self._hidden_components.clear()
self.hidden_layers.clear()
[docs]
def render(self, ax: Axes):
"""Assumes that the axis has been cleared already and so on.
Args:
ax (matplotlib.axes.Axes): mpl axis to draw on
"""
self.logger.debug("Rendering element tables to plot window.")
self.render_tables(ax)
[docs]
def get_mask(self, table: pd.DataFrame) -> pd.Series:
"""Gets the mask.
Args:
table (pd.DataFrame): dataframe
Returns:
pd.Series: return pandas index series with boolen mask
- i.e., which are not hidden or otherwise
"""
# TODO: Ideally these should be replaced with interface functions,
# not direct access to underlying internal representation
mask = table.layer.isin(self.hidden_layers)
mask = table.component.isin(self._hidden_components)
return ~mask # not
def _render_poly_array(self, ax: Axes, poly_array: np.ndarray, mpl_kw: dict):
"""Render the poly array.
Args:
ax (Axes): The axis
poly_array (np.array): The poly
mpl_kw (dict): The parameters dictionary
"""
if len(poly_array) > 0:
poly_array = to_poly_patch(poly_array)
ax.add_collection(PatchCollection(poly_array, **mpl_kw))
@property
def qgeometry(self) -> "QGeometryTables":
"""Return the qgeometry of the design."""
return self.design.qgeometry
# TODO: move to some config and user input also make widget
styles = {
"path": {
"base": dict(linewidth=2, alpha=0.5),
"subtracted": dict(), # linestyle='--', edgecolors='k', color='gray'),
"non-subtracted": dict(),
},
"poly": {
"base": dict(linewidth=1, alpha=0.5, edgecolors="k"),
"subtracted": dict(linestyle="--", color="gray"),
"non-subtracted": dict(),
},
"JJ": {
"base": dict(linewidth=1, alpha=0.2, edgecolors="k"),
"subtracted": dict(linestyle="--", color="gray"),
"non-subtracted": dict(),
},
}
"""Styles"""
[docs]
def get_style(self, element_type: str, subtracted=False, layer=None, extra=None):
"""Get the style.
Args:
element_type (str): The type of element.
subtracted (bool): True to subtract the key. Defaults to False.
layer (layer): The layer. Defaults to None.
extra (dict): Extra stuff to add. Defaults to None.
Return:
dict: Style dictionary
"""
# element_type - poly path
extra = extra or {}
key = "subtracted" if subtracted else "non-subtracted"
kw = {
**self.styles[element_type].get("base", {}),
**self.styles[element_type].get(key, {}),
**extra,
}
# TODO: maybe pop keys that are invalid for line etc.
# we could have a validation flag to validate for specific poly / path
return kw
[docs]
def render_tables(self, ax: Axes):
"""Render the tables.
Args:
ax (Axes): The axes
"""
for element_type, table in self.qgeometry.tables.items():
# Mask the table
table = table[self.get_mask(table)]
# subtracted
mask = table["subtract"] == True
render_func = getattr(self, f"render_{element_type}")
render_func(table[mask], ax, subtracted=True)
# non-subtracted
table1 = table[~mask]
# TODO: do by layer and color
# self.get_color_num()
# TODO: Check that the function exists
render_func = getattr(self, f"render_{element_type}")
render_func(table1, ax, subtracted=False)
[docs]
def render_junction(
self,
table: pd.DataFrame,
ax: Axes,
subtracted: bool = False,
extra_kw: Optional[dict] = None,
):
"""Render a table of junction geometry.
A junction is basically drawn like a path with finite width and no fillet.
Args:
table (DataFrame): Element table
ax (matplotlib.axes.Axes): Axis to render on
extra_kw (dict): Style params
"""
if len(table) > 0:
mask = (table.width == 0) | table.width.isna()
table1 = table[~mask]
if len(table1) > 0:
table1["geometry"] = table1[["geometry", "width"]].apply(
lambda x: x.iloc[0].buffer(
distance=float(x[1]) / 2.0,
cap_style=CAP_STYLE.flat,
join_style=JOIN_STYLE.mitre,
quad_segs=int(self.options["resolution"]),
),
axis=1,
)
kw = self.get_style("JJ", subtracted=subtracted, extra=extra_kw)
self.render_poly(table1, ax, subtracted=subtracted, extra_kw=kw)
table1 = table[mask]
if len(table1) > 0:
self.logger.warning(
"One or more junctions have zero width. Consider changing this."
)
[docs]
def render_poly(
self,
table: pd.DataFrame,
ax: Axes,
subtracted: bool = False,
extra_kw: Optional[dict] = None,
):
"""Render a table of poly geometry.
Args:
table (DataFrame): Element table
ax (matplotlib.axes.Axes): Axis to render on
kw (dict): Style params
"""
if len(table) < 1:
return
kw = self.get_style("poly", subtracted=subtracted, extra=extra_kw)
self._render_poly_array(ax, table.geometry, kw)
[docs]
def render_fillet(self, table):
"""Renders fillet path.
Args:
table (DataFrame): Table of elements with fillets
Returns:
DataFrame table with geometry field updated with a polygon filleted path.
"""
table["geometry"] = table.apply(self.fillet_path, axis=1)
return table
[docs]
def fillet_path(self, row):
"""Output the filleted path.
Args:
row (DataFrame): Row to fillet.
Returns:
Polygon of the new filleted path.
"""
path = row["geometry"].coords
if len(path) <= 2: # only start and end points, no need to fillet
return row["geometry"]
newpath = np.array([path[0]])
# Get list of vertices that can't be filleted
no_fillet = bad_fillet_idxs(
path, row["fillet"], self.design.template_options.PRECISION
)
# Iterate through every three-vertex corner
for i, (start, corner, end) in enumerate(zip(path, path[1:], path[2:])):
if i + 1 in no_fillet: # don't fillet this corner
newpath = np.concatenate((newpath, np.array([corner])))
else:
fillet = self._calc_fillet(
np.array(start),
np.array(corner),
np.array(end),
row["fillet"],
int(self.options["resolution"]),
)
if fillet is not False:
newpath = np.concatenate((newpath, fillet))
else:
newpath = np.concatenate((newpath, np.array([corner])))
newpath = np.concatenate((newpath, np.array([end])))
return LineString(newpath)
def _calc_fillet(self, vertex_start, vertex_corner, vertex_end, radius, points=16):
"""Returns the filleted path based on the start, corner, and end
vertices and the fillet radius.
Args:
vertex_start (np.ndarray): x-y coordinates of starting vertex.
vertex_corner (np.ndarray): x-y coordinates of corner vertex.
vertex_end (np.ndarray): x-y coordinates of end vertex.
radius (float): Fillet radius.
points (int): Number of points to draw in the fillet corner.
"""
# Start, corner, and end vertices must be distinct
if np.array_equal(vertex_start, vertex_corner) or np.array_equal(
vertex_end, vertex_corner
):
return False
# Vectors pointing from corner to start and end vertices, respectively
# Also calculate their lengths and unit vectors
sc_vec = vertex_start - vertex_corner
ec_vec = vertex_end - vertex_corner
sc_norm = np.linalg.norm(sc_vec)
ec_norm = np.linalg.norm(ec_vec)
sc_uvec = sc_vec / sc_norm
ec_uvec = ec_vec / ec_norm
# Angle between previous unit vectors
end_angle = np.arccos(np.dot(sc_uvec, ec_uvec))
# Start, corner, and end vertices can't be collinear
if (end_angle == 0) or (end_angle == np.pi):
return False
# Fillet circle must be small enough to fit inside corner
if radius / np.tan(end_angle / 2) > min(sc_norm, ec_norm):
return False
# Unit vector pointing from corner vertex to center of fillet circle
net_uvec = (sc_uvec + ec_uvec) / np.linalg.norm(sc_uvec + ec_uvec)
# Coordinates of center of fillet circle
circle_center = vertex_corner + net_uvec * radius / np.sin(end_angle / 2)
# Deltas represent displacement from corner vertex to circle center
# Midpoint angle from circle center to corner, wrt to horizontal extending from former
# Note: arctan is fine for angles in range (-pi / 2, pi / 2] but needs extra pi factor otherwise
delta_x = vertex_corner[0] - circle_center[0]
delta_y = vertex_corner[1] - circle_center[1]
if delta_x:
theta_mid = np.arctan(delta_y / delta_x) + np.pi * int(delta_x < 0)
else:
theta_mid = np.pi * ((1 - 2 * int(delta_y < 0)) + int(delta_y < 0))
# Start and end sweep angles determined relative to midpoint angle
# Swap them as needed to resolve ambiguity in arctan
theta_start = theta_mid - (np.pi - end_angle) / 2
theta_end = theta_mid + (np.pi - end_angle) / 2
p1 = circle_center + radius * np.array(
[np.cos(theta_start), np.sin(theta_start)]
)
p2 = circle_center + radius * np.array([np.cos(theta_end), np.sin(theta_end)])
if np.linalg.norm(vertex_start - p2) < np.linalg.norm(vertex_start - p1):
theta_start, theta_end = theta_end, theta_start
# Populate the fillet corner, skipping the start point since it's already added
path = np.array(
[
circle_center
+ radius * np.array([np.cos(theta_start), np.sin(theta_start)])
]
)
for theta in np.linspace(theta_start, theta_end, points)[1:]:
path = np.concatenate(
(
path,
np.array(
[
circle_center
+ radius * np.array([np.cos(theta), np.sin(theta)])
]
),
)
)
return path
[docs]
def render_path(
self,
table: pd.DataFrame,
ax: Axes,
subtracted: bool = False,
extra_kw: Optional[dict] = None,
):
"""Render a table of path geometry.
Args:
table (DataFrame): Element table
ax (matplotlib.axes.Axes): Axis to render on
kw (dict): Style params
"""
if len(table) < 1:
return
# mask for all non zero width paths
# TODO: could there be a problem with float vs int here?
mask = (table.width == 0) | table.width.isna()
# print(f'subtracted={subtracted}\n\n')
# display(table)
# display(imask)
# convert to polys - handle non zero width
table1 = table[~mask]
mask2 = table1.fillet == 0
table2 = table1[~mask2]
for index, row in table2[table2.fillet.notnull()].iterrows():
table1.loc[index, "geometry"] = self.fillet_path(row)
if len(table1) > 0:
table1["geometry"] = table1[["geometry", "width"]].apply(
lambda x: x.iloc[0].buffer(
distance=float(x.iloc[1]) / 2.0,
cap_style=CAP_STYLE.flat,
join_style=JOIN_STYLE.mitre,
quad_segs=int(self.options["resolution"]),
),
axis=1,
)
kw = self.get_style("poly", subtracted=subtracted, extra=extra_kw)
# render components
self.render_poly(table1, ax, subtracted=subtracted, extra_kw=kw)
# handle zero width
table1 = table[mask]
# best way to plot?
# TODO: speed and vectorize?
if len(table1) > 0:
kw = self.get_style("path", subtracted=subtracted, extra=extra_kw)
line_segments = LineCollection(table1.geometry)
ax.add_collection(line_segments)
# DEFAULT['renderer_mpl'] = Dict(
# annot_conectors=Dict(
# ofst=[0.025, 0.025],
# annotate_kw=dict( # called by ax.annotate
# color='r',
# arrowprops=dict(color='r', shrink=0.1, width=0.05, headwidth=0.1)
# ),
# line_kw=dict(lw=2, c='r')
# ),
# )
# class QRendererMPL(QRendererGui):
# """
# Renderer for matplotlib in a GUI environment.
# TODO: How do we handle component selection, etc.
# """
# name = 'mpl'
# element_extensions = dict()
# def render_shapely(self, obj, kw=None):
# # TODO: simplify, specialize, and update this function
# # right now, this is just calling the V0.1 old style
# render(obj, ax=self.ax, kw= {} or kw)
# def render_connectors(self):
# '''
# Plots all connectors on the active axes. Draws the 1D line that
# represents the "port" of a connector point. These are referenced for smart placement
# of Metal components, such as when using functions like Metal_CPW_Connect.
# TODO: add some filter for sense of what components are visible?
# or on what chip the connectors are
# '''
# for name, conn in self.design.connectors.items():
# line = LineString(conn.points)
# self.render_shapely(line, kw=DEFAULT.annot_conectors.line_kw)
# self.ax.annotate(name, xy=conn.middle[:2], xytext=conn.middle +
# np.array(DEFAULT.annot_conectors.ofst),
# **DEFAULT.annot_conectors.annotate_kw)