Plotting Utilities
This page documents the two main plotting helpers provided in this module:
plot_spike_rasterplot_confusion_matrix
They are designed to produce dark/transparent friendly figures useful both to debug and inspect behaviour of the network training/inference, both to publish results conveniently.
plot_spike_raster
plot_spike_raster visualizes spiking activity as a raster plot for a single sample of the batch across multiple (desired) network layers.
Each subplot corresponds to one spikes's layer provided in the list, and each row inside a subplot corresponds to a neuron. A vertical tick indicates a spike event at a given timestep.
This plot is particularly useful to assess the network activity in hidden layers, and to visually inspect prediction balance/patterns in the output layer.
Import with:
from nwavesdk.utils import plot_spike_raster
Expected inputs
plot_spike_raster(
spks,
sample_idx,
savepath=None,
)
Parameters
spksSequence of spike tensors, one per layer. Each tensor must have shape:(B, T, N)
where:
B= batch sizeT= number of timesteps-
N= number of neurons in that layer -
sample_idxIndex of the sample within the batch to visualize. -
savepath(optional) If provided, the figure is saved to this path (PNG with transparent background) and the figure is closed.
Output
-
If
savepathis not provided:
The raster plot is shown interactively and the function returnsNone. -
If
savepathis provided:
The plot is saved to that path.
Spike rate plots for model debugging
NWAVE dataloaders can optionally return three elements per batch (x, y, fn). The fn is the filename tensor associated with each sample in the batch, and can be extremely useful to inspect results.
If some samples exhibit poor accuracy Use fn to retrieve and visualize the original raw data and the input of the net (x tensor over some channels). Plot the spike raster activity over that sample id to compare inputs that the network fails to learn wiht inputs it classifies confidently.
This comparison often reveals input patterns that are under-represented in training, temporal structures that do not give stable firing or inconstencies between data statistics and network dynamics.
This can helps isolate whether performance issues stem from data complexity or model capacity.
Common errors
You may encounter errors if:
spks[i][sample_idx]is out of bounds: remark thatsample_idxhas to be strictly lower thanbatch_size- Spike tensors do not have 3 dimensions
(B, T, N) - Spike tensors are not PyTorch tensors
Example
# spks: list of tensors [(B, T, N1), (B, T, N2), ...]
plot_spike_raster(
spks=spike_layers,
sample_idx=0,
savepath="raster.png",
)
plot_confusion_matrix
plot_confusion_matrix runs inference on a classification model and produces the confusion matrix.
The matrix shows, for each true pair of class the percentage of predictions assigned.
This function is end-to-end, so runs the model in eval() mode, collects predictions (with the desired prediction schema) over the passed dataset, plots the confusion matrix and restore the model’s original training state.
Import with:
from nwavesdk.utils import plot_confusion_matrix
Expected inputs
plot_confusion_matrix(
model,
classification_method,
dataloader,
index_last_spike,
title = None,
savepath = None,
)
Parameters
-
modelAtorch.nn.Moduleused for inference. -
classification_methodOne of["population" | "spike_sum"]:"population"– population voting over output neurons"spike_sum"– class chosen by summed spikes over time
-
dataloaderAtorch.utils.data.DataLoaderbuilt fromNWaveClassificationDataset. -
index_last_spikeIndex used to select the spike tensor from the model output when the model returns a tuple or list. This depends on how the network is designed and in which order returns spikes and optionally membranes. -
title(optional) Title displayed on the plot. -
savepath(optional) If provided, saves the figure and closes it.
Output
Returns (fig, ax, cm) where fig is the matplotlib.figure.Figure, ax is the main Axes and cm is the raw (count-based) confusion matrix with shape (C, C)
The plotted matrix is row-normalized (percentages per true class).
Common errors
Errors are raised if:
dataloaderis not aDataLoader- The dataset is not
NWaveClassificationDataset classification_methodis invalid- Model outputs do not have shape
(B, T, O) index_last_spikeis out of range
Minimal example
fig, ax, cm = plot_confusion_matrix(
model=model,
classification_method="population",
dataloader=test_loader,
index_last_spike=-1,
title="Test Set Confusion Matrix",
savepath="confusion_matrix.png",
)
Notes
Tip
Both plots are optimized for transparent backgrounds and can be safely embedded in slides or papers with dark themes.
Warning
plot_confusion_matrix performs a full forward pass over the dataloader.
Avoid using very large datasets unless necessary.