import torch
from onmt.translate.decode_strategy import DecodeStrategy
[docs]class BeamSearch(DecodeStrategy):
"""Generation beam search.
Note that the attributes list is not exhaustive. Rather, it highlights
tensors to document their shape. (Since the state variables' "batch"
size decreases as beams finish, we denote this axis with a B rather than
``batch_size``).
Args:
beam_size (int): Number of beams to use (see base ``parallel_paths``).
batch_size (int): See base.
pad (int): See base.
bos (int): See base.
eos (int): See base.
n_best (int): Don't stop until at least this many beams have
reached EOS.
mb_device (torch.device or str): See base ``device``.
global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
min_length (int): See base.
max_length (int): See base.
return_attention (bool): See base.
block_ngram_repeat (int): See base.
exclusion_tokens (set[int]): See base.
memory_lengths (LongTensor): Lengths of encodings. Used for
masking attentions.
Attributes:
top_beam_finished (ByteTensor): Shape ``(B,)``.
_batch_offset (LongTensor): Shape ``(B,)``.
_beam_offset (LongTensor): Shape ``(batch_size x beam_size,)``.
alive_seq (LongTensor): See base.
topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These
are the scores used for the topk operation.
select_indices (LongTensor or NoneType): Shape
``(B x beam_size,)``. This is just a flat view of the
``_batch_index``.
topk_scores (FloatTensor): Shape
``(B, beam_size)``. These are the
scores a sequence will receive if it finishes.
topk_ids (LongTensor): Shape ``(B, beam_size)``. These are the
word indices of the topk predictions.
_batch_index (LongTensor): Shape ``(B, beam_size)``.
_prev_penalty (FloatTensor or NoneType): Shape
``(B, beam_size)``. Initialized to ``None``.
_coverage (FloatTensor or NoneType): Shape
``(1, B x beam_size, inp_seq_len)``.
hypotheses (list[list[Tuple[Tensor]]]): Contains a tuple
of score (float), sequence (long), and attention (float or None).
"""
def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device,
global_scorer, min_length, max_length, return_attention,
block_ngram_repeat, exclusion_tokens, memory_lengths,
stepwise_penalty, ratio):
super(BeamSearch, self).__init__(
pad, bos, eos, batch_size, mb_device, beam_size, min_length,
block_ngram_repeat, exclusion_tokens, return_attention,
max_length)
# beam parameters
self.global_scorer = global_scorer
self.beam_size = beam_size
self.n_best = n_best
self.batch_size = batch_size
self.ratio = ratio
# result caching
self.hypotheses = [[] for _ in range(batch_size)]
# beam state
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
self.best_scores = torch.full([batch_size], -1e10, dtype=torch.float,
device=mb_device)
self._batch_offset = torch.arange(batch_size, dtype=torch.long)
self._beam_offset = torch.arange(
0, batch_size * beam_size, step=beam_size, dtype=torch.long,
device=mb_device)
self.topk_log_probs = torch.tensor(
[0.0] + [float("-inf")] * (beam_size - 1), device=mb_device
).repeat(batch_size)
self.select_indices = None
self._memory_lengths = memory_lengths
# buffers for the topk scores and 'backpointer'
self.topk_scores = torch.empty((batch_size, beam_size),
dtype=torch.float, device=mb_device)
self.topk_ids = torch.empty((batch_size, beam_size), dtype=torch.long,
device=mb_device)
self._batch_index = torch.empty([batch_size, beam_size],
dtype=torch.long, device=mb_device)
self.done = False
# "global state" of the old beam
self._prev_penalty = None
self._coverage = None
self._stepwise_cov_pen = (
stepwise_penalty and self.global_scorer.has_cov_pen)
self._vanilla_cov_pen = (
not stepwise_penalty and self.global_scorer.has_cov_pen)
self._cov_pen = self.global_scorer.has_cov_pen
@property
def current_predictions(self):
return self.alive_seq[:, -1]
@property
def current_origin(self):
return self.select_indices
@property
def current_backptr(self):
# for testing
return self.select_indices.view(self.batch_size, self.beam_size)\
.fmod(self.beam_size)
[docs] def advance(self, log_probs, attn):
vocab_size = log_probs.size(-1)
# using integer division to get an integer _B without casting
_B = log_probs.shape[0] // self.beam_size
if self._stepwise_cov_pen and self._prev_penalty is not None:
self.topk_log_probs += self._prev_penalty
self.topk_log_probs -= self.global_scorer.cov_penalty(
self._coverage + attn, self.global_scorer.beta).view(
_B, self.beam_size)
# force the output to be longer than self.min_length
step = len(self)
self.ensure_min_length(log_probs)
# Multiply probs by the beam probability.
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
self.block_ngram_repeats(log_probs)
# if the sequence ends now, then the penalty is the current
# length + 1, to include the EOS token
length_penalty = self.global_scorer.length_penalty(
step + 1, alpha=self.global_scorer.alpha)
# Flatten probs into a list of possibilities.
curr_scores = log_probs / length_penalty
curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size)
torch.topk(curr_scores, self.beam_size, dim=-1,
out=(self.topk_scores, self.topk_ids))
# Recover log probs.
# Length penalty is just a scalar. It doesn't matter if it's applied
# before or after the topk.
torch.mul(self.topk_scores, length_penalty, out=self.topk_log_probs)
# Resolve beam origin and map to batch index flat representation.
torch.div(self.topk_ids, vocab_size, out=self._batch_index)
self._batch_index += self._beam_offset[:_B].unsqueeze(1)
self.select_indices = self._batch_index.view(_B * self.beam_size)
self.topk_ids.fmod_(vocab_size) # resolve true word ids
# Append last prediction.
self.alive_seq = torch.cat(
[self.alive_seq.index_select(0, self.select_indices),
self.topk_ids.view(_B * self.beam_size, 1)], -1)
if self.return_attention or self._cov_pen:
current_attn = attn.index_select(1, self.select_indices)
if step == 1:
self.alive_attn = current_attn
# update global state (step == 1)
if self._cov_pen: # coverage penalty
self._prev_penalty = torch.zeros_like(self.topk_log_probs)
self._coverage = current_attn
else:
self.alive_attn = self.alive_attn.index_select(
1, self.select_indices)
self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
# update global state (step > 1)
if self._cov_pen:
self._coverage = self._coverage.index_select(
1, self.select_indices)
self._coverage += current_attn
self._prev_penalty = self.global_scorer.cov_penalty(
self._coverage, beta=self.global_scorer.beta).view(
_B, self.beam_size)
if self._vanilla_cov_pen:
# shape: (batch_size x beam_size, 1)
cov_penalty = self.global_scorer.cov_penalty(
self._coverage,
beta=self.global_scorer.beta)
self.topk_scores -= cov_penalty.view(_B, self.beam_size)
self.is_finished = self.topk_ids.eq(self.eos)
self.ensure_max_length()
[docs] def update_finished(self):
# Penalize beams that finished.
_B_old = self.topk_log_probs.shape[0]
step = self.alive_seq.shape[-1] # 1 greater than the step in advance
self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
# on real data (newstest2017) with the pretrained transformer,
# it's faster to not move this back to the original device
self.is_finished = self.is_finished.to('cpu')
self.top_beam_finished |= self.is_finished[:, 0].eq(1)
predictions = self.alive_seq.view(_B_old, self.beam_size, step)
attention = (
self.alive_attn.view(
step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
if self.alive_attn is not None else None)
non_finished_batch = []
for i in range(self.is_finished.size(0)):
b = self._batch_offset[i]
finished_hyp = self.is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
for j in finished_hyp:
if self.ratio > 0:
s = self.topk_scores[i, j] / (step + 1)
if self.best_scores[b] < s:
self.best_scores[b] = s
self.hypotheses[b].append((
self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start_token.
attention[:, i, j, :self._memory_lengths[i]]
if attention is not None else None))
# End condition is the top beam finished and we can return
# n_best hypotheses.
if self.ratio > 0:
pred_len = self._memory_lengths[i] * self.ratio
finish_flag = ((self.topk_scores[i, 0] / pred_len)
<= self.best_scores[b]) or \
self.is_finished[i].all()
else:
finish_flag = self.top_beam_finished[i] != 0
if finish_flag and len(self.hypotheses[b]) >= self.n_best:
best_hyp = sorted(
self.hypotheses[b], key=lambda x: x[0], reverse=True)
for n, (score, pred, attn) in enumerate(best_hyp):
if n >= self.n_best:
break
self.scores[b].append(score)
self.predictions[b].append(pred)
self.attention[b].append(
attn if attn is not None else [])
else:
non_finished_batch.append(i)
non_finished = torch.tensor(non_finished_batch)
# If all sentences are translated, no need to go further.
if len(non_finished) == 0:
self.done = True
return
_B_new = non_finished.shape[0]
# Remove finished batches for the next step.
self.top_beam_finished = self.top_beam_finished.index_select(
0, non_finished)
self._batch_offset = self._batch_offset.index_select(0, non_finished)
non_finished = non_finished.to(self.topk_ids.device)
self.topk_log_probs = self.topk_log_probs.index_select(0,
non_finished)
self._batch_index = self._batch_index.index_select(0, non_finished)
self.select_indices = self._batch_index.view(_B_new * self.beam_size)
self.alive_seq = predictions.index_select(0, non_finished) \
.view(-1, self.alive_seq.size(-1))
self.topk_scores = self.topk_scores.index_select(0, non_finished)
self.topk_ids = self.topk_ids.index_select(0, non_finished)
if self.alive_attn is not None:
inp_seq_len = self.alive_attn.size(-1)
self.alive_attn = attention.index_select(1, non_finished) \
.view(step - 1, _B_new * self.beam_size, inp_seq_len)
if self._cov_pen:
self._coverage = self._coverage \
.view(1, _B_old, self.beam_size, inp_seq_len) \
.index_select(1, non_finished) \
.view(1, _B_new * self.beam_size, inp_seq_len)
if self._stepwise_cov_pen:
self._prev_penalty = self._prev_penalty.index_select(
0, non_finished)