Skip to content

Hyperparameter Optimization for Custom hardware-aware SNNs

This guide explains how to use nwavesdk.optim.hpo to train and tune custom spiking neural networks built with H1v1 or H1v2 layers.

The goal is to make the workflow practical and flexible:

  • define your own hardware-aware architecture,
  • compose task losses and regularizers,
  • define the metrics you actually care about,
  • tune model, optimizer, scheduler, and training strategy parameters together,
  • inspect whether those choices are truly helping.

Why use this module

HPOWrapper gives you a structured way to run Ray Tune over an SNN training pipeline while keeping your model and training logic customizable.

In practice it helps when you want to:

  • compare many design choices (taus, optimizer, regularizers, schedulers),
  • optimize for non-standard targets (for example recall instead of loss),
  • include hardware-aware behavior during training (sign annealing, topology coherence),
  • avoid manual script rewrites every time you change what should be tuned.

Mental model of the pipeline

Each trial follows this flow:

  1. Build namespaced configs (model.*, loss.*, metric.*, optim.*, sched.*, nwave.*, trainer.*, data.*, early_stop.*).
  2. Build model and datasets.
  3. Validate model output contract.
  4. Build losses and metrics from their per-trial configs.
  5. Build optimizer and scheduler.
  6. Optionally build NWAVE schedulers (surrogate, sign_annealing).
  7. Optionally run the trial with your train_fn (by default it already uses a training snippet that depends on the binded parameters: train_one_trial).
  8. Report metrics to Ray Tune and checkpoint best epochs.

This separation is what makes the tool general: each piece is a function you provide at an high level, while the training integration is handled in the backend.


Core contracts (important)

1) Model builder

model = model_builder(model_cfg: dict)

model_cfg comes from all keys bound through bind_model_param(...).

2) Model forward output contract

HPOWrapper validates that your model returns:

(spikes_list, membranes_list)

with:

  • both objects being lists/tuples,
  • same non-zero length,
  • each pair (spk_i, mem_i) same shape,
  • each tensor rank 3, expected shape is (B, T, N).

3) Loss builders

You pass a dictionary of builders:

loss_builders = {"main": build_main_loss}

Each builder receives trial loss config and returns a callable loss object.

With default train_one_trial, only ctx.losses["main"] is used.

The returned loss callable is invoked with:

  • spk_out
  • y
  • spikes_list
  • membranes_list

and may return:

  • a scalar torch.Tensor,
  • a dict[str, Tensor] of components,
  • a LossOutput(total=..., components=...).

4) Metric builders

You pass:

metric_builders = {"acc": build_acc_metric, "recall": build_recall_metric, ...}

Each metric builder receives metric_cfg and returns a metric function.
Metric functions are called with:

  • spk_out
  • y
  • spikes_list
  • membranes_list
  • model

and should return a numeric scalar (float-like).


Building a Tunable Network

The pattern below uses an explicit layer-by-layer style:

  • H1v2 frontend + H1v2 synapse/layer stack,
  • fixed number of hidden layers (2),
  • hidden widths tunable (32 or 64),
  • output contract always (spikes_list, membranes_list).
import torch
import torch.nn as nn
from nwavesdk.layers import H1v2Frontend, H1v2Synapse, H1v2Layer, prepare_net


class SurrogateSNN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.in_ch = cfg["in_ch"]
        self.num_classes = cfg["num_classes"]
        self.neurons_per_class = cfg["neurons_per_class"]
        self.out_neurons = self.num_classes * self.neurons_per_class
        self.dt = cfg["dt"]
        self.device_flag = cfg.get("device", "gpu")
        n1 = 16

        # Fixed depth, tunable width
        n2 = cfg["hidden_1"]   # tune.choice([32, 64])
        n3 = self.out_neurons

        self.f0 = H1v2Frontend(self.in_ch, device=self.device_flag)

        self.s1 = H1v2Synapse(self.in_ch, n1, device=self.device_flag, init=nn.init.xavier_uniform_, lif_threshold=1.0)
        self.l1 = H1v2Layer(n1, taus=torch.full((n1,), cfg["tau_l1"]), dt=self.dt, device=self.device_flag, layer_topology="FF")

        self.s2 = H1v2Synapse(n1, n2, device=self.device_flag, init=nn.init.xavier_uniform_, lif_threshold=1.0)
        self.l2 = H1v2Layer(n2, taus=torch.full((n2,), cfg["tau_l2"]), dt=self.dt, device=self.device_flag, layer_topology="FF")

        self.s3 = H1v2Synapse(n2, n3, device=self.device_flag, init=nn.init.xavier_uniform_, lif_threshold=1.0)
        self.l3 = H1v2Layer(n3, taus=torch.full((n3,), cfg["tau_l3"]), dt=self.dt, device=self.device_flag, layer_topology="FF")

        if self.device_flag == "gpu":
            self.to_gpu()

    def to_gpu(self):
        for m in [self.f0, self.s1, self.l1, self.s2, self.l2, self.s3, self.l3]:
            if hasattr(m, "to_gpu"):
                m.to_gpu()

    def forward_cpu(self, x):
        # x: (B, T, F)
        prepare_net(self)

        spk1, spk2, spk3 = [], [], []
        mem1, mem2, mem3 = [], [], []

        for t in range(x.shape[1]):
            q0 = self.f0(x[:, t, :])

            q1 = self.s1(q0)
            s1, m1 = self.l1(q1)

            q2 = self.s2(s1)
            s2, m2 = self.l2(q2)

            q3 = self.s3(s2)
            s3, m3 = self.l3(q3)

            spk1.append(s1); mem1.append(m1)
            spk2.append(s2); mem2.append(m2)
            spk3.append(s3); mem3.append(m3)

        spikes_list = [
            torch.stack(spk1, dim=1),
            torch.stack(spk2, dim=1),
            torch.stack(spk3, dim=1),
        ]
        membranes_list = [
            torch.stack(mem1, dim=1),
            torch.stack(mem2, dim=1),
            torch.stack(mem3, dim=1),
        ]
        return spikes_list, membranes_list

    def forward(self, x):
        # x: (B, T, F)
        prepare_net(self)

        x = x.to("cuda", non_blocking=True).contiguous().to(torch.float32)

        q0 = self.f0(x)

        q1 = self.s1(q0)
        s1, m1 = self.l1(q1)

        q2 = self.s2(s1)
        s2, m2 = self.l2(q2)

        q3 = self.s3(s2)
        s3, m3 = self.l3(q3)

        return [s1, s2, s3], [m1, m2, m3]

Tuning plain weight initialization parameters with HPO

If you want to tune parameters from standard torch.nn.init.* functions, keep the same SurrogateSNN style used above (explicit frontend + synapse + layer stack):

  • expose an init name in model_cfg,
  • expose only the init-specific scalar parameters you need (for example gain, a, mode),
  • apply the selected init right after each synapse creation.
import torch
import torch.nn as nn
from ray import tune
from nwavesdk.layers import H1v2Frontend, H1v2Synapse, H1v2Layer
from nwavesdk.optim.hpo.HPOWrapper import HPOWrapper, train_one_trial


def _apply_synapse_init(synapse, cfg):
    w = synapse.weight
    init_name = cfg["init_name"]

    if init_name == "xavier_uniform":
        nn.init.xavier_uniform_(w, gain=cfg["init_gain"])
    elif init_name == "xavier_normal":
        nn.init.xavier_normal_(w, gain=cfg["init_gain"])
    elif init_name == "kaiming_uniform":
        nn.init.kaiming_uniform_(
            w,
            a=cfg["init_a"],
            mode=cfg["init_mode"],          # "fan_in" or "fan_out"
            nonlinearity=cfg["init_nlin"],  # e.g. "relu", "leaky_relu"
        )
    elif init_name == "kaiming_normal":
        nn.init.kaiming_normal_(
            w,
            a=cfg["init_a"],
            mode=cfg["init_mode"],
            nonlinearity=cfg["init_nlin"],
        )
    else:
        raise ValueError(f"Unsupported init_name: {init_name}")


class SurrogateSNN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.in_ch = cfg["in_ch"]
        self.num_classes = cfg["num_classes"]
        self.neurons_per_class = cfg["neurons_per_class"]
        self.out_neurons = self.num_classes * self.neurons_per_class
        self.dt = cfg["dt"]
        self.device_flag = cfg.get("device", "gpu")

        n1 = 16
        n2 = cfg["hidden_1"]
        n3 = self.out_neurons

        self.f0 = H1v2Frontend(self.in_ch, device=self.device_flag)
        self.s1 = H1v2Synapse(self.in_ch, n1, device=self.device_flag)
        self.s2 = H1v2Synapse(n1, n2, device=self.device_flag)
        self.s3 = H1v2Synapse(n2, n3, device=self.device_flag)

        for syn in [self.s1, self.s2, self.s3]:
            _apply_synapse_init(syn, cfg)

        self.l1 = H1v2Layer(n1, taus=torch.full((n1,), cfg["tau_l1"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
        self.l2 = H1v2Layer(n2, taus=torch.full((n2,), cfg["tau_l2"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
        self.l3 = H1v2Layer(n3, taus=torch.full((n3,), cfg["tau_l3"]), dt=self.dt, device=self.device_flag, layer_topology="FF")

    # forward / forward_cpu are the same as the SurrogateSNN example above.


def build_model(model_cfg):
    return SurrogateSNN(model_cfg)


hpo = HPOWrapper(
    model_builder=build_model,
    dataset_builder=build_datasets,
    train_fn=train_one_trial,
    loss_builders={"main": build_main_loss},
    metric_builders={"acc": build_acc_metric},
    metric="val_acc",
    mode="max",
    experiment_name="plain_init_hpo",
)

hpo.bind_model_param("dt", 2e-3)
hpo.bind_model_param("in_ch", 13)
hpo.bind_model_param("num_classes", 2)
hpo.bind_model_param("neurons_per_class", 32)
hpo.bind_model_param("hidden_1", tune.choice([32, 64]))
hpo.bind_model_param("tau_l1", tune.loguniform(5e-3, 30e-3))
hpo.bind_model_param("tau_l2", tune.loguniform(10e-3, 50e-3))
hpo.bind_model_param("tau_l3", tune.loguniform(20e-3, 90e-3))
hpo.bind_model_param("device", "gpu")

# plain init search space
hpo.bind_model_param(
    "init_name",
    tune.choice(["xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"]),
)
hpo.bind_model_param("init_gain", tune.uniform(0.5, 2.0))            # used by xavier*
hpo.bind_model_param("init_a", tune.uniform(0.0, 0.3))               # used by kaiming*
hpo.bind_model_param("init_mode", tune.choice(["fan_in", "fan_out"]))  # used by kaiming*
hpo.bind_model_param("init_nlin", tune.choice(["relu", "leaky_relu"])) # used by kaiming*

Tuning fluct_init parameters with HPO

When using fluct_init, keep the same SurrogateSNN network definition style and apply fluct_init inside build_model.

import os
import torch
import torch.nn as nn
from ray import tune
from nwavesdk.layers import H1v2Frontend, H1v2Synapse, H1v2Layer
from nwavesdk.init.fluct_init import fluct_init
from nwavesdk.optim.hpo.HPOWrapper import HPOWrapper, train_one_trial


def load_init_loader(dataset_path, train_file):
    return torch.load(os.path.join(dataset_path, train_file), weights_only=False)


class SurrogateSNN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.in_ch = cfg["in_ch"]
        self.num_classes = cfg["num_classes"]
        self.neurons_per_class = cfg["neurons_per_class"]
        self.out_neurons = self.num_classes * self.neurons_per_class
        self.dt = cfg["dt"]
        self.device_flag = cfg.get("device", "gpu")

        n1 = 16
        n2 = cfg["hidden_1"]
        n3 = self.out_neurons

        self.f0 = H1v2Frontend(self.in_ch, device=self.device_flag)
        self.s1 = H1v2Synapse(self.in_ch, n1, device=self.device_flag, init=nn.init.xavier_uniform_)
        self.s2 = H1v2Synapse(n1, n2, device=self.device_flag, init=nn.init.xavier_uniform_)
        self.s3 = H1v2Synapse(n2, n3, device=self.device_flag, init=nn.init.xavier_uniform_)
        self.l1 = H1v2Layer(n1, taus=torch.full((n1,), cfg["tau_l1"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
        self.l2 = H1v2Layer(n2, taus=torch.full((n2,), cfg["tau_l2"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
        self.l3 = H1v2Layer(n3, taus=torch.full((n3,), cfg["tau_l3"]), dt=self.dt, device=self.device_flag, layer_topology="FF")

    # forward / forward_cpu are the same as the SurrogateSNN example above.


def build_model(model_cfg):
    model = SurrogateSNN(model_cfg)

    if model_cfg.get("use_fluct_init", False):
        init_loader = load_init_loader(
            model_cfg["init_dataset_path"],
            model_cfg.get("fluct_init_train_file", "train.pt"),
        )
        fluct_init(
            model,
            init_loader,
            xi_target=model_cfg["fluct_xi_target"],
            alpha=model_cfg["fluct_alpha"],
            n_batches=model_cfg["fluct_n_batches"],
            verbose=model_cfg.get("init_verbose", False),
        )

    return model


hpo = HPOWrapper(
    model_builder=build_model,
    dataset_builder=build_datasets,
    train_fn=train_one_trial,
    loss_builders={"main": build_main_loss},
    metric_builders={"acc": build_acc_metric},
    metric="val_acc",
    mode="max",
    experiment_name="fluct_init_hpo",
)

hpo.bind_model_param("dt", 2e-3)
hpo.bind_model_param("in_ch", 13)
hpo.bind_model_param("num_classes", 2)
hpo.bind_model_param("neurons_per_class", 32)
hpo.bind_model_param("hidden_1", tune.choice([32, 64]))
hpo.bind_model_param("tau_l1", tune.loguniform(5e-3, 30e-3))
hpo.bind_model_param("tau_l2", tune.loguniform(10e-3, 50e-3))
hpo.bind_model_param("tau_l3", tune.loguniform(20e-3, 90e-3))
hpo.bind_model_param("device", "gpu")
hpo.bind_model_param("init_dataset_path", "/path/to/data")
hpo.bind_model_param("init_verbose", False)

# fluct_init search space
hpo.bind_model_param("use_fluct_init", True)
hpo.bind_model_param("fluct_init_train_file", "train.pt")
hpo.bind_model_param("fluct_xi_target", tune.uniform(1.0, 3.0))
hpo.bind_model_param("fluct_alpha", tune.uniform(0.7, 1.0))
hpo.bind_model_param("fluct_n_batches", tune.choice([1, 2, 4]))

Composing losses

CompositeLoss lets you combine heterogeneous terms without hard-coding a single monolithic loss.

from nwavesdk.loss import (
    balanced_population_mse,
    firing_rate_target_mse_loss,
    weight_magnitude_loss,
    topology_loss,
)
from nwavesdk.optim.hpo.loss_factory import CompositeLoss


def balanced_pop_term(spk_out, y, config, **kwargs):
    return {
        "mse": balanced_population_mse(
            spk_out,
            y,
            num_classes=config["num_classes"],
            neurons_per_class=config["neurons_per_class"],
            correct_rate=config["correct_rate"],
            incorrect_rate=config["incorrect_rate"],
        )
    }


def firing_rate_term(spikes_list, config, **kwargs):
    return {
        "firing_rate": firing_rate_target_mse_loss(
            spikes_list=spikes_list,
            offsets=config["offsets_fr"],
            multipliers=config["multipliers_fr"],
        )
    }

def hardware_reg_term(model, config, **kwargs):
    return {
        "weight_mag": weight_magnitude_loss(model, limit=config["limit"]),
        "topology": topology_loss(model, lam=config["lambda_topo"]),
    }

def build_main_loss(loss_cfg):
    return (
        CompositeLoss()
        .add_loss_term(
            "population",
            balanced_pop_term,
            config={
                "num_classes": loss_cfg["num_classes"],
                "neurons_per_class": loss_cfg["neurons_per_class"],
                "correct_rate": loss_cfg["correct_rate"],
                "incorrect_rate": loss_cfg["incorrect_rate"],
            },
            weight=loss_cfg.get("w_population", 1.0),
        )
        .add_loss_term(
            "activity",
            firing_rate_term,
            config={
                "offsets_fr": loss_cfg["offsets_fr"],
                "multipliers_fr": loss_cfg["multipliers_fr"],
            },
            weight=loss_cfg.get("w_activity", 1.0),
        )
        .add_loss_term(
            "hardware",
            hardware_reg_term,
            config={
                "limit": loss_cfg.get("weight_limit", 0.9),
                "lambda_topo": loss_cfg.get("lambda_topology", 1e-3),
            },
            weight=loss_cfg.get("w_hardware", 1.0),
        )
    )

This gives:

  • clear component logging per term,
  • weighted composition,
  • easy tuning of each term strength/config.

The default backend supports model-dependent regularizers directly:
train_one_trial passes model into the loss call, and CompositeLoss filters kwargs per loss-term signature.
This allows model-independent and model-dependent terms to be mixed without a custom train_fn.

Warning

Using topology_loss on the entire network composed with a frontend may raise error: be aware to call the loss only on dense synapses if using the topology_loss.


Defining metrics

Any callable metric is valid if it returns a scalar.

from nwavesdk.metrics import accuracy_population

def build_acc_metric(metric_cfg):
    def metric_fn(spk_out, y, **kwargs):
        return accuracy_population(
            spk_out,
            y,
            num_classes=metric_cfg.get("num_classes", 2),
        )
    return metric_fn

def build_recall_metric(metric_cfg):
    positive_class = metric_cfg.get("positive_class", 1)
    num_classes = metric_cfg.get("num_classes", 2)

    def metric_fn(spk_out, y, **kwargs):
        # replace with your own population-vote recall implementation
        preds = spk_out.sum(dim=1).view(spk_out.shape[0], num_classes, -1).sum(dim=2).argmax(dim=1)
        tp = ((preds == positive_class) & (y == positive_class)).float().sum().item()
        fn = ((preds != positive_class) & (y == positive_class)).float().sum().item()
        return tp / max(1.0, tp + fn)
    return metric_fn

Then choose the target Ray should optimize:

metric="val_recall"
mode="max"

All namespaces you can tune

The wrapper supports these binding APIs:

API Namespace in trial config What it controls
bind_model_param(name, space) model.<name> architecture / neuron dynamics
bind_loss_param(loss_name, param_name, space) loss.<loss_name>.<param_name> loss term internals
bind_metric_param(metric_name, param_name, space) metric.<metric_name>.<param_name> metric behavior
bind_optimizer_param(name, space) optim.<name> optimizer type and params
bind_scheduler_param(name, space) sched.<name> LR scheduler config
bind_nwave_scheduler_param(name, space) nwave.<name> surrogate/sign annealing configs
bind_trainer_param(name, space) trainer.<name> epochs, device, etc.
bind_data_param(name, space) data.<name> dataset loader options
bind_early_stop_param(name, space) early_stop.<name> early stop policy

Also available:

  • set_fixed_params(...) for non-tuned fixed values,
  • set_resources(cpu=..., gpu=...),
  • set_num_samples(n),
  • set_max_concurrent_trials(n),
  • set_memory_guard(...) to pre-check RAM budget and avoid OOM.

Scheduler layers (three different levels)

It is useful to keep these separate:

1) Optimizer LR scheduler (sched.*)

Default supported names:

  • none
  • plateau (ReduceLROnPlateau)
  • step (StepLR)
  • cosine (CosineAnnealingLR)
  • multistep (MultiStepLR)

Example:

hpo.bind_scheduler_param(
    "cfg",
    tune.choice([
        {"name": "none"},
        {"name": "plateau", "mode": "min", "factor": 0.3, "patience": 8, "threshold": 1e-3},
        {"name": "step", "step_size": 10, "gamma": 0.5},
        {"name": "cosine", "T_max": 60, "eta_min": 1e-6},
        {"name": "multistep", "milestones": [20, 40], "gamma": 0.3},
    ])
)

2) Ray Tune trial scheduler

Controls trial promotion/stopping across HPO runs.

Use:

  • set_tune_scheduler_config({...}), or
  • set_tune_scheduler(ray_scheduler_instance).

Supported config names:

  • asha (default),
  • hyperband,
  • median,
  • fifo / none.

3) NWAVE schedulers (nwave.*)

Surrogate scheduler (nwave.surrogate.*)

Controls surrogate parameter evolution during training.

Supported policies:

  • constant
  • linear
  • exp
  • cosine
  • step

Typical parameters:

  • policy
  • start_value
  • end_value
  • total_steps
  • step_size (for step)
  • gamma (for step/exp style progression)
  • min_value, max_value.

Sign annealing (nwave.sign_annealing.*)

Gradually tightens row-block sign parametrization.

Typical parameters:

  • alpha_start
  • alpha_end
  • total_epochs.

When enabled, additional sign-topology coherence metrics are reported.


Early stopping policies

Configure through bind_early_stop_param(...).

Supported policies:

  • none
  • threshold
  • patience (aliases: plateau, no_improve, no-improve)

Common fields:

  • policy
  • metric
  • mode (min or max)
  • min_epoch

Threshold policy adds:

  • value

Patience policy adds:

  • patience
  • min_delta

End-to-end template

This template can be used as a starting point:

import os
import torch
from ray import tune

from nwavesdk.optim.hpo.HPOWrapper import HPOWrapper, train_one_trial


def build_model(model_cfg):
    return SurrogateSNN(model_cfg)


def build_datasets(dataset_path, data_cfg):
    train_loader = torch.load(os.path.join(dataset_path, data_cfg["train_file"]), weights_only=False)
    val_loader = torch.load(os.path.join(dataset_path, data_cfg["val_file"]), weights_only=False)
    return train_loader, val_loader


hpo = HPOWrapper(
    model_builder=build_model,
    dataset_builder=build_datasets,
    train_fn=train_one_trial,
    loss_builders={"main": build_main_loss},
    metric_builders={
        "acc": build_acc_metric,
        "recall": build_recall_metric,
    },
    metric="val_recall",
    mode="max",
    experiment_name="h1v2_snn_hpo",
)

# model space
hpo.bind_model_param("in_ch", 13)
hpo.bind_model_param("num_classes", 2)
hpo.bind_model_param("neurons_per_class", 32)
hpo.bind_model_param("hidden_1", tune.choice([32, 64]))
hpo.bind_model_param("dt", tune.choice([1e-3, 2e-3, 4e-3]))
hpo.bind_model_param("tau_l1", tune.loguniform(5e-3, 30e-3))
hpo.bind_model_param("tau_l2", tune.loguniform(10e-3, 50e-3))
hpo.bind_model_param("tau_l3", tune.loguniform(20e-3, 90e-3))
hpo.bind_model_param("device", "gpu")

# optimizer space
hpo.bind_optimizer_param("name", tune.choice(["adam", "adamax"]))
hpo.bind_optimizer_param("lr", tune.loguniform(1e-4, 1e-2))

# loss space
hpo.bind_loss_param("main", "num_classes", 2)
hpo.bind_loss_param("main", "neurons_per_class", 32)
hpo.bind_loss_param("main", "correct_rate", tune.uniform(0.25, 0.45))
hpo.bind_loss_param("main", "incorrect_rate", tune.uniform(0.01, 0.10))
hpo.bind_loss_param("main", "offsets_fr", tune.choice([[0.05, 0.08, 0.10], [0.1, 0.2, 0.3]]))       # 3 layers
hpo.bind_loss_param("main", "multipliers_fr", tune.choice([[1, 2, 3], [10, 5, 1]]))                  # 3 layers
hpo.bind_loss_param("main", "w_population", tune.choice([1.0, 2.0]))
hpo.bind_loss_param("main", "w_activity", tune.choice([0.2, 0.5, 1.0]))
hpo.bind_loss_param("main", "w_hardware", tune.choice([0.0, 1e-3, 1e-2]))
hpo.bind_loss_param("main", "weight_limit", tune.choice([0.8, 0.9]))
hpo.bind_loss_param("main", "lambda_topology", tune.choice([1e-4, 1e-3, 1e-2]))

# metric space
hpo.bind_metric_param("acc", "num_classes", 2)
hpo.bind_metric_param("recall", "num_classes", 2)
hpo.bind_metric_param("recall", "positive_class", tune.choice([0, 1]))

# LR scheduler space
hpo.bind_scheduler_param(
    "cfg",
    tune.choice([
        {"name": "none"},
        {"name": "plateau", "mode": "min", "factor": 0.3, "patience": 8, "threshold": 1e-3, "min_lr": 1e-6},
        {"name": "step", "step_size": 10, "gamma": 0.5},
        {"name": "cosine", "T_max": 60, "eta_min": 1e-6},
        {"name": "multistep", "milestones": [20, 40], "gamma": 0.3},
    ])
)

# early stop
hpo.bind_early_stop_param("policy", tune.choice(["none", "threshold", "patience"]))
hpo.bind_early_stop_param("metric", tune.choice(["val_recall", "val_acc", "val_loss"]))
hpo.bind_early_stop_param("mode", tune.choice(["max", "min"]))
hpo.bind_early_stop_param("min_epoch", tune.choice([8, 12]))
hpo.bind_early_stop_param("value", tune.choice([0.90, 0.93]))        # used by threshold
hpo.bind_early_stop_param("patience", tune.choice([4, 8, 12]))       # used by patience
hpo.bind_early_stop_param("min_delta", tune.choice([0.0, 1e-3]))     # used by patience

# NWAVE schedulers
hpo.bind_nwave_scheduler_param("surrogate.policy", tune.choice(["constant", "linear", "cosine", "exp", "step"]))
hpo.bind_nwave_scheduler_param("surrogate.start_value", tune.choice([10.0, 15.0]))
hpo.bind_nwave_scheduler_param("surrogate.end_value", tune.choice([20.0, 30.0]))
hpo.bind_nwave_scheduler_param("surrogate.total_steps", 60)
hpo.bind_nwave_scheduler_param("surrogate.step_size", tune.choice([5, 10]))
hpo.bind_nwave_scheduler_param("surrogate.gamma", tune.choice([0.5, 0.8, 2.0]))
hpo.bind_nwave_scheduler_param("sign_annealing.alpha_start", tune.choice([0.5, 1.0]))
hpo.bind_nwave_scheduler_param("sign_annealing.alpha_end", tune.choice([10.0, 20.0]))
hpo.bind_nwave_scheduler_param("sign_annealing.total_epochs", 60)

# trainer + data
hpo.bind_trainer_param("epochs", 60)
hpo.bind_trainer_param("device", "cuda")
hpo.bind_data_param("train_file", "train.pt")
hpo.bind_data_param("val_file", "val.pt")

hpo.set_resources(cpu=2, gpu=0.25)
hpo.set_max_concurrent_trials(3)
hpo.set_num_samples(20)
hpo.set_memory_guard(
    enabled=True,
    strict=True,                 # raise if user-forced concurrency is above safe recommendation
    safety_fraction=0.80,        # keep part of RAM free to avoid runtime spikes
    dataset_memory_multiplier=1.25,
)
hpo.set_tune_scheduler_config({"name": "asha", "grace_period": 5, "reduction_factor": 2})

results = hpo.run("/path/to/data")
print("Best config:", hpo.best_config())

Memory guard and concurrency suggestion

Before launching Ray Tune, HPOWrapper.run(...) now performs a host-RAM pre-check:

  • Reads available RAM from the machine.
  • Estimates dataset footprint from dataset_path and known data.* file/path hints.
  • Computes a recommended max_concurrent_trials based on:
  • estimated RAM per trial,
  • requested trial resources (set_resources),
  • num_samples.

Default behavior:

  • If you did not set max_concurrent_trials, the wrapper auto-caps concurrency to the recommended safe value.
  • If you did set max_concurrent_trials and it is above the recommended safe value, the wrapper raises early (when strict=True) with an actionable message.

How to inspect if choices are effective

After run, look beyond only the selection metric.

1) Best trial and config

best = hpo.best_result()
print(best.metrics["val_recall"])
print(best.config)

2) Full trial table

df = results.get_dataframe()
print(df[["trial_id", "epoch", "val_loss", "val_recall", "lr"]].sort_values("val_recall", ascending=False).head())

3) Signals worth checking

To inspect trials during or after training, you can run the following commands that will open a page in your browser at localhost:6006:

tensorboard --logdir ray_results

The default train_one_trial reports the following TensorBoard scalars.

TensorBoard key (pattern) What it is User-defined from outside? Reported when Not reported when
epoch Current epoch index (1-based). No. Always, every epoch. Never (if training is running).
train_loss Alias of train/total (average train total loss over batches). Indirectly: depends on your loss builder. Always, every epoch. Never (if training is running).
val_loss Alias of val/total (average validation total loss over batches). Indirectly: depends on your loss builder. Always, every epoch. Never (if training is running).
lr Current optimizer learning rate (optimizer.param_groups[0]["lr"]). Indirectly: depends on optimizer/scheduler config. Always, every epoch. Never (if training is running).
best_selected_metric Best-so-far value of HPOWrapper(metric=..., mode=...) for the trial. Yes: chosen via wrapper metric/mode. Always, every epoch. Never (if training is running).
train/<loss_component> Per-component train averages from loss_out.components (for example train/main/population/mse, train/main/total, train/total). Yes: entirely depends on your loss builder output. If that loss component exists in your configured loss function. If your loss builder does not produce that component key.
val/<loss_component> Per-component validation averages from loss_out.components (for example val/main/population/mse, val/main/total, val/total). Yes: entirely depends on your loss builder output. If that loss component exists in your configured loss function. If your loss builder does not produce that component key.
train/<metric_name> and train_<metric_name> Train metric averages (two aliases for easier querying). Yes: defined by metric_builders and bound metric config. For each metric function registered in metric_builders. If no metric builder is registered for that metric name.
val/<metric_name> and val_<metric_name> Validation metric averages (two aliases for easier querying). Yes: defined by metric_builders and bound metric config. For each metric function registered in metric_builders. If no metric builder is registered for that metric name.
train/layer_<i>/firing_rate Mean spike activity for layer i in training. No (computed from model outputs). Always for each returned layer in spikes_list. If the layer does not exist in model outputs.
train/layer_<i>/dead_neurons Number of dead neurons in layer i during training epoch. No (computed from model outputs). Always for each returned layer. If the layer does not exist in model outputs.
train/layer_<i>/dead_neurons_fraction Fraction of dead neurons in layer i during training epoch. No (computed from model outputs). Always for each returned layer. If the layer does not exist in model outputs.
train/layer_<i>/membrane_mean Mean membrane value for layer i during training. No (computed from model outputs). Always for each returned layer. If the layer does not exist in model outputs.
train/layer_<i>/membrane_var Membrane variance for layer i during training. No (computed from model outputs). Always for each returned layer. If the layer does not exist in model outputs.
val/layer_<i>/firing_rate Mean spike activity for layer i in validation. No (computed from model outputs). Always for each returned layer in spikes_list. If the layer does not exist in model outputs.
val/layer_<i>/dead_neurons Number of dead neurons in layer i during validation epoch. No (computed from model outputs). Always for each returned layer. If the layer does not exist in model outputs.
val/layer_<i>/dead_neurons_fraction Fraction of dead neurons in layer i during validation epoch. No (computed from model outputs). Always for each returned layer. If the layer does not exist in model outputs.
val/layer_<i>/membrane_mean Mean membrane value for layer i during validation. No (computed from model outputs). Always for each returned layer. If the layer does not exist in model outputs.
val/layer_<i>/membrane_var Membrane variance for layer i during validation. No (computed from model outputs). Always for each returned layer. If the layer does not exist in model outputs.
surrogate_value Current value emitted by NWAVE surrogate scheduler step. Yes: controlled by nwave.surrogate.* binding. Only if surrogate scheduler is configured (non-empty nwave.surrogate.*). If surrogate scheduler is not configured.
sign_alpha Current sign-annealing alpha value. Yes: controlled by nwave.sign_annealing.* binding. Only if sign annealing is configured. If sign annealing is not configured.
sign_topology/<layer>/<weight_id>_coherence_pct Magnitude-weighted sign-topology coherence (%) for parametrized weight/recurrent_weights. Indirectly: depends on sign-annealing setup and model parametrization. If sign annealing is configured and the layer exposes parametrized 2D weights. If sign annealing is disabled or no eligible parametrized weights exist.
early_stop_metric_value Current value of the configured early-stop metric. Yes: controlled by early_stop.metric. Only when early_stop.policy is enabled (threshold/patience). If early stop policy is none/disabled.
early_stop_target_value Threshold target used to trigger stop. Yes: controlled by early_stop.value. Only for early_stop.policy="threshold". For none or patience policy.
early_stop_bad_epochs Number of consecutive non-improving epochs. Yes: controlled by early_stop.patience and early_stop.min_delta. Only for early_stop.policy="patience" (and aliases). For none or threshold policy.
early_stop_best_metric Best metric seen so far under patience policy logic. Yes: controlled by early-stop config + selected metric. Only for early_stop.policy="patience" (and aliases). For none or threshold policy.
early_stop_triggered 1 if early stop condition is met at this epoch, else 0. Yes: depends on early-stop configuration. Only when early stop policy is enabled. If early stop policy is none/disabled.

Note

Bound parameters such as model.*, optim.*, sched.*, nwave.*, trainer.*, data.*, and early_stop.* are trial configuration values, not epoch-wise scalar metrics. They can be inspected in Ray trial config/hparams views.
If a namespace is not bound, it is not used to build that component and therefore no related conditional runtime metrics (for example surrogate_value, sign_alpha, early_stop_*) will appear.

These metrics make it possible to answer:

  • did validation improve for the right reason?
  • did the network become too silent/saturated?
  • did regularization improve topology coherence or only hurt accuracy?
  • which scheduler gave a better convergence profile?

4) Statistical analysis with tbparse

The recommeded package to inspect results of tensoboard is tbparse. It allows to get the dataframe equivalent of the tensorboard view, from where you can inspect results, trends of the metrics with hyperparameters. With those results, you can either drive next HPO runs with more driven hyperparameters, or explain why some models reached certain metrics.

5) Restore best checkpoint

import os
import torch

best = hpo.best_result()
ckpt_dir = best.checkpoint.to_directory()
payload = torch.load(os.path.join(ckpt_dir, "best_model.pt"), map_location="cpu")
print(payload["epoch"], payload["best_selected_metric"])

Troubleshooting

Error: model contract assertion fails

Check that your model returns exactly:

return spikes_list, membranes_list

with same length, matching shapes, and rank 3 tensors.

Error: selected metric not reported

metric=... in HPOWrapper must match one reported key exactly (for example val_recall, val_acc, val_loss).

Error: unsupported optimizer or scheduler

Default builders support:

  • optimizer: adam, adamax
  • LR scheduler: none, plateau, step, cosine, multistep

If you need more options, pass custom optimizer_builder / scheduler_builder.

Error: RC layer complains about missing prepared matrices

Call prepare_net(model) once per batch/sequence before timestep unroll.

Dataloader batch format error

The default train_one_trial accepts both:

  • (x, y)
  • (x, y, meta)

If you still get a batch-format error, make sure your loader yields a tuple/list with at least two elements.

Bound a metric/loss param but nothing changes

Bound spaces are only consumed by registered builders.
If you bind metric.f1.* but no "f1" metric builder exists, those params are unused.


Final take

HPOWrapper is intentionally modular: it does not force a specific H1 architecture, a fixed loss, or a fixed metric set.

That is the main benefit for custom SNN research and engineering:

  • you keep control of model internals and task objectives,
  • you still get a clean, reproducible, scalable HPO workflow,
  • and you can inspect not only performance but also activity health and hardware-oriented behavior.

For an even more in-depth look at a typical script used to train networks, please browse the tutorial folder that contains an end-to-end script for that.