Source code for onmt.Beam

from __future__ import division
import torch
import onmt

"""
 Class for managing the internals of the beam search process.

 Takes care of beams, back pointers, and scores.
"""


[docs]class Beam(object): def __init__(self, size, n_best=1, cuda=False, vocab=None, global_scorer=None): self.size = size self.tt = torch.cuda if cuda else torch # The score for each translation on the beam. self.scores = self.tt.FloatTensor(size).zero_() self.allScores = [] # The backpointers at each time-step. self.prevKs = [] # The outputs at each time-step. self.nextYs = [self.tt.LongTensor(size) .fill_(vocab.stoi[onmt.IO.PAD_WORD])] self.nextYs[0][0] = vocab.stoi[onmt.IO.BOS_WORD] self.vocab = vocab # Has EOS topped the beam yet. self._eos = self.vocab.stoi[onmt.IO.EOS_WORD] self.eosTop = False # The attentions (matrix) for each time. self.attn = [] # Time and k pair for finished. self.finished = [] self.n_best = n_best # Information for global scoring. self.globalScorer = global_scorer self.globalState = {}
[docs] def getCurrentState(self): "Get the outputs for the current timestep." return self.nextYs[-1]
[docs] def getCurrentOrigin(self): "Get the backpointers for the current timestep." return self.prevKs[-1]
[docs] def advance(self, wordLk, attnOut): """ Given prob over words for every last beam `wordLk` and attention `attnOut`: Compute and update the beam search. Parameters: * `wordLk`- probs of advancing from the last step (K x words) * `attnOut`- attention at the last step Returns: True if beam search is complete. """ numWords = wordLk.size(1) # Sum the previous scores. if len(self.prevKs) > 0: beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) # Don't let EOS have children. for i in range(self.nextYs[-1].size(0)): if self.nextYs[-1][i] == self._eos: beamLk[i] = -1e20 else: beamLk = wordLk[0] flatBeamLk = beamLk.view(-1) bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) self.allScores.append(self.scores) self.scores = bestScores # bestScoresId is flattened beam x word array, so calculate which # word and beam each score came from prevK = bestScoresId / numWords self.prevKs.append(prevK) self.nextYs.append((bestScoresId - prevK * numWords)) self.attn.append(attnOut.index_select(0, prevK)) if self.globalScorer is not None: self.globalScorer.updateGlobalState(self) for i in range(self.nextYs[-1].size(0)): if self.nextYs[-1][i] == self._eos: s = self.scores[i] if self.globalScorer is not None: globalScores = self.globalScorer.score(self, self.scores) s = globalScores[i] self.finished.append((s, len(self.nextYs) - 1, i)) # End condition is when top-of-beam is EOS and no global score. if self.nextYs[-1][0] == self.vocab.stoi[onmt.IO.EOS_WORD]: # self.allScores.append(self.scores) self.eosTop = True
[docs] def done(self): return self.eosTop and len(self.finished) >= self.n_best
[docs] def sortFinished(self, minimum=None): if minimum is not None: i = 0 # Add from beam until we have minimum outputs. while len(self.finished) < minimum: s = self.scores[i] if self.globalScorer is not None: globalScores = self.globalScorer.score(self, self.scores) s = globalScores[i] self.finished.append((s, len(self.nextYs) - 1, i)) self.finished.sort(key=lambda a: -a[0]) scores = [sc for sc, _, _ in self.finished] ks = [(t, k) for _, t, k in self.finished] return scores, ks
[docs] def getHyp(self, timestep, k): """ Walk back to construct the full hypothesis. """ hyp, attn = [], [] for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): hyp.append(self.nextYs[j+1][k]) attn.append(self.attn[j][k]) k = self.prevKs[j][k] return hyp[::-1], torch.stack(attn[::-1])
[docs]class GNMTGlobalScorer(object): """ Google NMT ranking score from Wu et al. """ def __init__(self, alpha, beta): self.alpha = alpha self.beta = beta
[docs] def score(self, beam, logprobs): "Additional term add to log probability" cov = beam.globalState["coverage"] pen = self.beta * torch.min(cov, cov.clone().fill_(1.0)).log().sum(1) l_term = (((5 + len(beam.nextYs)) ** self.alpha) / ((5 + 1) ** self.alpha)) return (logprobs / l_term) + pen
[docs] def updateGlobalState(self, beam): "Keeps the coverage vector as sum of attens" if len(beam.prevKs) == 1: beam.globalState["coverage"] = beam.attn[-1] else: beam.globalState["coverage"] = beam.globalState["coverage"] \ .index_select(0, beam.prevKs[-1]).add(beam.attn[-1])