Unverified Commit e723f2aa authored by Vincent Nguyen's avatar Vincent Nguyen Committed by GitHub
Browse files

Merge pull request #926 from vince62s/fix-bptt

Fix bptt cf #891
Showing with 25 additions and 11 deletions
+25 -11
......@@ -254,7 +254,7 @@ class Trainer(object):
for batch in true_batchs:
target_size = batch.tgt.size(0)
# Truncated BPTT
# Truncated BPTT: reminder not compatible with accum > 1
if self.trunc_size:
trunc_size = self.trunc_size
else:
......@@ -287,20 +287,31 @@ class Trainer(object):
total_stats.update(batch_stats)
report_stats.update(batch_stats)
# 4. Update the parameters and statistics.
if self.grad_accum_count == 1:
# Multi GPU gradient gather
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
and p.grad is not None]
onmt.utils.distributed.all_reduce_and_rescale_tensors(
grads, float(1))
self.optim.step()
# If truncated, don't backprop fully.
if dec_state is not None:
dec_state.detach()
# 3.bis Multi GPU gradient gather
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
and p.grad is not None]
onmt.utils.distributed.all_reduce_and_rescale_tensors(
grads, float(1))
# 4. Update the parameters and statistics.
self.optim.step()
# in case of multi step gradient accumulation,
# update only after accum batches
if self.grad_accum_count > 1:
if self.n_gpu > 1:
grads = [p.grad.data for p in self.model.parameters()
if p.requires_grad
and p.grad is not None]
onmt.utils.distributed.all_reduce_and_rescale_tensors(
grads, float(1))
self.optim.step()
def _start_report_manager(self, start_time=None):
"""
......
......@@ -18,6 +18,9 @@ def main(opt):
if opt.epochs:
raise AssertionError("-epochs is deprecated please use -train_steps.")
if opt.truncated_decoder > 0 and opt.accum_count > 1:
raise AssertionError("BPTT is not compatible with -accum > 1")
if len(opt.gpuid) > 1:
multi_main(opt)
else:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment