import torch
import torch.nn as nn
[docs]class Bottle(nn.Module):
[docs] def forward(self, input):
if len(input.size()) <= 2:
return super(Bottle, self).forward(input)
size = input.size()[:2]
out = super(Bottle, self).forward(input.view(size[0]*size[1], -1))
return out.contiguous().view(size[0], size[1], -1)
[docs]class Bottle2(nn.Module):
[docs] def forward(self, input):
if len(input.size()) <= 3:
return super(Bottle2, self).forward(input)
size = input.size()
out = super(Bottle2, self).forward(input.view(size[0]*size[1],
size[2], size[3]))
return out.contiguous().view(size[0], size[1], size[2], size[3])
[docs]class LayerNorm(nn.Module):
''' Layer normalization module '''
def __init__(self, d_hid, eps=1e-3):
super(LayerNorm, self).__init__()
self.eps = eps
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
[docs] def forward(self, z):
if z.size(1) == 1:
return z
mu = torch.mean(z, dim=1)
sigma = torch.std(z, dim=1)
# HACK. PyTorch is changing behavior
if mu.dim() == 1:
mu = mu.unsqueeze(1)
sigma = sigma.unsqueeze(1)
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
ln_out = ln_out.mul(self.a_2.expand_as(ln_out)) \
+ self.b_2.expand_as(ln_out)
return ln_out
[docs]class BottleLinear(Bottle, nn.Linear):
pass
[docs]class BottleLayerNorm(Bottle, LayerNorm):
pass
[docs]class BottleSoftmax(Bottle, nn.Softmax):
pass
[docs]class Elementwise(nn.ModuleList):
"""
A simple network container.
Parameters are a list of modules.
Inputs are a 3d Variable whose last dimension is the same length
as the list.
Outputs are the result of applying modules to inputs elementwise.
An optional merge parameter allows the outputs to be reduced to a
single Variable.
"""
def __init__(self, merge=None, *args):
assert merge in [None, 'first', 'concat', 'sum', 'mlp']
self.merge = merge
super(Elementwise, self).__init__(*args)
[docs] def forward(self, input):
inputs = [feat.squeeze(2) for feat in input.split(1, dim=2)]
assert len(self) == len(inputs)
outputs = [f(x) for f, x in zip(self, inputs)]
if self.merge == 'first':
return outputs[0]
elif self.merge == 'concat' or self.merge == 'mlp':
return torch.cat(outputs, 2)
elif self.merge == 'sum':
return sum(outputs)
else:
return outputs