Skip to content

Plotting Utilities

This page documents the two main plotting helpers provided in this module:

  • plot_spike_raster
  • plot_confusion_matrix

They are designed to produce dark/transparent friendly figures useful both to debug and inspect behaviour of the network training/inference, both to publish results conveniently.


plot_spike_raster

plot_spike_raster visualizes spiking activity as a raster plot for a single sample of the batch across multiple (desired) network layers.

Each subplot corresponds to one spikes's layer provided in the list, and each row inside a subplot corresponds to a neuron. A vertical tick indicates a spike event at a given timestep.

This plot is particularly useful to assess the network activity in hidden layers, and to visually inspect prediction balance/patterns in the output layer.

Import with:

from nwavesdk.utils import plot_spike_raster

Expected inputs

plot_spike_raster(
    spks,
    sample_idx,
    savepath=None,
)

Parameters

  • spks Sequence of spike tensors, one per layer. Each tensor must have shape: (B, T, N)

where:

  • B = batch size
  • T = number of timesteps
  • N = number of neurons in that layer

  • sample_idx Index of the sample within the batch to visualize.

  • savepath (optional) If provided, the figure is saved to this path (PNG with transparent background) and the figure is closed.


Output

  • If savepath is not provided:
    The raster plot is shown interactively and the function returns None.

  • If savepath is provided:
    The plot is saved to that path.


Spike rate plots for model debugging

NWAVE dataloaders can optionally return three elements per batch (x, y, fn). The fn is the filename tensor associated with each sample in the batch, and can be extremely useful to inspect results.

If some samples exhibit poor accuracy Use fn to retrieve and visualize the original raw data and the input of the net (x tensor over some channels). Plot the spike raster activity over that sample id to compare inputs that the network fails to learn wiht inputs it classifies confidently.

This comparison often reveals input patterns that are under-represented in training, temporal structures that do not give stable firing or inconstencies between data statistics and network dynamics.

This can helps isolate whether performance issues stem from data complexity or model capacity.


Common errors

You may encounter errors if:

  • spks[i][sample_idx] is out of bounds: remark that sample_idx has to be strictly lower than batch_size
  • Spike tensors do not have 3 dimensions (B, T, N)
  • Spike tensors are not PyTorch tensors

Example

# spks: list of tensors [(B, T, N1), (B, T, N2), ...]
plot_spike_raster(
    spks=spike_layers,
    sample_idx=0,
    savepath="raster.png",
)

plot_confusion_matrix

plot_confusion_matrix runs inference on a classification model and produces the confusion matrix.

The matrix shows, for each true pair of class the percentage of predictions assigned.

This function is end-to-end, so runs the model in eval() mode, collects predictions (with the desired prediction schema) over the passed dataset, plots the confusion matrix and restore the model’s original training state.

Import with:

from nwavesdk.utils import plot_confusion_matrix

Expected inputs

plot_confusion_matrix(
    model,
    classification_method,
    dataloader,
    index_last_spike,
    title = None,
    savepath = None,
)

Parameters

  • model A torch.nn.Module used for inference.

  • classification_method One of ["population" | "spike_sum"]:

    • "population" – population voting over output neurons
    • "spike_sum" – class chosen by summed spikes over time
  • dataloader A torch.utils.data.DataLoader built from NWaveClassificationDataset.

  • index_last_spike Index used to select the spike tensor from the model output when the model returns a tuple or list. This depends on how the network is designed and in which order returns spikes and optionally membranes.

  • title (optional) Title displayed on the plot.

  • savepath (optional) If provided, saves the figure and closes it.


Output

Returns (fig, ax, cm) where fig is the matplotlib.figure.Figure, ax is the main Axes and cm is the raw (count-based) confusion matrix with shape (C, C)

The plotted matrix is row-normalized (percentages per true class).


Common errors

Errors are raised if:

  • dataloader is not a DataLoader
  • The dataset is not NWaveClassificationDataset
  • classification_method is invalid
  • Model outputs do not have shape (B, T, O)
  • index_last_spike is out of range

Minimal example

fig, ax, cm = plot_confusion_matrix(
    model=model,
    classification_method="population",
    dataloader=test_loader,
    index_last_spike=-1,
    title="Test Set Confusion Matrix",
    savepath="confusion_matrix.png",
)

Notes

Tip

Both plots are optimized for transparent backgrounds and can be safely embedded in slides or papers with dark themes.

Warning

plot_confusion_matrix performs a full forward pass over the dataloader.
Avoid using very large datasets unless necessary.