# -*- coding: utf-8 -*-
r"""
The :mod:`pygsp2.plotting` module implements functionality to plot PyGSP2 objects
with a `pyqtgraph <https://www.pyqtgraph.org>`_ or `matplotlib
<https://matplotlib.org>`_ drawing backend (which can be controlled by the
:data:`BACKEND` constant or individually for each plotting call).
Most users won't use this module directly.
Graphs (from :mod:`pygsp2.graphs`) are to be plotted with
:meth:`pygsp2.graphs.Graph.plot` and
:meth:`pygsp2.graphs.Graph.plot_spectrogram`.
Filters (from :mod:`pygsp2.filters`) are to be plotted with
:meth:`pygsp2.filters.Filter.plot`.
.. data:: BACKEND
The default drawing backend to use if none are provided to the plotting
functions. Should be either ``'matplotlib'`` or ``'pyqtgraph'``. In general
pyqtgraph is better for interactive exploration while matplotlib is better
at generating figures to be included in papers or elsewhere.
"""
import functools
import numpy as np
from matplotlib import colormaps
from pygsp2 import utils
_logger = utils.build_logger(__name__)
BACKEND = 'matplotlib'
CMAP = 'viridis'
_qtg_widgets = []
_plt_figures = []
def _import_plt():
try:
import matplotlib as mpl
from matplotlib import pyplot as plt
from mpl_toolkits import mplot3d
except Exception as e:
raise ImportError('Cannot import matplotlib. Choose another backend '
'or try to install it with '
'pip (or conda) install matplotlib. '
'Original exception: {}'.format(e))
return mpl, plt, mplot3d
def _import_qtg():
try:
import pyqtgraph as qtg
import pyqtgraph.opengl as gl
from pyqtgraph.Qt import QtGui
except Exception as e:
raise ImportError('Cannot import pyqtgraph. Choose another backend '
'or try to install it with '
'pip (or conda) install pyqtgraph. You will also '
'need PyQt5 (or PySide) and PyOpenGL. '
'Original exception: {}'.format(e))
return qtg, gl, QtGui
def _plt_handle_figure(plot):
r"""Handle the common work (creating an axis if not given, setting the
title) of all matplotlib plot commands.
"""
# Preserve documentation of plot.
@functools.wraps(plot)
def inner(obj, **kwargs):
# Create a figure and an axis if none were passed.
if kwargs['ax'] is None:
_, plt, _ = _import_plt()
fig = plt.figure()
global _plt_figures
_plt_figures.append(fig)
if (hasattr(obj, 'coords') and obj.coords.ndim == 2 and obj.coords.shape[1] == 3):
kwargs['ax'] = fig.add_subplot(111, projection='3d')
else:
kwargs['ax'] = fig.add_subplot(111)
title = kwargs.pop('title')
plot(obj, **kwargs)
kwargs['ax'].set_title(title)
try:
fig.show(warn=False)
except NameError:
# No figure created, an axis was passed.
pass
return kwargs['ax'].figure, kwargs['ax']
return inner
[docs]
def close_all():
r"""Close all opened windows."""
global _qtg_widgets
for widget in _qtg_widgets:
widget.close()
_qtg_widgets = []
global _plt_figures
for fig in _plt_figures:
_, plt, _ = _import_plt()
plt.close(fig)
_plt_figures = []
[docs]
def show(*args, **kwargs):
r"""Show created figures, alias to ``plt.show()``.
By default, showing plots does not block the prompt.
Calling this function will block execution.
"""
_, plt, _ = _import_plt()
plt.show(*args, **kwargs)
[docs]
def close(*args, **kwargs):
r"""Close last created figure, alias to ``plt.close()``."""
_, plt, _ = _import_plt()
plt.close(*args, **kwargs)
def _qtg_plot_graph(G, edges, vertex_size, title):
qtg, gl, QtGui = _import_qtg()
if G.coords.shape[1] == 2:
widget = qtg.GraphicsLayoutWidget()
view = widget.addViewBox()
view.setAspectLocked()
if edges:
pen = tuple(np.array(G.plotting['edge_color']) * 255)
else:
pen = None
adj = _get_coords(G, edge_list=True)
g = qtg.GraphItem(pos=G.coords, adj=adj, pen=pen, size=vertex_size / 10)
view.addItem(g)
elif G.coords.shape[1] == 3:
if not QtGui.QApplication.instance():
QtGui.QApplication([]) # We want only one application.
widget = gl.GLViewWidget()
widget.opts['distance'] = 10
if edges:
x, y, z = _get_coords(G)
pos = np.stack((x, y, z), axis=1)
g = gl.GLLinePlotItem(pos=pos, mode='lines', color=G.plotting['edge_color'])
widget.addItem(g)
gp = gl.GLScatterPlotItem(pos=G.coords, size=vertex_size / 3, color=G.plotting['vertex_color'])
widget.addItem(gp)
widget.setWindowTitle(title)
widget.show()
global _qtg_widgets
_qtg_widgets.append(widget)
def _plot_filter(filters, n, eigenvalues, sum, labels, title, ax, **kwargs):
r"""Plot the spectral response of a filter bank.
Parameters
----------
n : int
Number of points where the filters are evaluated.
eigenvalues : boolean
Whether to show the eigenvalues of the graph Laplacian.
The eigenvalues should have been computed with
:meth:`~pygsp2.graphs.Graph.compute_fourier_basis`.
By default, the eigenvalues are shown if they are available.
sum : boolean
Whether to plot the sum of the squared magnitudes of the filters.
Default False if there is only one filter in the bank, True otherwise.
labels : boolean
Whether to label the filters.
Default False if there is only one filter in the bank, True otherwise.
title : str
Title of the figure.
ax : :class:`matplotlib.axes.Axes`
Axes where to draw the graph. Optional, created if not passed.
kwargs : dict
Additional parameters passed to the matplotlib plot function.
Useful for example to change the linewidth, linestyle, or set a label.
Returns
-------
fig : :class:`matplotlib.figure.Figure`
The figure the plot belongs to. Only with the matplotlib backend.
ax : :class:`matplotlib.axes.Axes`
The axes the plot belongs to. Only with the matplotlib backend.
Notes
-----
This function is only implemented with the matplotlib backend.
Examples
--------
>>> import matplotlib
>>> G = graphs.Logo()
>>> mh = filters.MexicanHat(G)
>>> fig, ax = mh.plot()
"""
if eigenvalues is None:
eigenvalues = (filters.G._e is not None)
if sum is None:
sum = (filters.n_filters > 1)
if labels is None:
labels = (filters.n_filters > 1)
if title is None:
title = repr(filters)
return _plt_plot_filter(filters, n=n, eigenvalues=eigenvalues, sum=sum, labels=labels, title=title, ax=ax, **kwargs)
@_plt_handle_figure
def _plt_plot_filter(filters, n, eigenvalues, sum, labels, ax, **kwargs):
x = np.linspace(0, filters.G.lmax, n)
params = dict(alpha=0.5)
params.update(kwargs)
if eigenvalues:
# Evaluate the filter bank at the eigenvalues to avoid plotting
# artifacts, for example when deltas are centered on the eigenvalues.
x = np.sort(np.concatenate([x, filters.G.e]))
y = filters.evaluate(x).T
lines = ax.plot(x, y, **params)
# TODO: plot highlighted eigenvalues
if sum:
line_sum, = ax.plot(x, np.sum(y**2, 1), 'k', **kwargs)
if labels:
for i, line in enumerate(lines):
line.set_label(fr'$g_{{{i}}}(\lambda)$')
if sum:
line_sum.set_label(r'$\sum_i g_i^2(\lambda)$')
ax.legend()
if eigenvalues:
segs = np.empty((len(filters.G.e), 2, 2))
segs[:, 0, 0] = segs[:, 1, 0] = filters.G.e
segs[:, :, 1] = [0, 1]
mpl, _, _ = _import_plt()
ax.add_collection(
mpl.collections.LineCollection(segs, transform=ax.get_xaxis_transform(), zorder=0, color=[0.9] * 3, linewidth=1,
label='eigenvalues'))
# Plot dots where the evaluation matters.
y = filters.evaluate(filters.G.e).T
params.pop('label', None)
for i in range(y.shape[1]):
params.update(color=lines[i].get_color())
ax.plot(filters.G.e, y[:, i], '.', **params)
if sum:
params.update(color=line_sum.get_color())
ax.plot(filters.G.e, np.sum(y**2, 1), '.', **params)
ax.set_xlabel(r"laplacian's eigenvalues (graph frequencies) $\lambda$")
ax.set_ylabel(r'filter response $g(\lambda)$')
def _plot_graph(G, vertex_color, vertex_size, highlight, edges, edge_color, edge_width, indices, colorbar, limits, ax, title,
backend, cmap, alphan, alphav, edge_weights):
r"""Plot a graph with signals as color or vertex size.
Parameters
----------
vertex_color : array_like or color
Signal to plot as vertex color (length is the number of vertices).
If None, vertex color is set to `graph.plotting['vertex_color']`.
Alternatively, a color can be set in any format accepted by matplotlib.
Each vertex color can by specified by an RGB(A) array of dimension
`n_vertices` x 3 (or 4).
vertex_size : array_like or int
Signal to plot as vertex size (length is the number of vertices).
Vertex size ranges from 0.5 to 2 times `graph.plotting['vertex_size']`.
If None, vertex size is set to `graph.plotting['vertex_size']`.
Alternatively, a size can be passed as an integer.
The pyqtgraph backend only accepts an integer size.
highlight : iterable
List of indices of vertices to be highlighted.
Useful for example to show where a filter was localized.
Only available with the matplotlib backend.
edges : bool
Whether to draw edges in addition to vertices.
Default to True if less than 10,000 edges to draw.
Note that drawing many edges can be slow.
edge_color : array_like or color
Signal to plot as edge color (length is the number of edges).
Edge color is given by `graph.plotting['edge_color']` and transparency
ranges from 0.2 to 0.9.
If None, edge color is set to `graph.plotting['edge_color']`.
Alternatively, a color can be set in any format accepted by matplotlib.
Each edge color can by specified by an RGB(A) array of dimension
`n_edges` x 3 (or 4).
Only available with the matplotlib backend.
edge_width : array_like or int
Signal to plot as edge width (length is the number of edges).
Edge width ranges from 0.5 to 2 times `graph.plotting['edge_width']`.
If None, edge width is set to `graph.plotting['edge_width']`.
Alternatively, a width can be passed as an integer.
Only available with the matplotlib backend.
indices : bool
Whether to print the node indices (in the adjacency / Laplacian matrix
and signal vectors) on top of each node.
Useful to locate a node of interest.
Only available with the matplotlib backend.
colorbar : bool
Whether to plot a colorbar indicating the signal's amplitude.
Only available with the matplotlib backend.
limits : [vmin, vmax]
Map colors from vmin to vmax.
Defaults to signal minimum and maximum value.
Only available with the matplotlib backend.
ax : :class:`matplotlib.axes.Axes`
Axes where to draw the graph. Optional, created if not passed.
Only available with the matplotlib backend.
title : str
Title of the figure.
backend: {'matplotlib', 'pyqtgraph', None}
Defines the drawing backend to use.
Defaults to :data:`pygsp2.plotting.BACKEND`.
cmap : str
Colormap of the figure.
alphan : float
Transparency channel for the graph's nodes.
alphav : float
Transparency channel for the graph's vertices.
edge_weights : array_like
Signal to plot as vertex color (length is the number of vertices).
Returns
-------
fig : :class:`matplotlib.figure.Figure`
The figure the plot belongs to. Only with the matplotlib backend.
ax : :class:`matplotlib.axes.Axes`
The axes the plot belongs to. Only with the matplotlib backend.
Notes
-----
The orientation of directed edges is not shown. If edges exist in both
directions, they will be drawn on top of each other.
Examples
--------
>>> import matplotlib
>>> graph = graphs.Sensor(20, seed=42)
>>> graph.compute_fourier_basis(n_eigenvectors=4)
>>> _, _, weights = graph.get_edge_list()
>>> fig, ax = graph.plot(graph.U[:, 1], vertex_size=graph.dw,
... edge_color=weights)
>>> graph.plotting['vertex_size'] = 300
>>> graph.plotting['edge_width'] = 5
>>> graph.plotting['edge_style'] = '--'
>>> fig, ax = graph.plot(edge_width=weights, edge_color=(0, .8, .8, .5),
... vertex_color='black')
>>> fig, ax = graph.plot(vertex_size=graph.dw, indices=True,
... highlight=[17, 3, 16], edges=False)
"""
if not hasattr(G, 'coords') or G.coords is None:
raise AttributeError('Graph has no coordinate set. '
'Please run G.set_coordinates() first.')
check_2d_3d = (G.coords.ndim != 2) or (G.coords.shape[1] not in [2, 3])
if G.coords.ndim != 1 and check_2d_3d:
raise AttributeError('Coordinates should be in 1D, 2D or 3D space.')
if G.coords.shape[0] != G.N:
raise AttributeError('Graph needs G.N = {} coordinates.'.format(G.N))
if backend is None:
backend = BACKEND
def check_shape(signal, name, length, many=False):
if (signal.ndim == 0) or (signal.shape[0] != length):
txt = '{}: signal should have length {}.'
txt = txt.format(name, length)
raise ValueError(txt)
if (not many) and (signal.ndim != 1):
txt = '{}: can plot only one signal (not {}).'
txt = txt.format(name, signal.shape[1])
raise ValueError(txt)
def normalize(x):
"""Scale values in [intercept, 1]. Return 0.5 if constant.
Set intercept value in G.plotting["normalize_intercept"]
with value in [0, 1], default is .25.
"""
#ptp = x.ptp()
ptp = np.ptp(x)
if ptp == 0:
return np.full(x.shape, 0.5)
else:
intercept = G.plotting['normalize_intercept']
return (1. - intercept) * (x - x.min()) / ptp + intercept
def is_color(color):
if backend == 'matplotlib':
mpl, _, _ = _import_plt()
if mpl.colors.is_color_like(color):
return True # single color
try:
return all(map(mpl.colors.is_color_like, color)) # color list
except TypeError:
return False # e.g., color is an int
else:
return False # No support for pyqtgraph (yet).
if cmap is None:
cmap = CMAP
else:
if cmap not in colormaps:
print('Wrong colormap')
cmap = CMAP
alphan = np.abs(alphan)
alphav = np.abs(alphav)
if vertex_color is None:
limits = [0, 0]
colorbar = False
if backend == 'matplotlib':
vertex_color = (G.plotting['vertex_color'], )
elif is_color(vertex_color):
limits = [0, 0]
colorbar = False
else:
vertex_color = np.asanyarray(vertex_color).squeeze()
check_shape(vertex_color, 'Vertex color', G.n_vertices, many=(G.coords.ndim == 1))
if vertex_size is None:
vertex_size = G.plotting['vertex_size']
elif not np.isscalar(vertex_size):
vertex_size = np.asanyarray(vertex_size).squeeze()
check_shape(vertex_size, 'Vertex size', G.n_vertices)
vertex_size = G.plotting['vertex_size'] * 4 * normalize(vertex_size)**2
if edges is None:
edges = G.Ne < 10e3
if edge_color is None:
edge_color = (G.plotting['edge_color'], )
elif not is_color(edge_color):
edge_color = np.asanyarray(edge_color).squeeze()
check_shape(edge_color, 'Edge color', G.n_edges)
edge_color = 0.9 * normalize(edge_color)
edge_color = [
np.tile(G.plotting['edge_color'][:3], [len(edge_color), 1]),
edge_color[:, np.newaxis],
]
edge_color = np.concatenate(edge_color, axis=1)
if edge_width is None:
edge_width = G.plotting['edge_width']
elif not np.isscalar(edge_width):
edge_width = np.array(edge_width).squeeze()
check_shape(edge_width, 'Edge width', G.n_edges)
edge_width = G.plotting['edge_width'] * 2 * normalize(edge_width)
if limits is None:
limits = [1.05 * vertex_color.min(), 1.05 * vertex_color.max()]
if title is None:
title = G.__repr__(limit=4)
if backend == 'pyqtgraph':
if vertex_color is None:
_qtg_plot_graph(G, edges=edges, vertex_size=vertex_size, title=title)
else:
_qtg_plot_signal(G, signal=vertex_color, vertex_size=vertex_size, edges=edges, limits=limits, title=title)
elif backend == 'matplotlib':
return _plt_plot_graph(G, vertex_color=vertex_color, vertex_size=vertex_size, highlight=highlight, edges=edges,
indices=indices, colorbar=colorbar, edge_color=edge_color, edge_width=edge_width, limits=limits,
ax=ax, title=title, cmap=cmap, alphan=alphan, alphav=alphav, edge_weights=edge_weights)
else:
raise ValueError('Unknown backend {}.'.format(backend))
@_plt_handle_figure
def _plt_plot_graph(G, vertex_color, vertex_size, highlight, edges, edge_color, edge_width, indices, colorbar, limits, ax, cmap,
alphan, alphav, edge_weights):
mpl, plt, mplot3d = _import_plt()
plt.set_cmap(cmap)
cmap = mpl.colormaps.get_cmap(cmap)
def is_color(color):
mpl, _, _ = _import_plt()
if mpl.colors.is_color_like(color):
return True # single color
try:
return all(map(mpl.colors.is_color_like, color)) # color list
except TypeError:
return False # e.g., color is an int
if edges and (G.coords.ndim != 1): # No edges for 1D plots.
sources, targets, _ = G.get_edge_list()
edges = [
G.coords[sources],
G.coords[targets],
]
edges = np.stack(edges, axis=1)
if G.coords.shape[1] == 2:
LineCollection = mpl.collections.LineCollection
elif G.coords.shape[1] == 3:
LineCollection = mplot3d.art3d.Line3DCollection
if edge_weights is not None:
normalized_signal = edge_weights / np.max(edge_weights)
colors = cmap(normalized_signal)
ax.add_collection(
LineCollection(
edges,
linewidths=edge_width,
colors=colors,
linestyles=G.plotting['edge_style'],
zorder=1,
alpha=alphav,
))
else:
ax.add_collection(
LineCollection(edges, linewidths=edge_width, colors=edge_color, linestyles=G.plotting['edge_style'], zorder=1,
alpha=alphav))
try:
iter(highlight)
except TypeError:
highlight = [highlight]
coords_hl = G.coords[highlight]
if G.coords.ndim == 1:
ax.plot(G.coords, vertex_color, alpha=alphan)
ax.set_ylim(limits)
for coord_hl in coords_hl:
ax.axvline(x=coord_hl, color=G.plotting['highlight_color'], linewidth=2)
else:
# Prevent matplotlib warning when using cmap without a valid color signal.
if ((vertex_color is None or isinstance(vertex_color, str)) and
limits is not None and len(limits) == 2 and
cmap is not None):
vertex_color = np.full(G.N, 0.5)
if is_color(vertex_color):
sc = ax.scatter(*G.coords.T, c=vertex_color, s=vertex_size, marker='o', linewidths=0, alpha=alphan, zorder=2)
else:
sc = ax.scatter(*G.coords.T, c=vertex_color, s=vertex_size, marker='o', linewidths=0, alpha=alphan, zorder=2,
vmin=limits[0], vmax=limits[1])
if np.isscalar(vertex_size):
size_hl = vertex_size
else:
size_hl = vertex_size[highlight]
ax.scatter(*coords_hl.T, s=2 * size_hl, zorder=3, marker='o', c='None', edgecolors=G.plotting['highlight_color'],
linewidths=2)
if G.coords.shape[1] == 3:
try:
ax.view_init(elev=G.plotting['elevation'], azim=G.plotting['azimuth'])
ax.dist = G.plotting['distance']
except KeyError:
pass
if G.coords.ndim != 1 and colorbar:
plt.colorbar(sc, ax=ax)
if indices:
for node in range(G.N):
ax.text(
*tuple(G.coords[node]), # accomodate 2D and 3D
s=node,
color='white',
horizontalalignment='center',
verticalalignment='center')
def _qtg_plot_signal(G, signal, edges, vertex_size, limits, title):
qtg, gl, QtGui = _import_qtg()
if G.coords.shape[1] == 2:
widget = qtg.GraphicsLayoutWidget()
view = widget.addViewBox()
elif G.coords.shape[1] == 3:
if not QtGui.QApplication.instance():
QtGui.QApplication([]) # We want only one application.
widget = gl.GLViewWidget()
widget.opts['distance'] = 10
if edges:
if G.coords.shape[1] == 2:
adj = _get_coords(G, edge_list=True)
pen = tuple(np.array(G.plotting['edge_color']) * 255)
g = qtg.GraphItem(pos=G.coords, adj=adj, symbolBrush=None, symbolPen=None, pen=pen)
view.addItem(g)
elif G.coords.shape[1] == 3:
x, y, z = _get_coords(G)
pos = np.stack((x, y, z), axis=1)
g = gl.GLLinePlotItem(pos=pos, mode='lines', color=G.plotting['edge_color'])
widget.addItem(g)
pos = [1, 8, 24, 40, 56, 64]
color = np.array([[0, 0, 143, 255], [0, 0, 255, 255], [0, 255, 255, 255], [255, 255, 0, 255], [255, 0, 0, 255],
[128, 0, 0, 255]])
cmap = qtg.ColorMap(pos, color)
signal = 1 + 63 * (signal - limits[0]) / limits[1] - limits[0]
if G.coords.shape[1] == 2:
gp = qtg.ScatterPlotItem(G.coords[:, 0], G.coords[:, 1], size=vertex_size / 10, brush=cmap.map(signal, 'qcolor'))
view.addItem(gp)
if G.coords.shape[1] == 3:
gp = gl.GLScatterPlotItem(pos=G.coords, size=vertex_size / 3, color=cmap.map(signal, 'float'))
widget.addItem(gp)
widget.setWindowTitle(title)
widget.show()
global _qtg_widgets
_qtg_widgets.append(widget)
def _plot_spectrogram(G, node_idx):
r"""Plot the graph's spectrogram.
Parameters
----------
node_idx : ndarray
Order to sort the nodes in the spectrogram.
By default, does not reorder the nodes.
Notes
-----
This function is only implemented for the pyqtgraph backend at the moment.
Examples
--------
>>> G = graphs.Ring(15)
>>> G.plot_spectrogram()
"""
from pygsp2 import features
qtg, _, _ = _import_qtg()
if not hasattr(G, 'spectr'):
features.compute_spectrogram(G)
M = G.spectr.shape[1]
spectr = G.spectr[node_idx, :] if node_idx is not None else G.spectr
spectr = np.ravel(spectr)
min_spec, max_spec = spectr.min(), spectr.max()
pos = np.array([0., 0.25, 0.5, 0.75, 1.])
color = [[20, 133, 212, 255], [53, 42, 135, 255], [48, 174, 170, 255], [210, 184, 87, 255], [249, 251, 14, 255]]
color = np.array(color, dtype=np.ubyte)
cmap = qtg.ColorMap(pos, color)
spectr = (spectr.astype(float) - min_spec) / (max_spec - min_spec)
widget = qtg.GraphicsLayoutWidget()
label = 'frequencies {}:{:.2f}:{:.2f}'.format(0, G.lmax / M, G.lmax)
v = widget.addPlot(labels={'bottom': 'nodes', 'left': label})
v.setAspectLocked()
spi = qtg.ScatterPlotItem(np.repeat(np.arange(G.N), M), np.ravel(np.tile(np.arange(M), (1, G.N))), pxMode=False, symbol='s',
size=1, brush=cmap.map(spectr, 'qcolor'))
v.addItem(spi)
widget.setWindowTitle('Spectrogram of {}'.format(G.__repr__(limit=4)))
widget.show()
global _qtg_widgets
_qtg_widgets.append(widget)
def _get_coords(G, edge_list=False):
sources, targets, _ = G.get_edge_list()
if edge_list:
return np.stack((sources, targets), axis=1)
coords = [np.stack((G.coords[sources, d], G.coords[targets, d]), axis=0) for d in range(G.coords.shape[1])]
if G.coords.shape[1] == 2:
return coords
elif G.coords.shape[1] == 3:
return [coord.reshape(-1, order='F') for coord in coords]