from __future__ import division
import torch
import onmt
"""
Class for managing the internals of the beam search process.
Takes care of beams, back pointers, and scores.
"""
[docs]class Beam(object):
def __init__(self, size, n_best=1, cuda=False, vocab=None,
global_scorer=None):
self.size = size
self.tt = torch.cuda if cuda else torch
# The score for each translation on the beam.
self.scores = self.tt.FloatTensor(size).zero_()
self.allScores = []
# The backpointers at each time-step.
self.prevKs = []
# The outputs at each time-step.
self.nextYs = [self.tt.LongTensor(size)
.fill_(vocab.stoi[onmt.IO.PAD_WORD])]
self.nextYs[0][0] = vocab.stoi[onmt.IO.BOS_WORD]
self.vocab = vocab
# Has EOS topped the beam yet.
self._eos = self.vocab.stoi[onmt.IO.EOS_WORD]
self.eosTop = False
# The attentions (matrix) for each time.
self.attn = []
# Time and k pair for finished.
self.finished = []
self.n_best = n_best
# Information for global scoring.
self.globalScorer = global_scorer
self.globalState = {}
[docs] def getCurrentState(self):
"Get the outputs for the current timestep."
return self.nextYs[-1]
[docs] def getCurrentOrigin(self):
"Get the backpointers for the current timestep."
return self.prevKs[-1]
[docs] def advance(self, wordLk, attnOut):
"""
Given prob over words for every last beam `wordLk` and attention
`attnOut`: Compute and update the beam search.
Parameters:
* `wordLk`- probs of advancing from the last step (K x words)
* `attnOut`- attention at the last step
Returns: True if beam search is complete.
"""
numWords = wordLk.size(1)
# Sum the previous scores.
if len(self.prevKs) > 0:
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
# Don't let EOS have children.
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
beamLk[i] = -1e20
else:
beamLk = wordLk[0]
flatBeamLk = beamLk.view(-1)
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
self.allScores.append(self.scores)
self.scores = bestScores
# bestScoresId is flattened beam x word array, so calculate which
# word and beam each score came from
prevK = bestScoresId / numWords
self.prevKs.append(prevK)
self.nextYs.append((bestScoresId - prevK * numWords))
self.attn.append(attnOut.index_select(0, prevK))
if self.globalScorer is not None:
self.globalScorer.updateGlobalState(self)
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
s = self.scores[i]
if self.globalScorer is not None:
globalScores = self.globalScorer.score(self, self.scores)
s = globalScores[i]
self.finished.append((s, len(self.nextYs) - 1, i))
# End condition is when top-of-beam is EOS and no global score.
if self.nextYs[-1][0] == self.vocab.stoi[onmt.IO.EOS_WORD]:
# self.allScores.append(self.scores)
self.eosTop = True
[docs] def done(self):
return self.eosTop and len(self.finished) >= self.n_best
[docs] def sortFinished(self, minimum=None):
if minimum is not None:
i = 0
# Add from beam until we have minimum outputs.
while len(self.finished) < minimum:
s = self.scores[i]
if self.globalScorer is not None:
globalScores = self.globalScorer.score(self, self.scores)
s = globalScores[i]
self.finished.append((s, len(self.nextYs) - 1, i))
self.finished.sort(key=lambda a: -a[0])
scores = [sc for sc, _, _ in self.finished]
ks = [(t, k) for _, t, k in self.finished]
return scores, ks
[docs] def getHyp(self, timestep, k):
"""
Walk back to construct the full hypothesis.
"""
hyp, attn = [], []
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
hyp.append(self.nextYs[j+1][k])
attn.append(self.attn[j][k])
k = self.prevKs[j][k]
return hyp[::-1], torch.stack(attn[::-1])
[docs]class GNMTGlobalScorer(object):
"""
Google NMT ranking score from Wu et al.
"""
def __init__(self, alpha, beta):
self.alpha = alpha
self.beta = beta
[docs] def score(self, beam, logprobs):
"Additional term add to log probability"
cov = beam.globalState["coverage"]
pen = self.beta * torch.min(cov, cov.clone().fill_(1.0)).log().sum(1)
l_term = (((5 + len(beam.nextYs)) ** self.alpha) /
((5 + 1) ** self.alpha))
return (logprobs / l_term) + pen
[docs] def updateGlobalState(self, beam):
"Keeps the coverage vector as sum of attens"
if len(beam.prevKs) == 1:
beam.globalState["coverage"] = beam.attn[-1]
else:
beam.globalState["coverage"] = beam.globalState["coverage"] \
.index_select(0, beam.prevKs[-1]).add(beam.attn[-1])