Hyperparameter Optimization for Custom hardware-aware SNNs
This guide explains how to use nwavesdk.optim.hpo to train and tune custom spiking neural networks built with H1v1 or H1v2 layers.
The goal is to make the workflow practical and flexible:
- define your own hardware-aware architecture,
- compose task losses and regularizers,
- define the metrics you actually care about,
- tune model, optimizer, scheduler, and training strategy parameters together,
- inspect whether those choices are truly helping.
Why use this module
HPOWrapper gives you a structured way to run Ray Tune over an SNN training pipeline while keeping your model and training logic customizable.
In practice it helps when you want to:
- compare many design choices (taus, optimizer, regularizers, schedulers),
- optimize for non-standard targets (for example
recallinstead of loss), - include hardware-aware behavior during training (sign annealing, topology coherence),
- avoid manual script rewrites every time you change what should be tuned.
Mental model of the pipeline
Each trial follows this flow:
- Build namespaced configs (
model.*,loss.*,metric.*,optim.*,sched.*,nwave.*,trainer.*,data.*,early_stop.*). - Build model and datasets.
- Validate model output contract.
- Build losses and metrics from their per-trial configs.
- Build optimizer and scheduler.
- Optionally build NWAVE schedulers (
surrogate,sign_annealing). - Optionally run the trial with your
train_fn(by default it already uses a training snippet that depends on the binded parameters:train_one_trial). - Report metrics to Ray Tune and checkpoint best epochs.
This separation is what makes the tool general: each piece is a function you provide at an high level, while the training integration is handled in the backend.
Core contracts (important)
1) Model builder
model = model_builder(model_cfg: dict)
model_cfg comes from all keys bound through bind_model_param(...).
2) Model forward output contract
HPOWrapper validates that your model returns:
(spikes_list, membranes_list)
with:
- both objects being lists/tuples,
- same non-zero length,
- each pair
(spk_i, mem_i)same shape, - each tensor rank 3, expected shape is
(B, T, N).
3) Loss builders
You pass a dictionary of builders:
loss_builders = {"main": build_main_loss}
Each builder receives trial loss config and returns a callable loss object.
With default train_one_trial, only ctx.losses["main"] is used.
The returned loss callable is invoked with:
spk_outyspikes_listmembranes_list
and may return:
- a scalar
torch.Tensor, - a
dict[str, Tensor]of components, - a
LossOutput(total=..., components=...).
4) Metric builders
You pass:
metric_builders = {"acc": build_acc_metric, "recall": build_recall_metric, ...}
Each metric builder receives metric_cfg and returns a metric function.
Metric functions are called with:
spk_outyspikes_listmembranes_listmodel
and should return a numeric scalar (float-like).
Building a Tunable Network
The pattern below uses an explicit layer-by-layer style:
- H1v2 frontend + H1v2 synapse/layer stack,
- fixed number of hidden layers (2),
- hidden widths tunable (32 or 64),
- output contract always
(spikes_list, membranes_list).
import torch
import torch.nn as nn
from nwavesdk.layers import H1v2Frontend, H1v2Synapse, H1v2Layer, prepare_net
class SurrogateSNN(nn.Module):
def __init__(self, cfg):
super().__init__()
self.in_ch = cfg["in_ch"]
self.num_classes = cfg["num_classes"]
self.neurons_per_class = cfg["neurons_per_class"]
self.out_neurons = self.num_classes * self.neurons_per_class
self.dt = cfg["dt"]
self.device_flag = cfg.get("device", "gpu")
n1 = 16
# Fixed depth, tunable width
n2 = cfg["hidden_1"] # tune.choice([32, 64])
n3 = self.out_neurons
self.f0 = H1v2Frontend(self.in_ch, device=self.device_flag)
self.s1 = H1v2Synapse(self.in_ch, n1, device=self.device_flag, init=nn.init.xavier_uniform_, lif_threshold=1.0)
self.l1 = H1v2Layer(n1, taus=torch.full((n1,), cfg["tau_l1"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
self.s2 = H1v2Synapse(n1, n2, device=self.device_flag, init=nn.init.xavier_uniform_, lif_threshold=1.0)
self.l2 = H1v2Layer(n2, taus=torch.full((n2,), cfg["tau_l2"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
self.s3 = H1v2Synapse(n2, n3, device=self.device_flag, init=nn.init.xavier_uniform_, lif_threshold=1.0)
self.l3 = H1v2Layer(n3, taus=torch.full((n3,), cfg["tau_l3"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
if self.device_flag == "gpu":
self.to_gpu()
def to_gpu(self):
for m in [self.f0, self.s1, self.l1, self.s2, self.l2, self.s3, self.l3]:
if hasattr(m, "to_gpu"):
m.to_gpu()
def forward_cpu(self, x):
# x: (B, T, F)
prepare_net(self)
spk1, spk2, spk3 = [], [], []
mem1, mem2, mem3 = [], [], []
for t in range(x.shape[1]):
q0 = self.f0(x[:, t, :])
q1 = self.s1(q0)
s1, m1 = self.l1(q1)
q2 = self.s2(s1)
s2, m2 = self.l2(q2)
q3 = self.s3(s2)
s3, m3 = self.l3(q3)
spk1.append(s1); mem1.append(m1)
spk2.append(s2); mem2.append(m2)
spk3.append(s3); mem3.append(m3)
spikes_list = [
torch.stack(spk1, dim=1),
torch.stack(spk2, dim=1),
torch.stack(spk3, dim=1),
]
membranes_list = [
torch.stack(mem1, dim=1),
torch.stack(mem2, dim=1),
torch.stack(mem3, dim=1),
]
return spikes_list, membranes_list
def forward(self, x):
# x: (B, T, F)
prepare_net(self)
x = x.to("cuda", non_blocking=True).contiguous().to(torch.float32)
q0 = self.f0(x)
q1 = self.s1(q0)
s1, m1 = self.l1(q1)
q2 = self.s2(s1)
s2, m2 = self.l2(q2)
q3 = self.s3(s2)
s3, m3 = self.l3(q3)
return [s1, s2, s3], [m1, m2, m3]
Tuning plain weight initialization parameters with HPO
If you want to tune parameters from standard torch.nn.init.* functions, keep the same SurrogateSNN style used above (explicit frontend + synapse + layer stack):
- expose an init name in
model_cfg, - expose only the init-specific scalar parameters you need (for example
gain,a,mode), - apply the selected init right after each synapse creation.
import torch
import torch.nn as nn
from ray import tune
from nwavesdk.layers import H1v2Frontend, H1v2Synapse, H1v2Layer
from nwavesdk.optim.hpo.HPOWrapper import HPOWrapper, train_one_trial
def _apply_synapse_init(synapse, cfg):
w = synapse.weight
init_name = cfg["init_name"]
if init_name == "xavier_uniform":
nn.init.xavier_uniform_(w, gain=cfg["init_gain"])
elif init_name == "xavier_normal":
nn.init.xavier_normal_(w, gain=cfg["init_gain"])
elif init_name == "kaiming_uniform":
nn.init.kaiming_uniform_(
w,
a=cfg["init_a"],
mode=cfg["init_mode"], # "fan_in" or "fan_out"
nonlinearity=cfg["init_nlin"], # e.g. "relu", "leaky_relu"
)
elif init_name == "kaiming_normal":
nn.init.kaiming_normal_(
w,
a=cfg["init_a"],
mode=cfg["init_mode"],
nonlinearity=cfg["init_nlin"],
)
else:
raise ValueError(f"Unsupported init_name: {init_name}")
class SurrogateSNN(nn.Module):
def __init__(self, cfg):
super().__init__()
self.in_ch = cfg["in_ch"]
self.num_classes = cfg["num_classes"]
self.neurons_per_class = cfg["neurons_per_class"]
self.out_neurons = self.num_classes * self.neurons_per_class
self.dt = cfg["dt"]
self.device_flag = cfg.get("device", "gpu")
n1 = 16
n2 = cfg["hidden_1"]
n3 = self.out_neurons
self.f0 = H1v2Frontend(self.in_ch, device=self.device_flag)
self.s1 = H1v2Synapse(self.in_ch, n1, device=self.device_flag)
self.s2 = H1v2Synapse(n1, n2, device=self.device_flag)
self.s3 = H1v2Synapse(n2, n3, device=self.device_flag)
for syn in [self.s1, self.s2, self.s3]:
_apply_synapse_init(syn, cfg)
self.l1 = H1v2Layer(n1, taus=torch.full((n1,), cfg["tau_l1"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
self.l2 = H1v2Layer(n2, taus=torch.full((n2,), cfg["tau_l2"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
self.l3 = H1v2Layer(n3, taus=torch.full((n3,), cfg["tau_l3"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
# forward / forward_cpu are the same as the SurrogateSNN example above.
def build_model(model_cfg):
return SurrogateSNN(model_cfg)
hpo = HPOWrapper(
model_builder=build_model,
dataset_builder=build_datasets,
train_fn=train_one_trial,
loss_builders={"main": build_main_loss},
metric_builders={"acc": build_acc_metric},
metric="val_acc",
mode="max",
experiment_name="plain_init_hpo",
)
hpo.bind_model_param("dt", 2e-3)
hpo.bind_model_param("in_ch", 13)
hpo.bind_model_param("num_classes", 2)
hpo.bind_model_param("neurons_per_class", 32)
hpo.bind_model_param("hidden_1", tune.choice([32, 64]))
hpo.bind_model_param("tau_l1", tune.loguniform(5e-3, 30e-3))
hpo.bind_model_param("tau_l2", tune.loguniform(10e-3, 50e-3))
hpo.bind_model_param("tau_l3", tune.loguniform(20e-3, 90e-3))
hpo.bind_model_param("device", "gpu")
# plain init search space
hpo.bind_model_param(
"init_name",
tune.choice(["xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"]),
)
hpo.bind_model_param("init_gain", tune.uniform(0.5, 2.0)) # used by xavier*
hpo.bind_model_param("init_a", tune.uniform(0.0, 0.3)) # used by kaiming*
hpo.bind_model_param("init_mode", tune.choice(["fan_in", "fan_out"])) # used by kaiming*
hpo.bind_model_param("init_nlin", tune.choice(["relu", "leaky_relu"])) # used by kaiming*
Tuning fluct_init parameters with HPO
When using fluct_init, keep the same SurrogateSNN network definition style and apply fluct_init inside build_model.
import os
import torch
import torch.nn as nn
from ray import tune
from nwavesdk.layers import H1v2Frontend, H1v2Synapse, H1v2Layer
from nwavesdk.init.fluct_init import fluct_init
from nwavesdk.optim.hpo.HPOWrapper import HPOWrapper, train_one_trial
def load_init_loader(dataset_path, train_file):
return torch.load(os.path.join(dataset_path, train_file), weights_only=False)
class SurrogateSNN(nn.Module):
def __init__(self, cfg):
super().__init__()
self.in_ch = cfg["in_ch"]
self.num_classes = cfg["num_classes"]
self.neurons_per_class = cfg["neurons_per_class"]
self.out_neurons = self.num_classes * self.neurons_per_class
self.dt = cfg["dt"]
self.device_flag = cfg.get("device", "gpu")
n1 = 16
n2 = cfg["hidden_1"]
n3 = self.out_neurons
self.f0 = H1v2Frontend(self.in_ch, device=self.device_flag)
self.s1 = H1v2Synapse(self.in_ch, n1, device=self.device_flag, init=nn.init.xavier_uniform_)
self.s2 = H1v2Synapse(n1, n2, device=self.device_flag, init=nn.init.xavier_uniform_)
self.s3 = H1v2Synapse(n2, n3, device=self.device_flag, init=nn.init.xavier_uniform_)
self.l1 = H1v2Layer(n1, taus=torch.full((n1,), cfg["tau_l1"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
self.l2 = H1v2Layer(n2, taus=torch.full((n2,), cfg["tau_l2"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
self.l3 = H1v2Layer(n3, taus=torch.full((n3,), cfg["tau_l3"]), dt=self.dt, device=self.device_flag, layer_topology="FF")
# forward / forward_cpu are the same as the SurrogateSNN example above.
def build_model(model_cfg):
model = SurrogateSNN(model_cfg)
if model_cfg.get("use_fluct_init", False):
init_loader = load_init_loader(
model_cfg["init_dataset_path"],
model_cfg.get("fluct_init_train_file", "train.pt"),
)
fluct_init(
model,
init_loader,
xi_target=model_cfg["fluct_xi_target"],
alpha=model_cfg["fluct_alpha"],
n_batches=model_cfg["fluct_n_batches"],
verbose=model_cfg.get("init_verbose", False),
)
return model
hpo = HPOWrapper(
model_builder=build_model,
dataset_builder=build_datasets,
train_fn=train_one_trial,
loss_builders={"main": build_main_loss},
metric_builders={"acc": build_acc_metric},
metric="val_acc",
mode="max",
experiment_name="fluct_init_hpo",
)
hpo.bind_model_param("dt", 2e-3)
hpo.bind_model_param("in_ch", 13)
hpo.bind_model_param("num_classes", 2)
hpo.bind_model_param("neurons_per_class", 32)
hpo.bind_model_param("hidden_1", tune.choice([32, 64]))
hpo.bind_model_param("tau_l1", tune.loguniform(5e-3, 30e-3))
hpo.bind_model_param("tau_l2", tune.loguniform(10e-3, 50e-3))
hpo.bind_model_param("tau_l3", tune.loguniform(20e-3, 90e-3))
hpo.bind_model_param("device", "gpu")
hpo.bind_model_param("init_dataset_path", "/path/to/data")
hpo.bind_model_param("init_verbose", False)
# fluct_init search space
hpo.bind_model_param("use_fluct_init", True)
hpo.bind_model_param("fluct_init_train_file", "train.pt")
hpo.bind_model_param("fluct_xi_target", tune.uniform(1.0, 3.0))
hpo.bind_model_param("fluct_alpha", tune.uniform(0.7, 1.0))
hpo.bind_model_param("fluct_n_batches", tune.choice([1, 2, 4]))
Composing losses
CompositeLoss lets you combine heterogeneous terms without hard-coding a single monolithic loss.
from nwavesdk.loss import (
balanced_population_mse,
firing_rate_target_mse_loss,
weight_magnitude_loss,
topology_loss,
)
from nwavesdk.optim.hpo.loss_factory import CompositeLoss
def balanced_pop_term(spk_out, y, config, **kwargs):
return {
"mse": balanced_population_mse(
spk_out,
y,
num_classes=config["num_classes"],
neurons_per_class=config["neurons_per_class"],
correct_rate=config["correct_rate"],
incorrect_rate=config["incorrect_rate"],
)
}
def firing_rate_term(spikes_list, config, **kwargs):
return {
"firing_rate": firing_rate_target_mse_loss(
spikes_list=spikes_list,
offsets=config["offsets_fr"],
multipliers=config["multipliers_fr"],
)
}
def hardware_reg_term(model, config, **kwargs):
return {
"weight_mag": weight_magnitude_loss(model, limit=config["limit"]),
"topology": topology_loss(model, lam=config["lambda_topo"]),
}
def build_main_loss(loss_cfg):
return (
CompositeLoss()
.add_loss_term(
"population",
balanced_pop_term,
config={
"num_classes": loss_cfg["num_classes"],
"neurons_per_class": loss_cfg["neurons_per_class"],
"correct_rate": loss_cfg["correct_rate"],
"incorrect_rate": loss_cfg["incorrect_rate"],
},
weight=loss_cfg.get("w_population", 1.0),
)
.add_loss_term(
"activity",
firing_rate_term,
config={
"offsets_fr": loss_cfg["offsets_fr"],
"multipliers_fr": loss_cfg["multipliers_fr"],
},
weight=loss_cfg.get("w_activity", 1.0),
)
.add_loss_term(
"hardware",
hardware_reg_term,
config={
"limit": loss_cfg.get("weight_limit", 0.9),
"lambda_topo": loss_cfg.get("lambda_topology", 1e-3),
},
weight=loss_cfg.get("w_hardware", 1.0),
)
)
This gives:
- clear component logging per term,
- weighted composition,
- easy tuning of each term strength/config.
The default backend supports model-dependent regularizers directly:
train_one_trial passes model into the loss call, and CompositeLoss filters kwargs per loss-term signature.
This allows model-independent and model-dependent terms to be mixed without a custom train_fn.
Warning
Using topology_loss on the entire network composed with a frontend may raise error: be aware to call the loss only on dense synapses if using the topology_loss.
Defining metrics
Any callable metric is valid if it returns a scalar.
from nwavesdk.metrics import accuracy_population
def build_acc_metric(metric_cfg):
def metric_fn(spk_out, y, **kwargs):
return accuracy_population(
spk_out,
y,
num_classes=metric_cfg.get("num_classes", 2),
)
return metric_fn
def build_recall_metric(metric_cfg):
positive_class = metric_cfg.get("positive_class", 1)
num_classes = metric_cfg.get("num_classes", 2)
def metric_fn(spk_out, y, **kwargs):
# replace with your own population-vote recall implementation
preds = spk_out.sum(dim=1).view(spk_out.shape[0], num_classes, -1).sum(dim=2).argmax(dim=1)
tp = ((preds == positive_class) & (y == positive_class)).float().sum().item()
fn = ((preds != positive_class) & (y == positive_class)).float().sum().item()
return tp / max(1.0, tp + fn)
return metric_fn
Then choose the target Ray should optimize:
metric="val_recall"
mode="max"
All namespaces you can tune
The wrapper supports these binding APIs:
| API | Namespace in trial config | What it controls |
|---|---|---|
bind_model_param(name, space) |
model.<name> |
architecture / neuron dynamics |
bind_loss_param(loss_name, param_name, space) |
loss.<loss_name>.<param_name> |
loss term internals |
bind_metric_param(metric_name, param_name, space) |
metric.<metric_name>.<param_name> |
metric behavior |
bind_optimizer_param(name, space) |
optim.<name> |
optimizer type and params |
bind_scheduler_param(name, space) |
sched.<name> |
LR scheduler config |
bind_nwave_scheduler_param(name, space) |
nwave.<name> |
surrogate/sign annealing configs |
bind_trainer_param(name, space) |
trainer.<name> |
epochs, device, etc. |
bind_data_param(name, space) |
data.<name> |
dataset loader options |
bind_early_stop_param(name, space) |
early_stop.<name> |
early stop policy |
Also available:
set_fixed_params(...)for non-tuned fixed values,set_resources(cpu=..., gpu=...),set_num_samples(n),set_max_concurrent_trials(n),set_memory_guard(...)to pre-check RAM budget and avoid OOM.
Scheduler layers (three different levels)
It is useful to keep these separate:
1) Optimizer LR scheduler (sched.*)
Default supported names:
noneplateau(ReduceLROnPlateau)step(StepLR)cosine(CosineAnnealingLR)multistep(MultiStepLR)
Example:
hpo.bind_scheduler_param(
"cfg",
tune.choice([
{"name": "none"},
{"name": "plateau", "mode": "min", "factor": 0.3, "patience": 8, "threshold": 1e-3},
{"name": "step", "step_size": 10, "gamma": 0.5},
{"name": "cosine", "T_max": 60, "eta_min": 1e-6},
{"name": "multistep", "milestones": [20, 40], "gamma": 0.3},
])
)
2) Ray Tune trial scheduler
Controls trial promotion/stopping across HPO runs.
Use:
set_tune_scheduler_config({...}), orset_tune_scheduler(ray_scheduler_instance).
Supported config names:
asha(default),hyperband,median,fifo/none.
3) NWAVE schedulers (nwave.*)
Surrogate scheduler (nwave.surrogate.*)
Controls surrogate parameter evolution during training.
Supported policies:
constantlinearexpcosinestep
Typical parameters:
policystart_valueend_valuetotal_stepsstep_size(for step)gamma(for step/exp style progression)min_value,max_value.
Sign annealing (nwave.sign_annealing.*)
Gradually tightens row-block sign parametrization.
Typical parameters:
alpha_startalpha_endtotal_epochs.
When enabled, additional sign-topology coherence metrics are reported.
Early stopping policies
Configure through bind_early_stop_param(...).
Supported policies:
nonethresholdpatience(aliases:plateau,no_improve,no-improve)
Common fields:
policymetricmode(minormax)min_epoch
Threshold policy adds:
value
Patience policy adds:
patiencemin_delta
End-to-end template
This template can be used as a starting point:
import os
import torch
from ray import tune
from nwavesdk.optim.hpo.HPOWrapper import HPOWrapper, train_one_trial
def build_model(model_cfg):
return SurrogateSNN(model_cfg)
def build_datasets(dataset_path, data_cfg):
train_loader = torch.load(os.path.join(dataset_path, data_cfg["train_file"]), weights_only=False)
val_loader = torch.load(os.path.join(dataset_path, data_cfg["val_file"]), weights_only=False)
return train_loader, val_loader
hpo = HPOWrapper(
model_builder=build_model,
dataset_builder=build_datasets,
train_fn=train_one_trial,
loss_builders={"main": build_main_loss},
metric_builders={
"acc": build_acc_metric,
"recall": build_recall_metric,
},
metric="val_recall",
mode="max",
experiment_name="h1v2_snn_hpo",
)
# model space
hpo.bind_model_param("in_ch", 13)
hpo.bind_model_param("num_classes", 2)
hpo.bind_model_param("neurons_per_class", 32)
hpo.bind_model_param("hidden_1", tune.choice([32, 64]))
hpo.bind_model_param("dt", tune.choice([1e-3, 2e-3, 4e-3]))
hpo.bind_model_param("tau_l1", tune.loguniform(5e-3, 30e-3))
hpo.bind_model_param("tau_l2", tune.loguniform(10e-3, 50e-3))
hpo.bind_model_param("tau_l3", tune.loguniform(20e-3, 90e-3))
hpo.bind_model_param("device", "gpu")
# optimizer space
hpo.bind_optimizer_param("name", tune.choice(["adam", "adamax"]))
hpo.bind_optimizer_param("lr", tune.loguniform(1e-4, 1e-2))
# loss space
hpo.bind_loss_param("main", "num_classes", 2)
hpo.bind_loss_param("main", "neurons_per_class", 32)
hpo.bind_loss_param("main", "correct_rate", tune.uniform(0.25, 0.45))
hpo.bind_loss_param("main", "incorrect_rate", tune.uniform(0.01, 0.10))
hpo.bind_loss_param("main", "offsets_fr", tune.choice([[0.05, 0.08, 0.10], [0.1, 0.2, 0.3]])) # 3 layers
hpo.bind_loss_param("main", "multipliers_fr", tune.choice([[1, 2, 3], [10, 5, 1]])) # 3 layers
hpo.bind_loss_param("main", "w_population", tune.choice([1.0, 2.0]))
hpo.bind_loss_param("main", "w_activity", tune.choice([0.2, 0.5, 1.0]))
hpo.bind_loss_param("main", "w_hardware", tune.choice([0.0, 1e-3, 1e-2]))
hpo.bind_loss_param("main", "weight_limit", tune.choice([0.8, 0.9]))
hpo.bind_loss_param("main", "lambda_topology", tune.choice([1e-4, 1e-3, 1e-2]))
# metric space
hpo.bind_metric_param("acc", "num_classes", 2)
hpo.bind_metric_param("recall", "num_classes", 2)
hpo.bind_metric_param("recall", "positive_class", tune.choice([0, 1]))
# LR scheduler space
hpo.bind_scheduler_param(
"cfg",
tune.choice([
{"name": "none"},
{"name": "plateau", "mode": "min", "factor": 0.3, "patience": 8, "threshold": 1e-3, "min_lr": 1e-6},
{"name": "step", "step_size": 10, "gamma": 0.5},
{"name": "cosine", "T_max": 60, "eta_min": 1e-6},
{"name": "multistep", "milestones": [20, 40], "gamma": 0.3},
])
)
# early stop
hpo.bind_early_stop_param("policy", tune.choice(["none", "threshold", "patience"]))
hpo.bind_early_stop_param("metric", tune.choice(["val_recall", "val_acc", "val_loss"]))
hpo.bind_early_stop_param("mode", tune.choice(["max", "min"]))
hpo.bind_early_stop_param("min_epoch", tune.choice([8, 12]))
hpo.bind_early_stop_param("value", tune.choice([0.90, 0.93])) # used by threshold
hpo.bind_early_stop_param("patience", tune.choice([4, 8, 12])) # used by patience
hpo.bind_early_stop_param("min_delta", tune.choice([0.0, 1e-3])) # used by patience
# NWAVE schedulers
hpo.bind_nwave_scheduler_param("surrogate.policy", tune.choice(["constant", "linear", "cosine", "exp", "step"]))
hpo.bind_nwave_scheduler_param("surrogate.start_value", tune.choice([10.0, 15.0]))
hpo.bind_nwave_scheduler_param("surrogate.end_value", tune.choice([20.0, 30.0]))
hpo.bind_nwave_scheduler_param("surrogate.total_steps", 60)
hpo.bind_nwave_scheduler_param("surrogate.step_size", tune.choice([5, 10]))
hpo.bind_nwave_scheduler_param("surrogate.gamma", tune.choice([0.5, 0.8, 2.0]))
hpo.bind_nwave_scheduler_param("sign_annealing.alpha_start", tune.choice([0.5, 1.0]))
hpo.bind_nwave_scheduler_param("sign_annealing.alpha_end", tune.choice([10.0, 20.0]))
hpo.bind_nwave_scheduler_param("sign_annealing.total_epochs", 60)
# trainer + data
hpo.bind_trainer_param("epochs", 60)
hpo.bind_trainer_param("device", "cuda")
hpo.bind_data_param("train_file", "train.pt")
hpo.bind_data_param("val_file", "val.pt")
hpo.set_resources(cpu=2, gpu=0.25)
hpo.set_max_concurrent_trials(3)
hpo.set_num_samples(20)
hpo.set_memory_guard(
enabled=True,
strict=True, # raise if user-forced concurrency is above safe recommendation
safety_fraction=0.80, # keep part of RAM free to avoid runtime spikes
dataset_memory_multiplier=1.25,
)
hpo.set_tune_scheduler_config({"name": "asha", "grace_period": 5, "reduction_factor": 2})
results = hpo.run("/path/to/data")
print("Best config:", hpo.best_config())
Memory guard and concurrency suggestion
Before launching Ray Tune, HPOWrapper.run(...) now performs a host-RAM pre-check:
- Reads available RAM from the machine.
- Estimates dataset footprint from
dataset_pathand knowndata.*file/path hints. - Computes a recommended
max_concurrent_trialsbased on: - estimated RAM per trial,
- requested trial resources (
set_resources), num_samples.
Default behavior:
- If you did not set
max_concurrent_trials, the wrapper auto-caps concurrency to the recommended safe value. - If you did set
max_concurrent_trialsand it is above the recommended safe value, the wrapper raises early (whenstrict=True) with an actionable message.
How to inspect if choices are effective
After run, look beyond only the selection metric.
1) Best trial and config
best = hpo.best_result()
print(best.metrics["val_recall"])
print(best.config)
2) Full trial table
df = results.get_dataframe()
print(df[["trial_id", "epoch", "val_loss", "val_recall", "lr"]].sort_values("val_recall", ascending=False).head())
3) Signals worth checking
To inspect trials during or after training, you can run the following commands that will open a page in your browser at localhost:6006:
tensorboard --logdir ray_results
The default train_one_trial reports the following TensorBoard scalars.
| TensorBoard key (pattern) | What it is | User-defined from outside? | Reported when | Not reported when |
|---|---|---|---|---|
epoch |
Current epoch index (1-based). | No. | Always, every epoch. | Never (if training is running). |
train_loss |
Alias of train/total (average train total loss over batches). |
Indirectly: depends on your loss builder. | Always, every epoch. | Never (if training is running). |
val_loss |
Alias of val/total (average validation total loss over batches). |
Indirectly: depends on your loss builder. | Always, every epoch. | Never (if training is running). |
lr |
Current optimizer learning rate (optimizer.param_groups[0]["lr"]). |
Indirectly: depends on optimizer/scheduler config. | Always, every epoch. | Never (if training is running). |
best_selected_metric |
Best-so-far value of HPOWrapper(metric=..., mode=...) for the trial. |
Yes: chosen via wrapper metric/mode. |
Always, every epoch. | Never (if training is running). |
train/<loss_component> |
Per-component train averages from loss_out.components (for example train/main/population/mse, train/main/total, train/total). |
Yes: entirely depends on your loss builder output. | If that loss component exists in your configured loss function. | If your loss builder does not produce that component key. |
val/<loss_component> |
Per-component validation averages from loss_out.components (for example val/main/population/mse, val/main/total, val/total). |
Yes: entirely depends on your loss builder output. | If that loss component exists in your configured loss function. | If your loss builder does not produce that component key. |
train/<metric_name> and train_<metric_name> |
Train metric averages (two aliases for easier querying). | Yes: defined by metric_builders and bound metric config. |
For each metric function registered in metric_builders. |
If no metric builder is registered for that metric name. |
val/<metric_name> and val_<metric_name> |
Validation metric averages (two aliases for easier querying). | Yes: defined by metric_builders and bound metric config. |
For each metric function registered in metric_builders. |
If no metric builder is registered for that metric name. |
train/layer_<i>/firing_rate |
Mean spike activity for layer i in training. |
No (computed from model outputs). | Always for each returned layer in spikes_list. |
If the layer does not exist in model outputs. |
train/layer_<i>/dead_neurons |
Number of dead neurons in layer i during training epoch. |
No (computed from model outputs). | Always for each returned layer. | If the layer does not exist in model outputs. |
train/layer_<i>/dead_neurons_fraction |
Fraction of dead neurons in layer i during training epoch. |
No (computed from model outputs). | Always for each returned layer. | If the layer does not exist in model outputs. |
train/layer_<i>/membrane_mean |
Mean membrane value for layer i during training. |
No (computed from model outputs). | Always for each returned layer. | If the layer does not exist in model outputs. |
train/layer_<i>/membrane_var |
Membrane variance for layer i during training. |
No (computed from model outputs). | Always for each returned layer. | If the layer does not exist in model outputs. |
val/layer_<i>/firing_rate |
Mean spike activity for layer i in validation. |
No (computed from model outputs). | Always for each returned layer in spikes_list. |
If the layer does not exist in model outputs. |
val/layer_<i>/dead_neurons |
Number of dead neurons in layer i during validation epoch. |
No (computed from model outputs). | Always for each returned layer. | If the layer does not exist in model outputs. |
val/layer_<i>/dead_neurons_fraction |
Fraction of dead neurons in layer i during validation epoch. |
No (computed from model outputs). | Always for each returned layer. | If the layer does not exist in model outputs. |
val/layer_<i>/membrane_mean |
Mean membrane value for layer i during validation. |
No (computed from model outputs). | Always for each returned layer. | If the layer does not exist in model outputs. |
val/layer_<i>/membrane_var |
Membrane variance for layer i during validation. |
No (computed from model outputs). | Always for each returned layer. | If the layer does not exist in model outputs. |
surrogate_value |
Current value emitted by NWAVE surrogate scheduler step. | Yes: controlled by nwave.surrogate.* binding. |
Only if surrogate scheduler is configured (non-empty nwave.surrogate.*). |
If surrogate scheduler is not configured. |
sign_alpha |
Current sign-annealing alpha value. | Yes: controlled by nwave.sign_annealing.* binding. |
Only if sign annealing is configured. | If sign annealing is not configured. |
sign_topology/<layer>/<weight_id>_coherence_pct |
Magnitude-weighted sign-topology coherence (%) for parametrized weight/recurrent_weights. |
Indirectly: depends on sign-annealing setup and model parametrization. | If sign annealing is configured and the layer exposes parametrized 2D weights. | If sign annealing is disabled or no eligible parametrized weights exist. |
early_stop_metric_value |
Current value of the configured early-stop metric. | Yes: controlled by early_stop.metric. |
Only when early_stop.policy is enabled (threshold/patience). |
If early stop policy is none/disabled. |
early_stop_target_value |
Threshold target used to trigger stop. | Yes: controlled by early_stop.value. |
Only for early_stop.policy="threshold". |
For none or patience policy. |
early_stop_bad_epochs |
Number of consecutive non-improving epochs. | Yes: controlled by early_stop.patience and early_stop.min_delta. |
Only for early_stop.policy="patience" (and aliases). |
For none or threshold policy. |
early_stop_best_metric |
Best metric seen so far under patience policy logic. | Yes: controlled by early-stop config + selected metric. | Only for early_stop.policy="patience" (and aliases). |
For none or threshold policy. |
early_stop_triggered |
1 if early stop condition is met at this epoch, else 0. |
Yes: depends on early-stop configuration. | Only when early stop policy is enabled. | If early stop policy is none/disabled. |
Note
Bound parameters such as model.*, optim.*, sched.*, nwave.*, trainer.*,
data.*, and early_stop.* are trial configuration values, not epoch-wise scalar metrics.
They can be inspected in Ray trial config/hparams views.
If a namespace is not bound, it is not used to build that component and therefore no related
conditional runtime metrics (for example surrogate_value, sign_alpha, early_stop_*) will appear.
These metrics make it possible to answer:
- did validation improve for the right reason?
- did the network become too silent/saturated?
- did regularization improve topology coherence or only hurt accuracy?
- which scheduler gave a better convergence profile?
4) Statistical analysis with tbparse
The recommeded package to inspect results of tensoboard is tbparse. It allows to get the dataframe equivalent of the tensorboard view, from where you can inspect results, trends of the metrics with hyperparameters. With those results, you can either drive next HPO runs with more driven hyperparameters, or explain why some models reached certain metrics.
5) Restore best checkpoint
import os
import torch
best = hpo.best_result()
ckpt_dir = best.checkpoint.to_directory()
payload = torch.load(os.path.join(ckpt_dir, "best_model.pt"), map_location="cpu")
print(payload["epoch"], payload["best_selected_metric"])
Troubleshooting
Error: model contract assertion fails
Check that your model returns exactly:
return spikes_list, membranes_list
with same length, matching shapes, and rank 3 tensors.
Error: selected metric not reported
metric=... in HPOWrapper must match one reported key exactly (for example val_recall, val_acc, val_loss).
Error: unsupported optimizer or scheduler
Default builders support:
- optimizer:
adam,adamax - LR scheduler:
none,plateau,step,cosine,multistep
If you need more options, pass custom optimizer_builder / scheduler_builder.
Error: RC layer complains about missing prepared matrices
Call prepare_net(model) once per batch/sequence before timestep unroll.
Dataloader batch format error
The default train_one_trial accepts both:
(x, y)(x, y, meta)
If you still get a batch-format error, make sure your loader yields a tuple/list with at least two elements.
Bound a metric/loss param but nothing changes
Bound spaces are only consumed by registered builders.
If you bind metric.f1.* but no "f1" metric builder exists, those params are unused.
Final take
HPOWrapper is intentionally modular: it does not force a specific H1 architecture, a fixed loss, or a fixed metric set.
That is the main benefit for custom SNN research and engineering:
- you keep control of model internals and task objectives,
- you still get a clean, reproducible, scalable HPO workflow,
- and you can inspect not only performance but also activity health and hardware-oriented behavior.
For an even more in-depth look at a typical script used to train networks, please browse the tutorial folder that contains an end-to-end script for that.