Source code for onmt.translate.decode_strategy

import torch


[docs]class DecodeStrategy(object): """Base class for generation strategies. Args: pad (int): Magic integer in output vocab. bos (int): Magic integer in output vocab. eos (int): Magic integer in output vocab. batch_size (int): Current batch size. device (torch.device or str): Device for memory bank (encoder). parallel_paths (int): Decoding strategies like beam search use parallel paths. Each batch is repeated ``parallel_paths`` times in relevant state tensors. min_length (int): Shortest acceptable generation, not counting begin-of-sentence or end-of-sentence. max_length (int): Longest acceptable sequence, not counting begin-of-sentence (presumably there has been no EOS yet if max_length is used as a cutoff). block_ngram_repeat (int): Block beams where ``block_ngram_repeat``-grams repeat. exclusion_tokens (set[int]): If a gram contains any of these tokens, it may repeat. return_attention (bool): Whether to work with attention too. If this is true, it is assumed that the decoder is attentional. Attributes: pad (int): See above. bos (int): See above. eos (int): See above. predictions (list[list[LongTensor]]): For each batch, holds a list of beam prediction sequences. scores (list[list[FloatTensor]]): For each batch, holds a list of scores. attention (list[list[FloatTensor or list[]]]): For each batch, holds a list of attention sequence tensors (or empty lists) having shape ``(step, inp_seq_len)`` where ``inp_seq_len`` is the length of the sample (not the max length of all inp seqs). alive_seq (LongTensor): Shape ``(B x parallel_paths, step)``. This sequence grows in the ``step`` axis on each call to :func:`advance()`. is_finished (ByteTensor or NoneType): Shape ``(B, parallel_paths)``. Initialized to ``None``. alive_attn (FloatTensor or NoneType): If tensor, shape is ``(step, B x parallel_paths, inp_seq_len)``, where ``inp_seq_len`` is the (max) length of the input sequence. min_length (int): See above. max_length (int): See above. block_ngram_repeat (int): See above. exclusion_tokens (set[int]): See above. return_attention (bool): See above. done (bool): See above. """ def __init__(self, pad, bos, eos, batch_size, device, parallel_paths, min_length, block_ngram_repeat, exclusion_tokens, return_attention, max_length): # magic indices self.pad = pad self.bos = bos self.eos = eos # result caching self.predictions = [[] for _ in range(batch_size)] self.scores = [[] for _ in range(batch_size)] self.attention = [[] for _ in range(batch_size)] self.alive_seq = torch.full( [batch_size * parallel_paths, 1], self.bos, dtype=torch.long, device=device) self.is_finished = torch.zeros( [batch_size, parallel_paths], dtype=torch.uint8, device=device) self.alive_attn = None self.min_length = min_length self.max_length = max_length self.block_ngram_repeat = block_ngram_repeat self.exclusion_tokens = exclusion_tokens self.return_attention = return_attention self.done = False def __len__(self): return self.alive_seq.shape[1] def ensure_min_length(self, log_probs): if len(self) <= self.min_length: log_probs[:, self.eos] = -1e20 def ensure_max_length(self): # add one to account for BOS. Don't account for EOS because hitting # this implies it hasn't been found. if len(self) == self.max_length + 1: self.is_finished.fill_(1) def block_ngram_repeats(self, log_probs): cur_len = len(self) if self.block_ngram_repeat > 0 and cur_len > 1: for path_idx in range(self.alive_seq.shape[0]): # skip BOS hyp = self.alive_seq[path_idx, 1:] ngrams = set() fail = False gram = [] for i in range(cur_len - 1): # Last n tokens, n = block_ngram_repeat gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:] # skip the blocking if any token in gram is excluded if set(gram) & self.exclusion_tokens: continue if tuple(gram) in ngrams: fail = True ngrams.add(tuple(gram)) if fail: log_probs[path_idx] = -10e20
[docs] def advance(self, log_probs, attn): """DecodeStrategy subclasses should override :func:`advance()`. Advance is used to update ``self.alive_seq``, ``self.is_finished``, and, when appropriate, ``self.alive_attn``. """ raise NotImplementedError()
[docs] def update_finished(self): """DecodeStrategy subclasses should override :func:`update_finished()`. ``update_finished`` is used to update ``self.predictions``, ``self.scores``, and other "output" attributes. """ raise NotImplementedError()