opennmt.models.sequence_to_sequence module

Standard sequence-to-sequence model.

opennmt.models.sequence_to_sequence.shift_target_sequence(inputter, data)[source]

Prepares shifted target sequences.

Given a target sequence a b c, the decoder input should be a b c and the output should be a b c for the dynamic decoding to start on and stop on .


The updated data dictionary with ids the sequence prefixed with the start token id and ids_out the sequence suffixed with the end token id. Additionally, the length is increased by 1 to reflect the added token on both sequences.

class opennmt.models.sequence_to_sequence.EmbeddingsSharingLevel[source]

Bases: object

Level of embeddings sharing.

Possible values are:

  • NONE: no sharing (default)
  • SOURCE_TARGET_INPUT: share source and target word embeddings
  • TARGET: share target word embeddings and softmax weights
  • ALL: share words embeddings and softmax weights
NONE = 0
ALL = 3
static share_input_embeddings(level)[source]

Returns True if input embeddings should be shared at level.

static share_target_embeddings(level)[source]

Returns True if target embeddings should be shared at level.

class opennmt.models.sequence_to_sequence.SequenceToSequence(source_inputter, target_inputter, encoder, decoder, share_embeddings=0, alignment_file_key='train_alignments', daisy_chain_variables=False, name='seq2seq')[source]

Bases: opennmt.models.model.Model

A sequence to sequence model.

__init__(source_inputter, target_inputter, encoder, decoder, share_embeddings=0, alignment_file_key='train_alignments', daisy_chain_variables=False, name='seq2seq')[source]

Initializes a sequence-to-sequence model.


TypeError – if target_inputter is not a opennmt.inputters.text_inputter.WordEmbedder (same for source_inputter when embeddings sharing is enabled) or if source_inputter and target_inputter do not have the same dtype.


Returns automatic configuration values specific to this model.

Parameters:num_devices – The number of devices used for the training.
Returns:A partial training configuration.
compute_loss(outputs, labels, training=True, params=None)[source]

Computes the loss.

  • outputs – The model outputs (usually unscaled probabilities).
  • labels – The dict of labels tf.Tensor.
  • training – Compute training loss.
  • params – A dictionary of hyperparameters.

The loss or a tuple containing the computed loss and the loss to display.

print_prediction(prediction, params=None, stream=None)[source]

Prints the model prediction.

  • prediction – The evaluated prediction.
  • params – (optional) Dictionary of formatting parameters.
  • stream – (optional) The stream to print to.
class opennmt.models.sequence_to_sequence.SequenceToSequenceInputter(features_inputter, labels_inputter, share_parameters=False, alignment_file_key=None)[source]

Bases: opennmt.inputters.inputter.ExampleInputter

A custom opennmt.inputters.inputter.ExampleInputter that possibly injects alignment information during training.

initialize(metadata, asset_dir=None, asset_prefix='')[source]

Initializes the inputter.

  • metadata – A dictionary containing additional metadata set by the user.
  • asset_dir – The directory where assets can be written. If None, no assets are returned.
  • asset_prefix – The prefix to attach to assets filename.

A dictionary containing additional assets used by the inputter.

make_dataset(data_file, training=None)[source]

Creates the base dataset required by this inputter.

  • data_file – The data file.
  • training – Run in training mode.


make_features(element=None, features=None, training=None)[source]

Creates features from data.

  • element – An element from the dataset.
  • features – An optional dictionary of features to augment.
  • training – Run in training mode.

A dictionary of tf.Tensor.

opennmt.models.sequence_to_sequence.alignment_matrix_from_pharaoh(alignment_line, source_length, target_length, dtype=tf.float32)[source]

Parse Pharaoh alignments into an alignment matrix.

  • alignment_line – A string tf.Tensor in the Pharaoh format.
  • source_length – The length of the source sentence, without special symbols.
  • The length of the target sentence, without special symbols. (target_length) –
  • dtype – The output matrix dtype. Defaults to tf.float32 for convenience when computing the guided alignment loss.

The alignment matrix as a 2-D tf.Tensor of type dtype and shape [target_length, source_length], where [i, j] = 1 if the i th target word is aligned with the j th source word.

opennmt.models.sequence_to_sequence.guided_alignment_cost(attention_probs, gold_alignment, sequence_length, guided_alignment_type, guided_alignment_weight=1)[source]

Computes the guided alignment cost.

  • attention_probs – The attention probabilities, a float tf.Tensor of shape \([B, T_t, T_s]\).
  • gold_alignment – The true alignment matrix, a float tf.Tensor of shape \([B, T_t, T_s]\).
  • sequence_length – The length of each sequence.
  • guided_alignment_type – The type of guided alignment cost function to compute (can be: ce, mse).
  • guided_alignment_weight – The weight applied to the guided alignment cost.

The guided alignment cost.

opennmt.models.sequence_to_sequence.align_tokens_from_attention(tokens, attention)[source]

Returns aligned tokens from the attention.

  • tokens – The tokens on which the attention is applied as a string tf.Tensor of shape \([B, T_s]\).
  • attention – The attention vector of shape \([B, T_t, T_s]\).

The aligned tokens as a string tf.Tensor of shape \([B, T_t]\).

opennmt.models.sequence_to_sequence.replace_unknown_target(target_tokens, source_tokens, attention, unknown_token='')[source]

Replaces all target unknown tokens by the source token with the highest attention.

  • target_tokens – A a string tf.Tensor of shape \([B, T_t]\).
  • source_tokens – A a string tf.Tensor of shape \([B, T_s]\).
  • attention – The attention vector of shape \([B, T_t, T_s]\).
  • unknown_token – The target token to replace.

A string tf.Tensor with the same shape and type as target_tokens but will all instances of unknown_token replaced by the aligned source token.