Network-level Losses
These losses are intended to improve training stability, activity regulation, and classification behavior of spiking neural networks. They do not encode hardware constraints and can be used independently of the target platform.
Firing-rate Target MSE Loss
Function: firing_rate_target_mse_loss
This loss penalizes deviations of the average firing rate of each layer from a desired target value.
Import with:
from nwavesdk.loss import firing_rate_target_mse_loss
Description
For each layer, the loss computes:
- The mean firing rate across time and neurons
- The squared error with respect to a target firing rate
- A per-layer scaling factor
The final loss is the sum over all layers.
This loss is particularly useful to: - Stabilize training - Prevent pathological regimes (silent or saturated layers) - Enforce biologically and/or hardware-inspired firing-rate budgets
Mathematical Form
For layer \(i\):
Where: - \(r_i\) is the mean firing rate of layer \(i\) - \(o_i\) is the target offset (desired firing rate for the the \(i-th\) layer) - \(m_i\) is a scaling multiplier
The total loss is:
Arguments
spikes_list: Sequence of spike tensors, one per layer
Shape typically[B, T, N]offsets: Target firing rates (one per layer)multipliers: Scaling coefficients (one per layer)
Returns
A scalar tensor representing the total penalty.
Balanced Population MSE Loss
Function: balanced_population_mse
This loss enforces class-specific population activity in classification tasks.
Import with:
from nwavesdk.loss import balanced_population_mse
Description
Output neurons are divided into fixed-size populations, one per class. For each sample:
- Neurons belonging to the correct class are encouraged to fire at a high total rate
- Neurons belonging to incorrect classes are encouraged to fire at a low total rate
The loss computes a mean squared error between: - The sum of spikes over time for each neuron - A target spike count derived from the desired firing rate
Mathematical Intuition
Let \( S_{b,t,o} \in \{0,1\} \) be the spike output of neuron \( o \) at time \( t \) for sample \( b \).
The activity of each output neuron is computed as the sum of spikes over time:
Each neuron is assigned a target spike count depending on whether it belongs to the correct class:
The loss is a standard mean squared error:
Key Difference from snntorch Population Losses
Unlike some population-based losses (e.g. in snntorch) that regulate mean activity,
this loss operates on the sum of spikes over time.
This distinction is subtle but important:
- Mean-based losses can be satisfied by a small subset of neurons
- This often leads to dead or inactive neurons
- Sum-based enforcement distributes activity more evenly across the population
As a result, this loss is: - More robust against neuron collapse - Better suited for long training runs - More compatible with hardware constraints that assume active populations
Arguments
spk: Spike tensor of shape[batch, time, output_neurons]targets: Target class labels[batch]num_classes: Number of output classesneurons_per_class: Neurons assigned to each class (population per class)correct_rate: Desired firing rate for the correct classincorrect_rate: Desired firing rate for incorrect classes
Returns
A scalar MSE loss over all output neurons.