import torch
from torch.autograd import Variable
import onmt
import onmt.Models
import onmt.ModelConstructor
import onmt.modules
import onmt.IO
from onmt.Utils import use_gpu
[docs]class Translator(object):
def __init__(self, opt, dummy_opt={}):
# Add in default model arguments, possibly added since training.
self.opt = opt
checkpoint = torch.load(opt.model,
map_location=lambda storage, loc: storage)
self.fields = onmt.IO.load_fields(checkpoint['vocab'],
data_type=opt.data_type)
model_opt = checkpoint['opt']
for arg in dummy_opt:
if arg not in model_opt:
model_opt.__dict__[arg] = dummy_opt[arg]
self._type = model_opt.encoder_type
self.copy_attn = model_opt.copy_attn
self.model = onmt.ModelConstructor.make_base_model(
model_opt, self.fields, use_gpu(opt), checkpoint)
self.model.eval()
self.model.generator.eval()
# Length + Coverage Penalty
self.alpha = opt.alpha
self.beta = opt.beta
# for debugging
self.beam_accum = None
[docs] def initBeamAccum(self):
self.beam_accum = {
"predicted_ids": [],
"beam_parent_ids": [],
"scores": [],
"log_probs": []}
[docs] def buildTargetTokens(self, pred, src, attn, copy_vocab):
vocab = self.fields["tgt"].vocab
tokens = []
for tok in pred:
if tok < len(vocab):
tokens.append(vocab.itos[tok])
else:
tokens.append(copy_vocab.itos[tok - len(vocab)])
if tokens[-1] == onmt.IO.EOS_WORD:
tokens = tokens[:-1]
break
if self.opt.replace_unk and (attn is not None) and (src is not None):
for i in range(len(tokens)):
if tokens[i] == vocab.itos[onmt.IO.UNK]:
_, maxIndex = attn[i].max(0)
tokens[i] = self.fields["src"].vocab.itos[src[maxIndex[0]]]
return tokens
def _runTarget(self, batch, data):
data_type = data.data_type
if data_type == 'text':
_, src_lengths = batch.src
else:
src_lengths = None
src = onmt.IO.make_features(batch, 'src', data_type)
tgt_in = onmt.IO.make_features(batch, 'tgt')[:-1]
# (1) run the encoder on the src
encStates, context = self.model.encoder(src, src_lengths)
decStates = self.model.decoder.init_decoder_state(
src, context, encStates)
# (2) if a target is specified, compute the 'goldScore'
# (i.e. log likelihood) of the target under the model
tt = torch.cuda if self.opt.cuda else torch
goldScores = tt.FloatTensor(batch.batch_size).fill_(0)
decOut, decStates, attn = self.model.decoder(
tgt_in, context, decStates, context_lengths=src_lengths)
tgt_pad = self.fields["tgt"].vocab.stoi[onmt.IO.PAD_WORD]
for dec, tgt in zip(decOut, batch.tgt[1:].data):
# Log prob of each word.
out = self.model.generator.forward(dec)
tgt = tgt.unsqueeze(1)
scores = out.data.gather(1, tgt)
scores.masked_fill_(tgt.eq(tgt_pad), 0)
goldScores += scores
return goldScores
[docs] def translateBatch(self, batch, data):
beam_size = self.opt.beam_size
batch_size = batch.batch_size
# (1) Run the encoder on the src.
data_type = data.data_type
src = onmt.IO.make_features(batch, 'src', data_type)
if data_type == 'text':
_, src_lengths = batch.src
else:
src_lengths = None
encStates, context = self.model.encoder(src, src_lengths)
decStates = self.model.decoder.init_decoder_state(
src, context, encStates)
if src_lengths is None:
src_lengths = torch.Tensor(batch_size).type_as(context.data)\
.long()\
.fill_(context.size(0))
# (1b) Initialize for the decoder.
def var(a): return Variable(a, volatile=True)
def rvar(a): return var(a.repeat(1, beam_size, 1))
# Repeat everything beam_size times.
context = rvar(context.data)
context_lengths = src_lengths.repeat(beam_size)
if data_type == 'text':
srcMap = rvar(batch.src_map.data)
else:
srcMap = None
decStates.repeat_beam_size_times(beam_size)
scorer = onmt.GNMTGlobalScorer(self.alpha, self.beta)
beam = [onmt.Beam(beam_size, n_best=self.opt.n_best,
cuda=self.opt.cuda,
vocab=self.fields["tgt"].vocab,
global_scorer=scorer)
for __ in range(batch_size)]
# (2) run the decoder to generate sentences, using beam search.
def bottle(m):
return m.view(batch_size * beam_size, -1)
def unbottle(m):
return m.view(beam_size, batch_size, -1)
for i in range(self.opt.max_sent_length):
if all((b.done() for b in beam)):
break
# Construct batch x beam_size nxt words.
# Get all the pending current beam words and arrange for forward.
inp = var(torch.stack([b.getCurrentState() for b in beam])
.t().contiguous().view(1, -1))
# Turn any copied words to UNKs
# 0 is unk
if self.copy_attn:
inp = inp.masked_fill(
inp.gt(len(self.fields["tgt"].vocab) - 1), 0)
# Temporary kludge solution to handle changed dim expectation
# in the decoder
inp = inp.unsqueeze(2)
# Run one step.
decOut, decStates, attn = self.model.decoder(
inp, context, decStates, context_lengths=context_lengths)
decOut = decOut.squeeze(0)
# decOut: beam x rnn_size
# (b) Compute a vector of batch*beam word scores.
if not self.copy_attn:
out = self.model.generator.forward(decOut).data
out = unbottle(out)
# beam x tgt_vocab
else:
out = self.model.generator.forward(decOut,
attn["copy"].squeeze(0),
srcMap)
# beam x (tgt_vocab + extra_vocab)
out = data.collapse_copy_scores(
unbottle(out.data),
batch, self.fields["tgt"].vocab)
# beam x tgt_vocab
out = out.log()
# (c) Advance each beam.
for j, b in enumerate(beam):
b.advance(
out[:, j],
unbottle(attn["std"]).data[:, j, :context_lengths[j]])
decStates.beam_update(j, b.getCurrentOrigin(), beam_size)
if "tgt" in batch.__dict__:
allGold = self._runTarget(batch, data)
else:
allGold = [0] * batch_size
# (3) Package everything up.
allHyps, allScores, allAttn = [], [], []
for b in beam:
n_best = self.opt.n_best
scores, ks = b.sortFinished(minimum=n_best)
hyps, attn = [], []
for i, (times, k) in enumerate(ks[:n_best]):
hyp, att = b.getHyp(times, k)
hyps.append(hyp)
attn.append(att)
allHyps.append(hyps)
allScores.append(scores)
allAttn.append(attn)
return allHyps, allScores, allAttn, allGold
[docs] def translate(self, batch, data):
# (1) convert words to indexes
batch_size = batch.batch_size
# (2) translate
pred, predScore, attn, goldScore = self.translateBatch(batch, data)
assert(len(goldScore) == len(pred))
pred, predScore, attn, goldScore, indices = list(zip(
*sorted(zip(pred, predScore, attn, goldScore,
batch.indices.data),
key=lambda x: x[-1])))
inds, perm = torch.sort(batch.indices.data)
# (3) convert indexes to words
predBatch, goldBatch = [], []
data_type = data.data_type
if data_type == 'text':
src = batch.src[0].data.index_select(1, perm)
else:
src = None
if self.opt.tgt:
tgt = batch.tgt.data.index_select(1, perm)
for b in range(batch_size):
if data_type == 'text':
src_vocab = data.src_vocabs[inds[b]]
else:
src_vocab = None
predBatch.append(
[self.buildTargetTokens(pred[b][n], src[:, b]
if src is not None else None,
attn[b][n], src_vocab)
for n in range(self.opt.n_best)])
if self.opt.tgt:
goldBatch.append(
self.buildTargetTokens(tgt[1:, b], src[:, b]
if src is not None else None,
None, None))
return predBatch, goldBatch, predScore, goldScore, attn, src, indices