Training a Network to maximise it’s curvature¶
WIP
import numpy as np
import torch
from torch import nn, optim
from matplotlib import pyplot as plt
from swarm import networks, core, animator, activations
import env
plt.rcParams["figure.figsize"] = (12.0, 12.0)
SEED = 20
if not env.FULL:
NUM_EPOCHS = 4
NUM_BEES = 5
else:
NUM_EPOCHS = 400
NUM_BEES = 500
def autocorr(x: torch.Tensor, n=10):
"""
Simplified autocorrelation function
Args:
x:
n:
Returns:
"""
# we normalise the vector so moving too far from 0 is penalised as well as too close to 0
# otherwise the solution to minimise autocorrelation is a straight line at 0
# and dividing by the max ends up in floating point sadness
x = x / x.norm()
num = 0
for i in range(n):
num += (x[i:] * x[: len(x) - i]).sum()
return num
def diff(x: torch.Tensor) -> torch.Tensor:
return x[1:] - x[:-1]
def second_deriv(ypred: torch.Tensor, mse_weight=0.01):
fdiff = diff(ypred)
sdiff = diff(fdiff)
mse = ypred.norm()
# ln2 = sdiff.abs().mean()
ln2 = sdiff.norm()
return mse_weight * mse - ln2
class Sin(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.sin(x)
def solo_train(
x, hidden=2, width=2, activation=nn.ReLU, num_epochs=10, lr=0.001, momentum=0.9, corr_len=10
):
net = networks.flat_net(hidden_depth=hidden, width=width, activation=activation)
optimiser = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
mse_weight = 1
loss_func = lambda ypred: second_deriv(ypred, mse_weight=mse_weight)
for epoch in range(num_epochs):
mse_weight = min(1 / (epoch + 1), 0.1)
optimiser.zero_grad()
ypred = net(x)
loss = loss_func(ypred)
if torch.isnan(loss):
raise RuntimeError("NaN loss, poorly configured experiment")
yield ypred, loss
loss.backward()
optimiser.step()
def main():
x = torch.linspace(-10, 10, 100)
beeparams = {
"x": x,
"num_epochs": NUM_EPOCHS,
"lr": 0.005,
"momentum": 0.5,
"width": 50,
"hidden": 3,
"activation": activations.Tanh,
}
results = core.swarm_train(solo_train, beeparams, num_bees=NUM_BEES, seed=SEED, fields="ypred,loss")
print(results["loss"])
yd = np.zeros(len(x))
yd[0] = -0.5
yd[-1] = 0.5
animator.make_animation(
x.detach().numpy(), yd=yd, data=results["ypred"], title="secondderiv", destfile="sd.mp4"
)