Source code for onmt.translate.random_sampling

import torch

from onmt.translate.decode_strategy import DecodeStrategy


[docs]def sample_with_temperature(logits, sampling_temp, keep_topk): """Select next tokens randomly from the top k possible next tokens. Samples from a categorical distribution over the ``keep_topk`` words using the category probabilities ``logits / sampling_temp``. Args: logits (FloatTensor): Shaped ``(batch_size, vocab_size)``. These can be logits (``(-inf, inf)``) or log-probs (``(-inf, 0]``). (The distribution actually uses the log-probabilities ``logits - logits.logsumexp(-1)``, which equals the logits if they are log-probabilities summing to 1.) sampling_temp (float): Used to scale down logits. The higher the value, the more likely it is that a non-max word will be sampled. keep_topk (int): This many words could potentially be chosen. The other logits are set to have probability 0. Returns: (LongTensor, FloatTensor): * topk_ids: Shaped ``(batch_size, 1)``. These are the sampled word indices in the output vocab. * topk_scores: Shaped ``(batch_size, 1)``. These are essentially ``(logits / sampling_temp)[topk_ids]``. """ if sampling_temp == 0.0 or keep_topk == 1: # For temp=0.0, take the argmax to avoid divide-by-zero errors. # keep_topk=1 is also equivalent to argmax. topk_scores, topk_ids = logits.topk(1, dim=-1) if sampling_temp > 0: topk_scores /= sampling_temp else: logits = torch.div(logits, sampling_temp) if keep_topk > 0: top_values, top_indices = torch.topk(logits, keep_topk, dim=1) kth_best = top_values[:, -1].view([-1, 1]) kth_best = kth_best.repeat([1, logits.shape[1]]).float() # Set all logits that are not in the top-k to -10000. # This puts the probabilities close to 0. ignore = torch.lt(logits, kth_best) logits = logits.masked_fill(ignore, -10000) dist = torch.distributions.Multinomial( logits=logits, total_count=1) topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True) topk_scores = logits.gather(dim=1, index=topk_ids) return topk_ids, topk_scores
[docs]class RandomSampling(DecodeStrategy): """Select next tokens randomly from the top k possible next tokens. The ``scores`` attribute's lists are the score, after applying temperature, of the final prediction (either EOS or the final token in the event that ``max_length`` is reached) Args: pad (int): See base. bos (int): See base. eos (int): See base. batch_size (int): See base. device (torch.device or str): See base ``device``. min_length (int): See base. max_length (int): See base. block_ngram_repeat (int): See base. exclusion_tokens (set[int]): See base. return_attention (bool): See base. max_length (int): See base. sampling_temp (float): See :func:`~onmt.translate.random_sampling.sample_with_temperature()`. keep_topk (int): See :func:`~onmt.translate.random_sampling.sample_with_temperature()`. memory_length (LongTensor): Lengths of encodings. Used for masking attention. """ def __init__(self, pad, bos, eos, batch_size, device, min_length, block_ngram_repeat, exclusion_tokens, return_attention, max_length, sampling_temp, keep_topk, memory_length): super(RandomSampling, self).__init__( pad, bos, eos, batch_size, device, 1, min_length, block_ngram_repeat, exclusion_tokens, return_attention, max_length) self.sampling_temp = sampling_temp self.keep_topk = keep_topk self.topk_scores = None self.memory_length = memory_length self.batch_size = batch_size self.select_indices = torch.arange(self.batch_size, dtype=torch.long, device=device) self.original_batch_idx = torch.arange(self.batch_size, dtype=torch.long, device=device)
[docs] def advance(self, log_probs, attn): """Select next tokens randomly from the top k possible next tokens. Args: log_probs (FloatTensor): Shaped ``(batch_size, vocab_size)``. These can be logits (``(-inf, inf)``) or log-probs (``(-inf, 0]``). (The distribution actually uses the log-probabilities ``logits - logits.logsumexp(-1)``, which equals the logits if they are log-probabilities summing to 1.) attn (FloatTensor): Shaped ``(1, B, inp_seq_len)``. """ self.ensure_min_length(log_probs) self.block_ngram_repeats(log_probs) topk_ids, self.topk_scores = sample_with_temperature( log_probs, self.sampling_temp, self.keep_topk) self.is_finished = topk_ids.eq(self.eos) self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) if self.return_attention: if self.alive_attn is None: self.alive_attn = attn else: self.alive_attn = torch.cat([self.alive_attn, attn], 0) self.ensure_max_length()
[docs] def update_finished(self): """Finalize scores and predictions.""" # shape: (sum(~ self.is_finished), 1) finished_batches = self.is_finished.view(-1).nonzero() for b in finished_batches.view(-1): b_orig = self.original_batch_idx[b] self.scores[b_orig].append(self.topk_scores[b, 0]) self.predictions[b_orig].append(self.alive_seq[b, 1:]) self.attention[b_orig].append( self.alive_attn[:, b, :self.memory_length[b]] if self.alive_attn is not None else []) self.done = self.is_finished.all() if self.done: return is_alive = ~self.is_finished.view(-1) self.alive_seq = self.alive_seq[is_alive] if self.alive_attn is not None: self.alive_attn = self.alive_attn[:, is_alive] self.select_indices = is_alive.nonzero().view(-1) self.original_batch_idx = self.original_batch_idx[is_alive]