Commit e34bd2c0 authored by Vincent Nguyen's avatar Vincent Nguyen Committed by GitHub
Browse files

Revert "fix onmt as library example (#1292)"

This reverts commit 492f0cf9.
No related merge requests found
Showing with 120 additions and 137 deletions
+120 -137
......@@ -3,7 +3,7 @@
For this example, we will assume that we have run preprocess to
create our datasets. For instance
> python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 10000 -tgt_vocab_size 10000
> python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000
......@@ -21,208 +21,191 @@ We begin by loading in the vocabulary for the model of interest. This will let u
```python
vocab_fields = torch.load("data/data.vocab.pt")
src_text_field = vocab_fields["src"].base_field
src_vocab = src_text_field.vocab
src_padding = src_vocab.stoi[src_text_field.pad_token]
tgt_text_field = vocab_fields['tgt'].base_field
tgt_vocab = tgt_text_field.vocab
tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]
vocab = dict(torch.load("../../data/data.vocab.pt"))
src_padding = vocab["src"].stoi[onmt.inputters.PAD_WORD]
tgt_padding = vocab["tgt"].stoi[onmt.inputters.PAD_WORD]
```
Next we specify the core model itself. Here we will build a small model with an encoder and an attention based input feeding decoder. Both models will be RNNs and the encoder will be bidirectional
```python
emb_size = 100
rnn_size = 500
emb_size = 10
rnn_size = 6
# Specify the core model.
encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),
encoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab["src"]),
word_padding_idx=src_padding)
encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,
rnn_type="LSTM", bidirectional=True,
embeddings=encoder_embeddings)
rnn_type="LSTM", bidirectional=True,
embeddings=encoder_embeddings)
decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),
decoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab["tgt"]),
word_padding_idx=tgt_padding)
decoder = onmt.decoders.decoder.InputFeedRNNDecoder(
hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True,
rnn_type="LSTM", embeddings=decoder_embeddings)
device = "cuda" if torch.cuda.is_available() else "cpu"
decoder = onmt.decoders.decoder.InputFeedRNNDecoder(hidden_size=rnn_size, num_layers=1,
bidirectional_encoder=True,
rnn_type="LSTM", embeddings=decoder_embeddings)
model = onmt.models.model.NMTModel(encoder, decoder)
model.to(device)
# Specify the tgt word generator and loss computation module
model.generator = nn.Sequential(
nn.Linear(rnn_size, len(tgt_vocab)),
nn.LogSoftmax(dim=-1))
loss = onmt.utils.loss.NMTLossCompute(
criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction="sum"),
generator=model.generator)
nn.Linear(rnn_size, len(vocab["tgt"])),
nn.LogSoftmax())
loss = onmt.utils.loss.NMTLossCompute(model.generator, vocab["tgt"])
```
Now we set up the optimizer. Our wrapper around a core torch optim class handles learning rate updates and gradient normalization automatically.
Now we set up the optimizer. This could be a core torch optim class, or our wrapper which handles learning rate updates and gradient normalization automatically.
```python
lr = 1
torch_optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optim = onmt.utils.optimizers.Optimizer(
torch_optimizer, learning_rate=lr, max_grad_norm=2)
optim = onmt.utils.optimizers.Optimizer(method="sgd", learning_rate=1, max_grad_norm=2)
optim.set_parameters(model.named_parameters())
```
Now we load the data from disk with the associated vocab fields. To iterate through the data itself we use a wrapper around a torchtext iterator class. We specify one for both the training and test data.
Now we load the data from disk. Currently will need to call a function to load the fields into the data as well.
```python
# Load some data
from itertools import chain
train_data_file = "data/data.train.0.pt"
valid_data_file = "data/data.valid.0.pt"
train_iter = onmt.inputters.inputter.DatasetLazyIter(dataset_paths=[train_data_file],
fields=vocab_fields,
batch_size=50,
batch_size_multiple=1,
batch_size_fn=None,
device=device,
is_train=True,
repeat=True)
valid_iter = onmt.inputters.inputter.DatasetLazyIter(dataset_paths=[valid_data_file],
fields=vocab_fields,
batch_size=10,
batch_size_multiple=1,
batch_size_fn=None,
device=device,
is_train=False,
repeat=False)
data = torch.load("../../data/data.train.1.pt")
valid_data = torch.load("../../data/data.valid.1.pt")
data.load_fields(vocab)
valid_data.load_fields(vocab)
data.examples = data.examples[:100]
```
Finally we train. Keeping track of the output requires a report manager.
To iterate through the data itself we use a torchtext iterator class. We specify one for both the training and test data.
```python
report_manager = onmt.utils.ReportMgr(
report_every=50, start_time=None, tensorboard_writer=None)
trainer = onmt.Trainer(model=model,
train_loss=loss,
valid_loss=loss,
optim=optim,
report_manager=report_manager)
trainer.train(train_iter=train_iter,
train_steps=400,
valid_iter=valid_iter,
valid_steps=200)
train_iter = onmt.inputters.OrderedIterator(
dataset=data, batch_size=10,
device=-1,
repeat=False)
valid_iter = onmt.inputters.OrderedIterator(
dataset=valid_data, batch_size=10,
device=-1,
train=False)
```
```
[2019-02-15 16:34:17,475 INFO] Start training loop and validate every 200 steps...
[2019-02-15 16:34:17,601 INFO] Loading dataset from data/data.train.0.pt, number of examples: 10000
[2019-02-15 16:35:43,873 INFO] Step 50/ 400; acc: 11.54; ppl: 1714.07; xent: 7.45; lr: 1.00000; 662/656 tok/s; 86 sec
[2019-02-15 16:37:05,965 INFO] Step 100/ 400; acc: 13.75; ppl: 534.80; xent: 6.28; lr: 1.00000; 675/671 tok/s; 168 sec
[2019-02-15 16:38:31,289 INFO] Step 150/ 400; acc: 15.02; ppl: 439.96; xent: 6.09; lr: 1.00000; 675/668 tok/s; 254 sec
[2019-02-15 16:39:56,715 INFO] Step 200/ 400; acc: 16.08; ppl: 357.62; xent: 5.88; lr: 1.00000; 642/647 tok/s; 339 sec
[2019-02-15 16:39:56,811 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000
[2019-02-15 16:41:13,415 INFO] Validation perplexity: 208.73
[2019-02-15 16:41:13,415 INFO] Validation accuracy: 23.3507
[2019-02-15 16:41:13,567 INFO] Loading dataset from data/data.train.0.pt, number of examples: 10000
[2019-02-15 16:42:41,562 INFO] Step 250/ 400; acc: 17.07; ppl: 310.41; xent: 5.74; lr: 1.00000; 347/344 tok/s; 504 sec
[2019-02-15 16:44:04,899 INFO] Step 300/ 400; acc: 19.17; ppl: 262.81; xent: 5.57; lr: 1.00000; 665/661 tok/s; 587 sec
[2019-02-15 16:45:33,653 INFO] Step 350/ 400; acc: 19.38; ppl: 244.81; xent: 5.50; lr: 1.00000; 649/642 tok/s; 676 sec
[2019-02-15 16:47:06,141 INFO] Step 400/ 400; acc: 20.44; ppl: 214.75; xent: 5.37; lr: 1.00000; 593/598 tok/s; 769 sec
[2019-02-15 16:47:06,265 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000
[2019-02-15 16:48:27,328 INFO] Validation perplexity: 150.277
[2019-02-15 16:48:27,328 INFO] Validation accuracy: 24.2132
Finally we train.
```python
trainer = onmt.Trainer(model, loss, loss, optim)
def report_func(*args):
stats = args[-1]
stats.output(args[0], args[1], 10, 0)
return stats
for epoch in range(2):
trainer.train(epoch, report_func)
val_stats = trainer.validate()
print("Validation")
val_stats.output(epoch, 11, 10, 0)
trainer.epoch_step(val_stats.ppl(), epoch)
```
To use the model, we need to load up the translation functions. A Translator object requires the vocab fields, readers for source and target and a global scorer.
Epoch 0, 0/ 10; acc: 0.00; ppl: 1225.23; 1320 src tok/s; 1320 tgt tok/s; 1514090454 s elapsed
Epoch 0, 1/ 10; acc: 9.50; ppl: 996.33; 1188 src tok/s; 1194 tgt tok/s; 1514090454 s elapsed
Epoch 0, 2/ 10; acc: 16.51; ppl: 694.48; 1265 src tok/s; 1267 tgt tok/s; 1514090454 s elapsed
Epoch 0, 3/ 10; acc: 20.49; ppl: 470.39; 1459 src tok/s; 1420 tgt tok/s; 1514090454 s elapsed
Epoch 0, 4/ 10; acc: 22.68; ppl: 387.03; 1511 src tok/s; 1462 tgt tok/s; 1514090454 s elapsed
Epoch 0, 5/ 10; acc: 24.58; ppl: 345.44; 1625 src tok/s; 1509 tgt tok/s; 1514090454 s elapsed
Epoch 0, 6/ 10; acc: 25.37; ppl: 314.39; 1586 src tok/s; 1493 tgt tok/s; 1514090454 s elapsed
Epoch 0, 7/ 10; acc: 26.14; ppl: 291.15; 1593 src tok/s; 1520 tgt tok/s; 1514090455 s elapsed
Epoch 0, 8/ 10; acc: 26.32; ppl: 274.79; 1606 src tok/s; 1545 tgt tok/s; 1514090455 s elapsed
Epoch 0, 9/ 10; acc: 26.83; ppl: 247.32; 1669 src tok/s; 1614 tgt tok/s; 1514090455 s elapsed
Validation
Epoch 0, 11/ 10; acc: 13.41; ppl: 111.94; 0 src tok/s; 7329 tgt tok/s; 1514090464 s elapsed
Epoch 1, 0/ 10; acc: 6.59; ppl: 147.05; 1849 src tok/s; 1743 tgt tok/s; 1514090464 s elapsed
Epoch 1, 1/ 10; acc: 22.10; ppl: 130.66; 2002 src tok/s; 1957 tgt tok/s; 1514090464 s elapsed
Epoch 1, 2/ 10; acc: 20.16; ppl: 122.49; 1748 src tok/s; 1760 tgt tok/s; 1514090464 s elapsed
Epoch 1, 3/ 10; acc: 23.52; ppl: 117.41; 1690 src tok/s; 1698 tgt tok/s; 1514090464 s elapsed
Epoch 1, 4/ 10; acc: 24.16; ppl: 119.42; 1647 src tok/s; 1662 tgt tok/s; 1514090464 s elapsed
Epoch 1, 5/ 10; acc: 25.44; ppl: 115.31; 1775 src tok/s; 1709 tgt tok/s; 1514090465 s elapsed
Epoch 1, 6/ 10; acc: 24.05; ppl: 115.11; 1780 src tok/s; 1718 tgt tok/s; 1514090465 s elapsed
Epoch 1, 7/ 10; acc: 25.32; ppl: 109.59; 1799 src tok/s; 1765 tgt tok/s; 1514090465 s elapsed
Epoch 1, 8/ 10; acc: 25.14; ppl: 108.16; 1771 src tok/s; 1734 tgt tok/s; 1514090465 s elapsed
Epoch 1, 9/ 10; acc: 25.58; ppl: 107.13; 1817 src tok/s; 1757 tgt tok/s; 1514090465 s elapsed
Validation
Epoch 1, 11/ 10; acc: 19.58; ppl: 88.09; 0 src tok/s; 7371 tgt tok/s; 1514090474 s elapsed
To use the model, we need to load up the translation functions
```python
import onmt.translate
```
src_reader = onmt.inputters.str2reader["text"]
tgt_reader = onmt.inputters.str2reader["text"]
scorer = onmt.translate.GNMTGlobalScorer(alpha=0.7,
beta=0.,
length_penalty="avg",
coverage_penalty="none")
gpu = 0 if torch.cuda.is_available() else -1
translator = onmt.translate.Translator(model=model,
fields=vocab_fields,
src_reader=src_reader,
tgt_reader=tgt_reader,
global_scorer=scorer,
gpu=gpu)
builder = onmt.translate.TranslationBuilder(data=torch.load(valid_data_file),
fields=vocab_fields)
```python
translator = onmt.translate.Translator(beam_size=10, fields=data.fields, model=model)
builder = onmt.translate.TranslationBuilder(data=valid_data, fields=data.fields)
valid_data.src_vocabs
for batch in valid_iter:
trans_batch = translator.translate_batch(
batch=batch, src_vocabs=[src_vocab],
attn_debug=False)
trans_batch = translator.translate_batch(batch=batch, data=valid_data)
translations = builder.from_batch(trans_batch)
for trans in translations:
print(trans.log(0))
break
```
```
[2019-02-15 16:48:27,419 INFO] Loading dataset from data/data.valid.0.pt, number of examples: 3000
PRED SCORE: -4.0690
SENT 0: ('The', 'competitors', 'have', 'other', 'advantages', ',', 'too', '.')
PRED 0: .
PRED SCORE: -4.2736
SENT 0: ['Parliament', 'Does', 'Not', 'Support', 'Amendment', 'Freeing', 'Tymoshenko']
PRED 0: <unk> ist ein <unk> <unk> <unk> .
PRED SCORE: -1.0983
SENT 0: ('The', 'company', '&apos;s', 'durability', 'goes', 'back', 'to', 'its', 'first', 'boss', ',', 'a', 'visionary', ',', 'Thomas', 'J.', 'Watson', 'Sr.')
PRED 0: .
PRED SCORE: -4.0144
SENT 0: ['Today', ',', 'the', 'Ukraine', 'parliament', 'dismissed', ',', 'within', 'the', 'Code', 'of', 'Criminal', 'Procedure', 'amendment', ',', 'the', 'motion', 'to', 'revoke', 'an', 'article', 'based', 'on', 'which', 'the', 'opposition', 'leader', ',', 'Yulia', 'Tymoshenko', ',', 'was', 'sentenced', '.']
PRED 0: <unk> ist das <unk> <unk> .
PRED SCORE: -1.5950
SENT 0: ('&quot;', 'From', 'what', 'we', 'know', 'today', ',', 'you', 'have', 'to', 'ask', 'how', 'I', 'could', 'be', 'so', 'wrong', '.', '&quot;')
PRED 0: .
PRED SCORE: -4.1361
SENT 0: ['The', 'amendment', 'that', 'would', 'lead', 'to', 'freeing', 'the', 'imprisoned', 'former', 'Prime', 'Minister', 'was', 'revoked', 'during', 'second', 'reading', 'of', 'the', 'proposal', 'for', 'mitigation', 'of', 'sentences', 'for', 'economic', 'offences', '.']
PRED 0: Es gibt es das <unk> der <unk> für <unk> <unk> .
PRED SCORE: -1.5128
SENT 0: ('Boeing', 'Co', 'shares', 'rose', '1.5%', 'to', '$', '67.94', '.')
PRED 0: .
PRED SCORE: -4.1382
SENT 0: ['In', 'October', ',', 'Tymoshenko', 'was', 'sentenced', 'to', 'seven', 'years', 'in', 'prison', 'for', 'entering', 'into', 'what', 'was', 'reported', 'to', 'be', 'a', 'disadvantageous', 'gas', 'deal', 'with', 'Russia', '.']
PRED 0: <unk> ist ein <unk> <unk> .
PRED SCORE: -1.5578
SENT 0: ('Some', 'did', 'not', 'believe', 'him', ',', 'they', 'said', 'that', 'he', 'got', 'dizzy', 'even', 'in', 'the', 'truck', ',', 'but', 'always', 'wanted', 'to', 'fulfill', 'his', 'dream', ',', 'that', 'of', 'becoming', 'a', 'pilot', '.')
PRED 0: .
PRED SCORE: -3.8881
SENT 0: ['The', 'verdict', 'is', 'not', 'yet', 'final;', 'the', 'court', 'will', 'hear', 'Tymoshenko', '&apos;s', 'appeal', 'in', 'December', '.']
PRED 0: <unk> ist nicht <unk> .
PRED SCORE: -0.9623
SENT 0: ('In', 'your', 'opinion', ',', 'the', 'council', 'should', 'ensure', 'that', 'the', 'band', 'immediately', 'above', 'the', 'Ronda', 'de', 'Dalt', 'should', 'provide', 'in', 'its', 'entirety', ',', 'an', 'area', 'of', 'equipment', 'to', 'conduct', 'a', 'smooth', 'transition', 'between', 'the', 'city', 'and', 'the', 'green', '.')
PRED 0: .
PRED SCORE: -4.0778
SENT 0: ['Tymoshenko', 'claims', 'the', 'verdict', 'is', 'a', 'political', 'revenge', 'of', 'the', 'regime;', 'in', 'the', 'West', ',', 'the', 'trial', 'has', 'also', 'evoked', 'suspicion', 'of', 'being', 'biased', '.']
PRED 0: <unk> ist ein <unk> <unk> .
PRED SCORE: -0.8703
SENT 0: ('The', 'clerk', 'of', 'the', 'court', ',', 'Jorge', 'Yanez', ',', 'went', 'to', 'the', 'jail', 'of', 'the', 'municipality', 'of', 'San', 'Nicolas', 'of', 'Garza', 'to', 'notify', 'Jonah', 'that', 'he', 'has', 'been', 'legally', 'pardoned', 'and', 'his', 'record', 'will', 'be', 'filed', '.')
PRED 0: .
PRED SCORE: -4.2479
SENT 0: ['The', 'proposal', 'to', 'remove', 'Article', '365', 'from', 'the', 'Code', 'of', 'Criminal', 'Procedure', ',', 'upon', 'which', 'the', 'former', 'Prime', 'Minister', 'was', 'sentenced', ',', 'was', 'supported', 'by', '147', 'members', 'of', 'parliament', '.']
PRED 0: <unk> Sie sich mit <unk> <unk> .
PRED SCORE: -1.4778
SENT 0: ('&quot;', 'In', 'a', 'research', 'it', 'is', 'reported', 'that', 'there', 'are', 'no', 'parts', 'or', 'components', 'of', 'the', 'ship', 'in', 'another', 'place', ',', 'the', 'impact', 'is', 'presented', 'in', 'a', 'structural', 'way', '.')
PRED 0: .
PRED SCORE: -3.8585
SENT 0: ['Its', 'ratification', 'would', 'require', '226', 'votes', '.']
PRED 0: <unk> Sie sich <unk> .
PRED SCORE: -1.3341
SENT 0: ('On', 'the', 'asphalt', 'covering', ',', 'he', 'added', ',', 'is', 'placed', 'a', 'final', 'layer', 'called', 'rolling', 'covering', ',', 'which', 'is', 'made', '\u200b', '\u200b', 'of', 'a', 'fine', 'stone', 'material', ',', 'meaning', 'sand', 'also', 'dipped', 'into', 'the', 'asphalt', '.')
PRED 0: .
PRED SCORE: -4.2298
SENT 0: ['Libya', '&apos;s', 'Victory']
PRED 0: <unk> Sie die <unk> <unk> .
PRED SCORE: -1.5192
SENT 0: ('This', 'is', '200', 'bar', 'on', 'leaving', 'and', '100', 'bar', 'on', 'arrival', '.')
PRED 0: .
SENT 0: ['The', 'story', 'of', 'Libya', '&apos;s', 'liberation', ',', 'or', 'rebellion', ',', 'already', 'has', 'its', 'defeated', '.']
PRED 0: <unk> ist ein <unk> <unk> .
PRED SCORE: -1.2772
...
/usr/local/lib/python3.5/dist-packages/torch/tensor.py:297: UserWarning: other is not broadcastable to self, but they have the same number of elements. Falling back to deprecated pointwise behavior.
return self.add_(other)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment