>>> chainer_linear = chainer.links.Linear(10, 10)
>>> torch_linear = torch.nn.Linear(10, 10)
>>> input = numpy.arange(100, dtype=numpy.float32).reshape(10, 2, 5)
>>> chainer_linear(input)
variable([[-1.27548418e+01, -6.84057713e-01, 4.33863544e+00,
6.55931234e+00, -3.39445019e+00, -2.51949596e+00,
1.03438854e+00, 7.69187212e-02, 1.93967378e+00,
2.17406702e+00],
[-3.20976524e+01, 2.56630421e+00, 2.20505409e+01,
1.03132820e+01, -9.60312557e+00, -1.00995445e+01,
1.22814524e+00, -2.46707058e+00, 1.31490164e+01,
-4.19132054e-01],
[-5.14404640e+01, 5.81666756e+00, 3.97624512e+01,
1.40672522e+01, -1.58118019e+01, -1.76795921e+01,
1.42190135e+00, -5.01105976e+00, 2.43583584e+01,
-3.01233125e+00],
[-7.07832718e+01, 9.06703091e+00, 5.74743538e+01,
1.78212242e+01, -2.20204792e+01, -2.52596416e+01,
1.61565781e+00, -7.55504990e+00, 3.55676994e+01,
-5.60552549e+00],
[-9.01260834e+01, 1.23173914e+01, 7.51862717e+01,
2.15751915e+01, -2.82291527e+01, -3.28396912e+01,
1.80941379e+00, -1.00990391e+01, 4.67770386e+01,
-8.19872189e+00],
[-1.09468895e+02, 1.55677538e+01, 9.28981705e+01,
2.53291626e+01, -3.44378281e+01, -4.04197388e+01,
2.00317073e+00, -1.26430283e+01, 5.79863853e+01,
-1.07919216e+01],
[-1.28811707e+02, 1.88181152e+01, 1.10610077e+02,
2.90831299e+01, -4.06464958e+01, -4.79997826e+01,
2.19692683e+00, -1.51870193e+01, 6.91957169e+01,
-1.33851223e+01],
[-1.48154510e+02, 2.20684814e+01, 1.28321976e+02,
3.28371048e+01, -4.68551788e+01, -5.55798340e+01,
2.39068103e+00, -1.77310085e+01, 8.04050751e+01,
-1.59783182e+01],
[-1.67497330e+02, 2.53188400e+01, 1.46033875e+02,
3.65910721e+01, -5.30638542e+01, -6.31598854e+01,
2.58443928e+00, -2.02749958e+01, 9.16144028e+01,
-1.85715141e+01],
[-1.86840134e+02, 2.85692101e+01, 1.63745789e+02,
4.03450470e+01, -5.92725296e+01, -7.07399368e+01,
2.77819300e+00, -2.28189812e+01, 1.02823753e+02,
-2.11647110e+01]])
>>> torch_linear(torch.Tensor(input))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/tianqi/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/tianqi/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/home/tianqi/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/nn/functional.py", line 1593, in linear
output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [20 x 5], m2: [10 x 10] at /home/tianqi/repository/pytorch/aten/src/TH/generic/THTensorMath.cpp:41
>>>
class Flatten(torch.nn.Module):
def __init__(self, n_batch_axes=1):
self.n_batch_axes = n_batch_axes
def forward(self, x):
return torch.flatten(x, start_dim=self.n_batch_axes)