I tried to follow the instructions in the paper, but was unable to reproduce them successfully.
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from matplotlib import pyplot as plt
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Sequential(
nn.Linear(28 * 28, 200),
nn.ReLU(),
nn.Linear(200, 200),
nn.ReLU(),
nn.Linear(200, 10),
nn.ReLU())
def forward(self, x):
return self.fc(x)
train_set = torchvision.datasets.MNIST('dataset', train=True, download=False, transform=transforms.ToTensor())
test_set = torchvision.datasets.MNIST('dataset', train=False, download=False, transform=transforms.ToTensor())
batch_size = 128
num_epochs = 1000
device = torch.device('cuda')
dataloader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size,
shuffle=True,
pin_memory=True)
test_dataloader = torch.utils.data.DataLoader(
test_set, batch_size=batch_size,
shuffle=False,
pin_memory=True)
train_inputs = []
for i, x in enumerate(dataloader):
if i == 8:
break
train_inputs.append(x)
inputs = list(map(lambda x: (x[0].to(device, non_blocking=True).view(-1, 28 * 28), x[1].to(device, non_blocking=True)), train_inputs))
print('Number of train_set: ', len(inputs) * batch_size)
test_inputs = list(map(lambda x: (x[0].to(device, non_blocking=True).view(-1, 28 * 28), x[1].to(device, non_blocking=True)), list(test_dataloader)))
net = Net().to(device)
with torch.no_grad():
for param in net.parameters():
if param.dim() == 1:
print(param.shape)
else:
param *= 9.0
optimizer = optim.AdamW(net.parameters(), weight_decay=5e-3)
accs, test_accs = [], []
for epoch in range(num_epochs):
net.eval()
test_correct = 0
with torch.no_grad():
for x, y in test_inputs:
output = net(x)
pred = output.data.max(1, keepdim=True)[1]
test_correct += pred.eq(y.data.view_as(pred)).sum().item()
net.train()
correct = 0
for x, y in inputs:
optimizer.zero_grad()
output = net(x)
loss = F.mse_loss(output, F.one_hot(y, 10).float())
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(y.data.view_as(pred)).sum().item()
loss.backward()
optimizer.step()
acc, test_acc = 100. * correct / (len(inputs) * batch_size), 100. * test_correct / len(test_set)
accs.append(acc)
test_accs.append(test_acc)
print(epoch, loss.item(), acc, test_acc)
plt.plot(accs)
plt.plot(test_accs)
plt.show()