Source code for onmt.translate.beam

from __future__ import division
import torch
from onmt.translate import penalties

import warnings


class Beam(object):
    """Class for managing the internals of the beam search process.

    Takes care of beams, back pointers, and scores.

    Args:
        size (int): Number of beams to use.
        pad (int): Magic integer in output vocab.
        bos (int): Magic integer in output vocab.
        eos (int): Magic integer in output vocab.
        n_best (int): Don't stop until at least this many beams have
            reached EOS.
        cuda (bool): use gpu
        global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
        min_length (int): Shortest acceptable generation, not counting
            begin-of-sentence or end-of-sentence.
        stepwise_penalty (bool): Apply coverage penalty at every step.
        block_ngram_repeat (int): Block beams where
            ``block_ngram_repeat``-grams repeat.
        exclusion_tokens (set[int]): If a gram contains any of these
            token indices, it may repeat.
    """

    def __init__(self, size, pad, bos, eos,
                 n_best=1, cuda=False,
                 global_scorer=None,
                 min_length=0,
                 stepwise_penalty=False,
                 block_ngram_repeat=0,
                 exclusion_tokens=set()):

        self.size = size
        self.tt = torch.cuda if cuda else torch

        # The score for each translation on the beam.
        self.scores = self.tt.FloatTensor(size).zero_()
        self.all_scores = []

        # The backpointers at each time-step.
        self.prev_ks = []

        # The outputs at each time-step.
        self.next_ys = [self.tt.LongTensor(size)
                        .fill_(pad)]
        self.next_ys[0][0] = bos

        # Has EOS topped the beam yet.
        self._eos = eos
        self.eos_top = False

        # The attentions (matrix) for each time.
        self.attn = []

        # Time and k pair for finished.
        self.finished = []
        self.n_best = n_best

        # Information for global scoring.
        self.global_scorer = global_scorer
        self.global_state = {}

        # Minimum prediction length
        self.min_length = min_length

        # Apply Penalty at every step
        self.stepwise_penalty = stepwise_penalty
        self.block_ngram_repeat = block_ngram_repeat
        self.exclusion_tokens = exclusion_tokens

    @property
    def current_predictions(self):
        return self.next_ys[-1]

    @property
    def current_origin(self):
        """Get the backpointers for the current timestep."""
        return self.prev_ks[-1]

    def advance(self, word_probs, attn_out):
        """
        Given prob over words for every last beam `wordLk` and attention
        `attn_out`: Compute and update the beam search.

        Args:
            word_probs (FloatTensor): probs of advancing from the last step
                ``(K, words)``
            attn_out (FloatTensor): attention at the last step

        Returns:
            bool: True if beam search is complete.
        """

        num_words = word_probs.size(1)
        if self.stepwise_penalty:
            self.global_scorer.update_score(self, attn_out)
        # force the output to be longer than self.min_length
        cur_len = len(self.next_ys)
        if cur_len <= self.min_length:
            # assumes there are len(word_probs) predictions OTHER
            # than EOS that are greater than -1e20
            for k in range(len(word_probs)):
                word_probs[k][self._eos] = -1e20

        # Sum the previous scores.
        if len(self.prev_ks) > 0:
            beam_scores = word_probs + self.scores.unsqueeze(1)
            # Don't let EOS have children.
            for i in range(self.next_ys[-1].size(0)):
                if self.next_ys[-1][i] == self._eos:
                    beam_scores[i] = -1e20

            # Block ngram repeats
            if self.block_ngram_repeat > 0:
                le = len(self.next_ys)
                for j in range(self.next_ys[-1].size(0)):
                    hyp, _ = self.get_hyp(le - 1, j)
                    ngrams = set()
                    fail = False
                    gram = []
                    for i in range(le - 1):
                        # Last n tokens, n = block_ngram_repeat
                        gram = (gram +
                                [hyp[i].item()])[-self.block_ngram_repeat:]
                        # Skip the blocking if it is in the exclusion list
                        if set(gram) & self.exclusion_tokens:
                            continue
                        if tuple(gram) in ngrams:
                            fail = True
                        ngrams.add(tuple(gram))
                    if fail:
                        beam_scores[j] = -10e20
        else:
            beam_scores = word_probs[0]
        flat_beam_scores = beam_scores.view(-1)
        best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0,
                                                            True, True)

        self.all_scores.append(self.scores)
        self.scores = best_scores

        # best_scores_id is flattened beam x word array, so calculate which
        # word and beam each score came from
        prev_k = best_scores_id / num_words
        self.prev_ks.append(prev_k)
        self.next_ys.append((best_scores_id - prev_k * num_words))
        self.attn.append(attn_out.index_select(0, prev_k))
        self.global_scorer.update_global_state(self)

        for i in range(self.next_ys[-1].size(0)):
            if self.next_ys[-1][i] == self._eos:
                global_scores = self.global_scorer.score(self, self.scores)
                s = global_scores[i]
                self.finished.append((s, len(self.next_ys) - 1, i))

        # End condition is when top-of-beam is EOS and no global score.
        if self.next_ys[-1][0] == self._eos:
            self.all_scores.append(self.scores)
            self.eos_top = True

    @property
    def done(self):
        return self.eos_top and len(self.finished) >= self.n_best

    def sort_finished(self, minimum=None):
        if minimum is not None:
            i = 0
            # Add from beam until we have minimum outputs.
            while len(self.finished) < minimum:
                global_scores = self.global_scorer.score(self, self.scores)
                s = global_scores[i]
                self.finished.append((s, len(self.next_ys) - 1, i))
                i += 1

        self.finished.sort(key=lambda a: -a[0])
        scores = [sc for sc, _, _ in self.finished]
        ks = [(t, k) for _, t, k in self.finished]
        return scores, ks

    def get_hyp(self, timestep, k):
        """Walk back to construct the full hypothesis."""
        hyp, attn = [], []
        for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
            hyp.append(self.next_ys[j + 1][k])
            attn.append(self.attn[j][k])
            k = self.prev_ks[j][k]
        return hyp[::-1], torch.stack(attn[::-1])


[docs]class GNMTGlobalScorer(object): """NMT re-ranking. Args: alpha (float): Length parameter. beta (float): Coverage parameter. length_penalty (str): Length penalty strategy. coverage_penalty (str): Coverage penalty strategy. Attributes: alpha (float): See above. beta (float): See above. length_penalty (callable): See :class:`penalties.PenaltyBuilder`. coverage_penalty (callable): See :class:`penalties.PenaltyBuilder`. has_cov_pen (bool): See :class:`penalties.PenaltyBuilder`. has_len_pen (bool): See :class:`penalties.PenaltyBuilder`. """ @classmethod def from_opt(cls, opt): return cls( opt.alpha, opt.beta, opt.length_penalty, opt.coverage_penalty) def __init__(self, alpha, beta, length_penalty, coverage_penalty): self._validate(alpha, beta, length_penalty, coverage_penalty) self.alpha = alpha self.beta = beta penalty_builder = penalties.PenaltyBuilder(coverage_penalty, length_penalty) self.has_cov_pen = penalty_builder.has_cov_pen # Term will be subtracted from probability self.cov_penalty = penalty_builder.coverage_penalty self.has_len_pen = penalty_builder.has_len_pen # Probability will be divided by this self.length_penalty = penalty_builder.length_penalty @classmethod def _validate(cls, alpha, beta, length_penalty, coverage_penalty): # these warnings indicate that either the alpha/beta # forces a penalty to be a no-op, or a penalty is a no-op but # the alpha/beta would suggest otherwise. if length_penalty is None or length_penalty == "none": if alpha != 0: warnings.warn("Non-default `alpha` with no length penalty. " "`alpha` has no effect.") else: # using some length penalty if length_penalty == "wu" and alpha == 0.: warnings.warn("Using length penalty Wu with alpha==0 " "is equivalent to using length penalty none.") if coverage_penalty is None or coverage_penalty == "none": if beta != 0: warnings.warn("Non-default `beta` with no coverage penalty. " "`beta` has no effect.") else: # using some coverage penalty if beta == 0.: warnings.warn("Non-default coverage penalty with beta==0 " "is equivalent to using coverage penalty none.")
[docs] def score(self, beam, logprobs): """Rescore a prediction based on penalty functions.""" len_pen = self.length_penalty(len(beam.next_ys), self.alpha) normalized_probs = logprobs / len_pen if not beam.stepwise_penalty: penalty = self.cov_penalty(beam.global_state["coverage"], self.beta) normalized_probs -= penalty return normalized_probs
[docs] def update_score(self, beam, attn): """Update scores of a Beam that is not finished.""" if "prev_penalty" in beam.global_state.keys(): beam.scores.add_(beam.global_state["prev_penalty"]) penalty = self.cov_penalty(beam.global_state["coverage"] + attn, self.beta) beam.scores.sub_(penalty)
[docs] def update_global_state(self, beam): """Keeps the coverage vector as sum of attentions.""" if len(beam.prev_ks) == 1: beam.global_state["prev_penalty"] = beam.scores.clone().fill_(0.0) beam.global_state["coverage"] = beam.attn[-1] self.cov_total = beam.attn[-1].sum(1) else: self.cov_total += torch.min(beam.attn[-1], beam.global_state['coverage']).sum(1) beam.global_state["coverage"] = beam.global_state["coverage"] \ .index_select(0, beam.prev_ks[-1]).add(beam.attn[-1]) prev_penalty = self.cov_penalty(beam.global_state["coverage"], self.beta) beam.global_state["prev_penalty"] = prev_penalty