Skip to content

Network-level Losses

These losses are intended to improve training stability, activity regulation, and classification behavior of spiking neural networks. They do not encode hardware constraints and can be used independently of the target platform.


Firing-rate Target MSE Loss

Function: firing_rate_target_mse_loss

This loss penalizes deviations of the average firing rate of each layer from a desired target value.

Import with:

from nwavesdk.loss import firing_rate_target_mse_loss

Description

For each layer, the loss computes:

  1. The mean firing rate across time and neurons
  2. The squared error with respect to a target firing rate
  3. A per-layer scaling factor

The final loss is the sum over all layers.

This loss is particularly useful to: - Stabilize training - Prevent pathological regimes (silent or saturated layers) - Enforce biologically and/or hardware-inspired firing-rate budgets

Mathematical Form

For layer \(i\):

\[ L_i = \left( (r_i - o_i) \cdot m_i \right)^2 \]

Where: - \(r_i\) is the mean firing rate of layer \(i\) - \(o_i\) is the target offset (desired firing rate for the the \(i-th\) layer) - \(m_i\) is a scaling multiplier

The total loss is:

\[ L = \sum_i L_i \]

Arguments

  • spikes_list: Sequence of spike tensors, one per layer
    Shape typically [B, T, N]
  • offsets: Target firing rates (one per layer)
  • multipliers: Scaling coefficients (one per layer)

Returns

A scalar tensor representing the total penalty.


Balanced Population MSE Loss

Function: balanced_population_mse

This loss enforces class-specific population activity in classification tasks.

Import with:

from nwavesdk.loss import balanced_population_mse

Description

Output neurons are divided into fixed-size populations, one per class. For each sample:

  • Neurons belonging to the correct class are encouraged to fire at a high total rate
  • Neurons belonging to incorrect classes are encouraged to fire at a low total rate

The loss computes a mean squared error between: - The sum of spikes over time for each neuron - A target spike count derived from the desired firing rate

Mathematical Intuition

Let \( S_{b,t,o} \in \{0,1\} \) be the spike output of neuron \( o \) at time \( t \) for sample \( b \).

The activity of each output neuron is computed as the sum of spikes over time:

\[ K_{b,o} = \sum_{t} S_{b,t,o} \]

Each neuron is assigned a target spike count depending on whether it belongs to the correct class:

\[ K^{\ast}_{b,o} = \begin{cases} T \cdot r_{\text{correct}} & \text{correct class} \\ T \cdot r_{\text{incorrect}} & \text{incorrect classes} \end{cases} \]

The loss is a standard mean squared error:

\[ \mathcal{L} = \mathrm{MSE}(K, K^{\ast}) \]

Key Difference from snntorch Population Losses

Unlike some population-based losses (e.g. in snntorch) that regulate mean activity, this loss operates on the sum of spikes over time.

This distinction is subtle but important:

  • Mean-based losses can be satisfied by a small subset of neurons
  • This often leads to dead or inactive neurons
  • Sum-based enforcement distributes activity more evenly across the population

As a result, this loss is: - More robust against neuron collapse - Better suited for long training runs - More compatible with hardware constraints that assume active populations

Arguments

  • spk: Spike tensor of shape [batch, time, output_neurons]
  • targets: Target class labels [batch]
  • num_classes: Number of output classes
  • neurons_per_class: Neurons assigned to each class (population per class)
  • correct_rate: Desired firing rate for the correct class
  • incorrect_rate: Desired firing rate for incorrect classes

Returns

A scalar MSE loss over all output neurons.