Source code for onmt.ModelConstructor

"""
This file is for models creation, which consults options
and creates each encoder and decoder accordingly.
"""
import torch.nn as nn

import onmt
import onmt.Models
import onmt.modules
from onmt.Models import NMTModel, MeanEncoder, RNNEncoder, \
                        StdRNNDecoder, InputFeedRNNDecoder
from onmt.modules import Embeddings, ImageEncoder, CopyGenerator, \
                         TransformerEncoder, TransformerDecoder, \
                         CNNEncoder, CNNDecoder, AudioEncoder


[docs]def make_embeddings(opt, word_dict, feature_dicts, for_encoder=True): """ Make an Embeddings instance. Args: opt: the option in current environment. word_dict(Vocab): words dictionary. feature_dicts([Vocab], optional): a list of feature dictionary. for_encoder(bool): make Embeddings for encoder or decoder? """ if for_encoder: embedding_dim = opt.src_word_vec_size else: embedding_dim = opt.tgt_word_vec_size word_padding_idx = word_dict.stoi[onmt.IO.PAD_WORD] num_word_embeddings = len(word_dict) feats_padding_idx = [feat_dict.stoi[onmt.IO.PAD_WORD] for feat_dict in feature_dicts] num_feat_embeddings = [len(feat_dict) for feat_dict in feature_dicts] return Embeddings(embedding_dim, opt.position_encoding, opt.feat_merge, opt.feat_vec_exponent, opt.feat_vec_size, opt.dropout, word_padding_idx, feats_padding_idx, num_word_embeddings, num_feat_embeddings)
[docs]def make_encoder(opt, embeddings): """ Various encoder dispatcher function. Args: opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this encoder. """ if opt.encoder_type == "transformer": return TransformerEncoder(opt.enc_layers, opt.rnn_size, opt.dropout, embeddings) elif opt.encoder_type == "cnn": return CNNEncoder(opt.enc_layers, opt.rnn_size, opt.cnn_kernel_width, opt.dropout, embeddings) elif opt.encoder_type == "mean": return MeanEncoder(opt.enc_layers, embeddings) else: # "rnn" or "brnn" return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers, opt.rnn_size, opt.dropout, embeddings)
[docs]def make_decoder(opt, embeddings): """ Various decoder dispatcher function. Args: opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this decoder. """ if opt.decoder_type == "transformer": return TransformerDecoder(opt.dec_layers, opt.rnn_size, opt.global_attention, opt.copy_attn, opt.dropout, embeddings) elif opt.decoder_type == "cnn": return CNNDecoder(opt.dec_layers, opt.rnn_size, opt.global_attention, opt.copy_attn, opt.cnn_kernel_width, opt.dropout, embeddings) elif opt.input_feed: return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers, opt.rnn_size, opt.global_attention, opt.coverage_attn, opt.context_gate, opt.copy_attn, opt.dropout, embeddings) else: return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers, opt.rnn_size, opt.global_attention, opt.coverage_attn, opt.context_gate, opt.copy_attn, opt.dropout, embeddings)
[docs]def make_base_model(model_opt, fields, gpu, checkpoint=None): """ Args: model_opt: the option loaded from checkpoint. fields: `Field` objects for the model. gpu(bool): whether to use gpu. checkpoint: the model gnerated by train phase, or a resumed snapshot model from a stopped training. Returns: the NMTModel. """ assert model_opt.model_type in ["text", "img", "audio"], \ ("Unsupported model type %s" % (model_opt.model_type)) # Make encoder. if model_opt.model_type == "text": src_dict = fields["src"].vocab feature_dicts = onmt.IO.collect_feature_dicts(fields, 'src') src_embeddings = make_embeddings(model_opt, src_dict, feature_dicts) encoder = make_encoder(model_opt, src_embeddings) elif model_opt.model_type == "img": encoder = ImageEncoder(model_opt.enc_layers, model_opt.brnn, model_opt.rnn_size, model_opt.dropout) elif model_opt.model_type == "audio": encoder = AudioEncoder(model_opt.enc_layers, model_opt.brnn, model_opt.rnn_size, model_opt.dropout, model_opt.sample_rate, model_opt.window_size) # Make decoder. tgt_dict = fields["tgt"].vocab # TODO: prepare for a future where tgt features are possible. feature_dicts = onmt.IO.collect_feature_dicts(fields, 'tgt') tgt_embeddings = make_embeddings(model_opt, tgt_dict, feature_dicts, for_encoder=False) # Share the embedding matrix - preprocess with share_vocab required if model_opt.share_embeddings: tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight decoder = make_decoder(model_opt, tgt_embeddings) # Make NMTModel(= encoder + decoder). model = NMTModel(encoder, decoder) model.model_type = model_opt.model_type # Make Generator. if not model_opt.copy_attn: generator = nn.Sequential( nn.Linear(model_opt.rnn_size, len(fields["tgt"].vocab)), nn.LogSoftmax()) if model_opt.share_decoder_embeddings: generator[0].weight = decoder.embeddings.word_lut.weight else: generator = CopyGenerator(model_opt, fields["src"].vocab, fields["tgt"].vocab) # Load the model states from checkpoint or initialize them. if checkpoint is not None: print('Loading model parameters.') model.load_state_dict(checkpoint['model']) generator.load_state_dict(checkpoint['generator']) else: if model_opt.param_init != 0.0: print('Intializing model parameters.') for p in model.parameters(): p.data.uniform_(-model_opt.param_init, model_opt.param_init) for p in generator.parameters(): p.data.uniform_(-model_opt.param_init, model_opt.param_init) if hasattr(model.encoder, 'embeddings'): model.encoder.embeddings.load_pretrained_vectors( model_opt.pre_word_vecs_enc, model_opt.fix_word_vecs_enc) if hasattr(model.decoder, 'embeddings'): model.decoder.embeddings.load_pretrained_vectors( model_opt.pre_word_vecs_dec, model_opt.fix_word_vecs_dec) # Add generator to model (this registers it as parameter of model). model.generator = generator # Make the whole model leverage GPU if indicated to do so. if gpu: model.cuda() else: model.cpu() return model