Source code for pytket.utils.graph

# Copyright 2019-2024 Cambridge Quantum Computing
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from itertools import combinations
from tempfile import NamedTemporaryFile

import graphviz as gv  # type: ignore
import networkx as nx  # type: ignore

from pytket.circuit import Circuit


[docs] class Graph:
[docs] def __init__(self, c: Circuit): """ A class for visualising a circuit as a directed acyclic graph (DAG). Note: in order to use graph-rendering methods, such as :py:meth:`Graph.save_DAG`, it is necessary to have the Graphviz tools installed and on your path. See the `Graphviz website <https://graphviz.org/download/>`_ for instructions on how to install them. :param c: Circuit :type c: pytket.Circuit """ ( q_inputs, c_inputs, w_inputs, q_outputs, c_outputs, w_outputs, input_names, output_names, node_data, edge_data, ) = c._dag_data self.q_inputs = q_inputs self.c_inputs = c_inputs self.w_inputs = w_inputs self.q_outputs = q_outputs self.c_outputs = c_outputs self.w_outputs = w_outputs self.input_names = input_names self.output_names = output_names self.node_data = node_data self.Gnx: nx.MultiDiGraph | None = None self.G: gv.Digraph | None = None self.Gqc: gv.Graph | None = None self.edge_data: dict[tuple[int, int], list[tuple[int, int, str]]] = defaultdict( list ) self.port_counts: dict = defaultdict(int) for src_node, tgt_node, src_port, tgt_port, edge_type in edge_data: self.edge_data[(src_node, tgt_node)].append((src_port, tgt_port, edge_type)) self.port_counts[(src_node, src_port)] += 1
[docs] def as_nx(self) -> nx.MultiDiGraph: """ Return a logical representation of the circuit as a DAG. :returns: Representation of the DAG :rtype: networkx.MultiDiGraph """ if self.Gnx is not None: return self.Gnx Gnx = nx.MultiDiGraph() for node, desc in self.node_data.items(): Gnx.add_node(node, desc=desc) for nodepair, portpairlist in self.edge_data.items(): src_node, tgt_node = nodepair for src_port, tgt_port, edge_type in portpairlist: Gnx.add_edge( src_node, tgt_node, src_port=src_port, tgt_port=tgt_port, edge_type=edge_type, ) # Add node IDs to edges for edge in nx.topological_sort(nx.line_graph(Gnx)): src_node, tgt_node, _ = edge # List parent edges with matching port number src_port = Gnx.edges[edge]["src_port"] prev_edges = [ e for e in Gnx.in_edges(src_node, keys=True) if Gnx.edges[e]["tgt_port"] == src_port ] if not prev_edges: # The source must be an input node unit_id = src_node nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) else: # The parent must be unique assert len(prev_edges) == 1 prev_edge = prev_edges[0] unit_id = Gnx.edges[prev_edge]["unit_id"] nx.set_edge_attributes(Gnx, {edge: {"unit_id": unit_id}}) # Remove unnecessary port attributes to avoid clutter: for node in Gnx.nodes: if Gnx.in_degree(node) == 1: for edge in Gnx.in_edges(node, keys=True): nx.set_edge_attributes(Gnx, {edge: {"tgt_port": None}}) for edge in Gnx.out_edges(node, keys=True): nx.set_edge_attributes(Gnx, {edge: {"src_port": None}}) self.Gnx = Gnx return Gnx
[docs] def get_DAG(self) -> gv.Digraph: """ Return a visual representation of the DAG as a graphviz object. :returns: Representation of the DAG :rtype: graphviz.DiGraph """ if self.G is not None: return self.G G = gv.Digraph( "Circuit", strict=True, ) G.attr(rankdir="LR", ranksep="0.3", nodesep="0.15", margin="0") q_color = "blue" c_color = "slategray" b_color = "gray" w_color = "green" gate_color = "lightblue" boundary_cluster_attr = { "style": "rounded, filled", "color": "lightgrey", "margin": "5", } boundary_node_attr = {"fontname": "Courier", "fontsize": "8"} with G.subgraph(name="cluster_q_inputs") as c: c.attr(rank="source", **boundary_cluster_attr) c.node_attr.update(shape="point", color=q_color) for node in self.q_inputs: c.node( str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr ) with G.subgraph(name="cluster_c_inputs") as c: c.attr(rank="source", **boundary_cluster_attr) c.node_attr.update(shape="point", color=c_color) for node in self.c_inputs: c.node( str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr ) with G.subgraph(name="cluster_w_inputs") as c: c.attr(rank="source", **boundary_cluster_attr) c.node_attr.update(shape="point", color=w_color) for node in self.w_inputs: c.node( str((node, 0)), xlabel=self.input_names[node], **boundary_node_attr ) with G.subgraph(name="cluster_q_outputs") as c: c.attr(rank="sink", **boundary_cluster_attr) c.node_attr.update(shape="point", color=q_color) for node in self.q_outputs: c.node( str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr ) with G.subgraph(name="cluster_c_outputs") as c: c.attr(rank="sink", **boundary_cluster_attr) c.node_attr.update(shape="point", color=c_color) for node in self.c_outputs: c.node( str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr ) with G.subgraph(name="cluster_w_outputs") as c: c.attr(rank="sink", **boundary_cluster_attr) c.node_attr.update(shape="point", color=w_color) for node in self.w_outputs: c.node( str((node, 0)), xlabel=self.output_names[node], **boundary_node_attr ) boundary_nodes = ( self.q_inputs | self.c_inputs | self.w_inputs | self.q_outputs | self.c_outputs | self.w_outputs ) Gnx = self.as_nx() node_cluster_attr = { "style": "rounded, filled", "color": gate_color, "fontname": "Times-Roman", "fontsize": "10", "margin": "5", "lheight": "100", } port_node_attr = { "shape": "point", "weight": "2", "fontname": "Helvetica", "fontsize": "8", } for node, ndata in Gnx.nodes.items(): if node not in boundary_nodes: with G.subgraph(name="cluster_" + str(node)) as c: c.attr(label=ndata["desc"], **node_cluster_attr) n_ports = Gnx.in_degree(node) if n_ports == 1: c.node(name=str((node, 0)), **port_node_attr) else: for i in range(n_ports): c.node(name=str((node, i)), xlabel=str(i), **port_node_attr) edge_colors = { "Quantum": q_color, "Boolean": b_color, "Classical": c_color, "WASM": w_color, } edge_attr = { "weight": "2", "arrowhead": "vee", "arrowsize": "0.2", "headclip": "true", "tailclip": "true", } for edge, edata in Gnx.edges.items(): src_node, tgt_node, _ = edge src_port = edata["src_port"] or 0 tgt_port = edata["tgt_port"] or 0 edge_type = edata["edge_type"] src_nodename = str((src_node, src_port)) tgt_nodename = str((tgt_node, tgt_port)) G.edge( src_nodename, tgt_nodename, color=edge_colors[edge_type], **edge_attr ) self.G = G return G
[docs] def save_DAG(self, name: str, fmt: str = "pdf") -> None: """ Save an image of the DAG to a file. The actual filename will be "<name>.<fmt>". A wide range of formats is supported. See https://graphviz.org/doc/info/output.html. :param name: Prefix of file name :type name: str :param fmt: File format, e.g. "pdf", "png", ... :type fmt: str """ G = self.get_DAG() G.render(name, cleanup=True, format=fmt, quiet=True)
[docs] def view_DAG(self) -> str: """ View the DAG. This method creates a temporary file, and returns its filename so that the caller may delete it afterwards. :returns: filename of temporary created file """ G = self.get_DAG() filename = NamedTemporaryFile(delete=False).name G.view(filename, quiet=True) return filename
[docs] def get_qubit_graph(self) -> gv.Graph: """ Return a visual representation of the qubit connectivity graph as a graphviz object. :returns: Representation of the qubit connectivity graph of the circuit :rtype: graphviz.Graph """ if self.Gqc is not None: return self.Gqc Gnx = self.as_nx() Gqcnx = nx.Graph() for node in Gnx.nodes(): qubits = [] for e in Gnx.in_edges(node, keys=True): unit_id = Gnx.edges[e]["unit_id"] if unit_id in self.q_inputs: qubits.append(unit_id) Gqcnx.add_edges_from(combinations(qubits, 2)) G = gv.Graph( "Qubit connectivity", node_attr={ "shape": "circle", "color": "blue", "fontname": "Courier", "fontsize": "10", }, engine="neato", ) G.edges( (self.input_names[src], self.input_names[tgt]) for src, tgt in Gqcnx.edges() ) self.Gqc = G return G
[docs] def view_qubit_graph(self) -> str: """ View the qubit connectivity graph. This method creates a temporary file, and returns its filename so that the caller may delete it afterwards. :returns: filename of temporary created file """ G = self.get_qubit_graph() filename = NamedTemporaryFile(delete=False).name G.view(filename, quiet=True) return filename
[docs] def save_qubit_graph(self, name: str, fmt: str = "pdf") -> None: """ Save an image of the qubit connectivity graph to a file. The actual filename will be "<name>.<fmt>". A wide range of formats is supported. See https://graphviz.org/doc/info/output.html. :param name: Prefix of file name :type name: str :param fmt: File format, e.g. "pdf", "png", ... :type fmt: str """ G = self.get_qubit_graph() G.render(name, cleanup=True, format=fmt, quiet=True)