opennmt.models.sequence_classifier module

Sequence classifier.

class opennmt.models.sequence_classifier.SequenceClassifier(inputter, encoder, labels_vocabulary_file_key, encoding='average', daisy_chain_variables=False, name='seqclassifier')[source]

Bases: opennmt.models.model.Model

A sequence classifier.

__init__(inputter, encoder, labels_vocabulary_file_key, encoding='average', daisy_chain_variables=False, name='seqclassifier')[source]

Initializes a sequence classifier.

Parameters:
  • inputter – A opennmt.inputters.inputter.Inputter to process the input data.
  • encoder – A opennmt.encoders.encoder.Encoder to encode the input.
  • labels_vocabulary_file_key – The data configuration key of the labels vocabulary file containing one label per line.
  • encoding – “average” or “last” (case insensitive), the encoding vector to extract from the encoder outputs.
  • daisy_chain_variables – If True, copy variables in a daisy chain between devices for this model. Not compatible with RNN based models.
  • name – The name of this model.
Raises:

ValueError – if encoding is invalid.

compute_loss(outputs, labels, training=True, params=None)[source]

Computes the loss.

Parameters:
  • outputs – The model outputs (usually unscaled probabilities).
  • labels – The dict of labels tf.Tensor.
  • training – Compute training loss.
  • params – A dictionary of hyperparameters.
Returns:

The loss or a tuple containing the computed loss and the loss to display.

compute_metrics(predictions, labels)[source]

Computes additional metrics on the predictions.

Parameters:
  • predictions – The model predictions.
  • labels – The dict of labels tf.Tensor.
Returns:

A dict of metrics. See the eval_metric_ops field of tf.estimator.EstimatorSpec.

print_prediction(prediction, params=None, stream=None)[source]

Prints the model prediction.

Parameters:
  • prediction – The evaluated prediction.
  • params – (optional) Dictionary of formatting parameters.
  • stream – (optional) The stream to print to.
class opennmt.models.sequence_classifier.ClassInputter(vocabulary_file_key)[source]

Bases: opennmt.inputters.text_inputter.TextInputter

Reading class from a text file.

make_features(element=None, features=None, training=None)[source]

Tokenizes raw text.