Introduction to spectral graph wavelets
This tutorial will show you how to easily construct a wavelet frame, a kind of filter bank, and apply it to a signal. This tutorial will walk you into computing the wavelet coefficients of a graph, visualizing filters in the vertex domain, and using the wavelets to estimate the curvature of a 3D shape.
As usual, we first have to import some packages.
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from pygsp2 import graphs, filters, plotting, utils
Then we can load a graph. The graph we’ll use is a nearest-neighbor graph of a point cloud of the Stanford bunny. It will allow us to get interesting visual results using wavelets.
>>> G = graphs.Bunny()
Note
At this stage we could compute the Fourier basis using
pygsp2.graphs.Graph.compute_fourier_basis(), but this would take some
time, and can be avoided with a Chebychev polynomials approximation to
graph filtering. See the documentation of the
pygsp2.filters.Filter.filter() filtering function and
[HVG11] for details on how it is down.
Simple filtering: heat diffusion
Before tackling wavelets, let’s observe the effect of a single filter on a graph signal. We first design a few heat kernel filters, each with a different scale.
>>> taus = [10, 25, 50]
>>> g = filters.Heat(G, taus)
Let’s create a signal as a Kronecker delta located on one vertex, e.g. the vertex 20. That signal is our heat source.
>>> s = np.zeros(G.N)
>>> DELTA = 20
>>> s[DELTA] = 1
We can now simulate heat diffusion by filtering our signal s with each of our heat kernels.
>>> s = g.filter(s, method='chebyshev')
And finally plot the filtered signal showing heat diffusion at different scales.
>>> fig = plt.figure(figsize=(10, 3))
>>> for i in range(g.Nf):
... ax = fig.add_subplot(1, g.Nf, i+1, projection='3d')
... title = r'Heat diffusion, $\tau={}$'.format(taus[i])
... _ = G.plot(s[:, i], colorbar=False, title=title, ax=ax)
... ax.set_axis_off()
>>> fig.tight_layout()
Note
The pygsp2.filters.Filter.localize() method can be used to visualize a
filter in the vertex domain instead of doing it manually.
Visualizing wavelets atoms
Let’s now replace the Heat filter by a filter bank of wavelets. We can create a
filter bank using one of the predefined filters, such as
pygsp2.filters.MexicanHat to design a set of Mexican hat wavelets.
>>> g = filters.MexicanHat(G, Nf=6) # Nf = 6 filters in the filter bank.
Then plot the frequency response of those filters.
>>> fig, ax = plt.subplots(figsize=(10, 5))
>>> _ = g.plot(title='Filter bank of mexican hat wavelets', ax=ax)
Note
We can see that the wavelet atoms are stacked on the low frequency part of
the spectrum. A better coverage could be obtained by adapting the filter
bank with pygsp2.filters.WarpedTranslates or by using another
filter bank like pygsp2.filters.Itersine.
We can visualize the atoms as we did with the heat kernel, by filtering a Kronecker delta placed at one specific vertex.
>>> s = g.localize(DELTA)
>>>
>>> fig = plt.figure(figsize=(10, 2.5))
>>> for i in range(3):
... ax = fig.add_subplot(1, 3, i+1, projection='3d')
... _ = G.plot(s[:, i], title='Wavelet {}'.format(i+1), ax=ax)
... ax.set_axis_off()
>>> fig.tight_layout()
Curvature estimation
As a last and more applied example, let us try to estimate the curvature of the underlying 3D model by only using spectral filtering on the nearest-neighbor graph formed by its point cloud.
A simple way to accomplish that is to use the coordinates map \([x, y, z]\) and filter it using the above defined wavelets. Doing so gives us a 3-dimensional signal \([g_i(L)x, g_i(L)y, g_i(L)z], \ i \in [0, \ldots, N_f]\) which describes variation along the 3 coordinates.
>>> s = G.coords
>>> s = g.filter(s)
The curvature is then estimated by taking the \(\ell_1\) or \(\ell_2\) norm across the 3D position.
>>> s = np.linalg.norm(s, ord=2, axis=1)
Let’s finally plot the result to observe that we indeed have a measure of the curvature at different scales.
>>> fig = plt.figure(figsize=(10, 7))
>>> for i in range(4):
... ax = fig.add_subplot(2, 2, i+1, projection='3d')
... title = 'Curvature estimation (scale {})'.format(i+1)
... _ = G.plot(s[:, i], title=title, ax=ax)
... ax.set_axis_off()
>>> fig.tight_layout()