Distribution of Weights in a Network¶
Varun Nayyar, 2020-08-23
Let us consider the simplest possible neural network, 1 input \(x\), 1 output \(y\) with some non-linearity \(f\). This is expressed as
where \(w\), \(b\) are the weight and bias in the network. Putting this into a slightly different form
we know that the activation function is centered at \(-b/w\).
For this experiment, we look at the distribution of \(-b/w\) for a swarm fitting to a
trig function: sin and cos have very obvious turning points.
ReLU activation: as a very simple activation, the \(-b/w\) will correspond exactly to the turning points
Single hidden layer: this makes interpretability a bit clearer since we have a clearer understanding of the mix
import numpy as np
import torch
from torch import nn
from matplotlib import pyplot as plt
from IPython.display import Video
from swarm import core, animator, networks
import env
plt.rcParams["figure.figsize"] = (12.0, 12.0)
SEED = 20
if not env.FULL:
NUM_EPOCHS = 4
NUM_BEES = 5
else:
NUM_EPOCHS = 400
NUM_BEES = 500
def bee_trainer(xt, yt, width=2, num_epochs=200):
"""Define a simple training loop for use with swarm"""
net = networks.flat_net(1, width, activation=nn.ReLU)
optimiser = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
loss_func = torch.nn.MSELoss()
for epoch in range(num_epochs):
optimiser.zero_grad()
ypred = net(xt)
loss = loss_func(ypred, yt)
if torch.isnan(loss):
raise RuntimeError("NaN loss, poorly configured experiment")
loss.backward()
optimiser.step()
weight, bias, *_ = net.parameters()
yield ypred, weight.detach().flatten().numpy().copy(), bias.detach().numpy().copy()
def main():
xt = torch.linspace(-3 * np.pi, 3 * np.pi, 101)
yt = torch.sin(xt)
bp = {"xt": xt, "yt": yt, "width": 20, "num_epochs": NUM_EPOCHS}
res = core.swarm_train(bee_trainer, bp, num_bees=NUM_BEES, fields="ypred,weights,biases", seed=SEED)
bw = -res["biases"] / res["weights"]
# reduce range to be safe
bw = bw.clip(-10, 10)
ls = animator.LineSwarm.standard(xt.detach().numpy(), yt.detach().numpy(), res["ypred"][::10], set_xlim=(-10,10))
hist = animator.HistogramSwarm.from_swarm(
bw, 100, set_title="- Biases/Weights", set_ylabel="Count", set_xlim=(-10,10)
)
animator.swarm_animate([ls, hist], "weight_distr.mp4")
main()
Video("weight_distr.mp4", embed=True)
Weight Distributions¶
We can see that the biases and weights cluster around the places where the sin curve turns. As you’d expect with the starting conditions being quite close to 0, we see that most of the bends assigned by the network fit into the first curves and not the turning points at extremities.