#!/usr/bin/env python
""" Translator Class and builder """
from __future__ import print_function
import codecs
import os
import math
import time
from itertools import count
import torch
import onmt.model_builder
import onmt.translate.beam
import onmt.inputters as inputters
import onmt.decoders.ensemble
from onmt.translate.beam_search import BeamSearch
from onmt.translate.random_sampling import RandomSampling
from onmt.utils.misc import tile, set_random_seed
from onmt.modules.copy_generator import collapse_copy_scores
def build_translator(opt, report_score=True, logger=None, out_file=None):
if out_file is None:
out_file = codecs.open(opt.output, 'w+', 'utf-8')
load_test_model = onmt.decoders.ensemble.load_test_model \
if len(opt.models) > 1 else onmt.model_builder.load_test_model
fields, model, model_opt = load_test_model(opt)
scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt)
translator = Translator.from_opt(
model,
fields,
opt,
model_opt,
global_scorer=scorer,
out_file=out_file,
report_score=report_score,
logger=logger
)
return translator
[docs]class Translator(object):
"""Translate a batch of sentences with a saved model.
Args:
model (onmt.modules.NMTModel): NMT model to use for translation
fields (dict[str, torchtext.data.Field]): A dict
mapping each side to its list of name-Field pairs.
src_reader (onmt.inputters.DataReaderBase): Source reader.
tgt_reader (onmt.inputters.TextDataReader): Target reader.
gpu (int): GPU device. Set to negative for no GPU.
n_best (int): How many beams to wait for.
min_length (int): See
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
max_length (int): See
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
beam_size (int): Number of beams.
random_sampling_topk (int): See
:class:`onmt.translate.random_sampling.RandomSampling`.
random_sampling_temp (int): See
:class:`onmt.translate.random_sampling.RandomSampling`.
stepwise_penalty (bool): Whether coverage penalty is applied every step
or not.
dump_beam (bool): Debugging option.
block_ngram_repeat (int): See
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
ignore_when_blocking (set or frozenset): See
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
replace_unk (bool): Replace unknown token.
data_type (str): Source data type.
verbose (bool): Print/log every translation.
report_bleu (bool): Print/log Bleu metric.
report_rouge (bool): Print/log Rouge metric.
report_time (bool): Print/log total time/frequency.
copy_attn (bool): Use copy attention.
global_scorer (onmt.translate.GNMTGlobalScorer): Translation
scoring/reranking object.
out_file (TextIO or codecs.StreamReaderWriter): Output file.
report_score (bool) : Whether to report scores
logger (logging.Logger or NoneType): Logger.
"""
def __init__(
self,
model,
fields,
src_reader,
tgt_reader,
gpu=-1,
n_best=1,
min_length=0,
max_length=100,
ratio=0.,
beam_size=30,
random_sampling_topk=1,
random_sampling_temp=1,
stepwise_penalty=None,
dump_beam=False,
block_ngram_repeat=0,
ignore_when_blocking=frozenset(),
replace_unk=False,
phrase_table="",
data_type="text",
verbose=False,
report_bleu=False,
report_rouge=False,
report_time=False,
copy_attn=False,
global_scorer=None,
out_file=None,
report_score=True,
logger=None,
seed=-1):
self.model = model
self.fields = fields
tgt_field = dict(self.fields)["tgt"].base_field
self._tgt_vocab = tgt_field.vocab
self._tgt_eos_idx = self._tgt_vocab.stoi[tgt_field.eos_token]
self._tgt_pad_idx = self._tgt_vocab.stoi[tgt_field.pad_token]
self._tgt_bos_idx = self._tgt_vocab.stoi[tgt_field.init_token]
self._tgt_unk_idx = self._tgt_vocab.stoi[tgt_field.unk_token]
self._tgt_vocab_len = len(self._tgt_vocab)
self._gpu = gpu
self._use_cuda = gpu > -1
self._dev = torch.device("cuda", self._gpu) \
if self._use_cuda else torch.device("cpu")
self.n_best = n_best
self.max_length = max_length
self.beam_size = beam_size
self.random_sampling_temp = random_sampling_temp
self.sample_from_topk = random_sampling_topk
self.min_length = min_length
self.ratio = ratio
self.stepwise_penalty = stepwise_penalty
self.dump_beam = dump_beam
self.block_ngram_repeat = block_ngram_repeat
self.ignore_when_blocking = ignore_when_blocking
self._exclusion_idxs = {
self._tgt_vocab.stoi[t] for t in self.ignore_when_blocking}
self.src_reader = src_reader
self.tgt_reader = tgt_reader
self.replace_unk = replace_unk
if self.replace_unk and not self.model.decoder.attentional:
raise ValueError(
"replace_unk requires an attentional decoder.")
self.phrase_table = phrase_table
self.data_type = data_type
self.verbose = verbose
self.report_bleu = report_bleu
self.report_rouge = report_rouge
self.report_time = report_time
self.copy_attn = copy_attn
self.global_scorer = global_scorer
if self.global_scorer.has_cov_pen and \
not self.model.decoder.attentional:
raise ValueError(
"Coverage penalty requires an attentional decoder.")
self.out_file = out_file
self.report_score = report_score
self.logger = logger
self.use_filter_pred = False
self._filter_pred = None
# for debugging
self.beam_trace = self.dump_beam != ""
self.beam_accum = None
if self.beam_trace:
self.beam_accum = {
"predicted_ids": [],
"beam_parent_ids": [],
"scores": [],
"log_probs": []}
set_random_seed(seed, self._use_cuda)
[docs] @classmethod
def from_opt(
cls,
model,
fields,
opt,
model_opt,
global_scorer=None,
out_file=None,
report_score=True,
logger=None):
"""Alternate constructor.
Args:
model (onmt.modules.NMTModel): See :func:`__init__()`.
fields (dict[str, torchtext.data.Field]): See
:func:`__init__()`.
opt (argparse.Namespace): Command line options
model_opt (argparse.Namespace): Command line options saved with
the model checkpoint.
global_scorer (onmt.translate.GNMTGlobalScorer): See
:func:`__init__()`..
out_file (TextIO or codecs.StreamReaderWriter): See
:func:`__init__()`.
report_score (bool) : See :func:`__init__()`.
logger (logging.Logger or NoneType): See :func:`__init__()`.
"""
src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
tgt_reader = inputters.str2reader["text"].from_opt(opt)
return cls(
model,
fields,
src_reader,
tgt_reader,
gpu=opt.gpu,
n_best=opt.n_best,
min_length=opt.min_length,
max_length=opt.max_length,
ratio=opt.ratio,
beam_size=opt.beam_size,
random_sampling_topk=opt.random_sampling_topk,
random_sampling_temp=opt.random_sampling_temp,
stepwise_penalty=opt.stepwise_penalty,
dump_beam=opt.dump_beam,
block_ngram_repeat=opt.block_ngram_repeat,
ignore_when_blocking=set(opt.ignore_when_blocking),
replace_unk=opt.replace_unk,
phrase_table=opt.phrase_table,
data_type=opt.data_type,
verbose=opt.verbose,
report_bleu=opt.report_bleu,
report_rouge=opt.report_rouge,
report_time=opt.report_time,
copy_attn=model_opt.copy_attn,
global_scorer=global_scorer,
out_file=out_file,
report_score=report_score,
logger=logger,
seed=opt.seed)
def _log(self, msg):
if self.logger:
self.logger.info(msg)
else:
print(msg)
def _gold_score(self, batch, memory_bank, src_lengths, src_vocabs,
use_src_map, enc_states, batch_size, src):
if "tgt" in batch.__dict__:
gs = self._score_target(
batch, memory_bank, src_lengths, src_vocabs,
batch.src_map if use_src_map else None)
self.model.decoder.init_state(src, memory_bank, enc_states)
else:
gs = [0] * batch_size
return gs
[docs] def translate(
self,
src,
tgt=None,
src_dir=None,
batch_size=None,
attn_debug=False,
phrase_table=""):
"""Translate content of ``src`` and get gold scores from ``tgt``.
Args:
src: See :func:`self.src_reader.read()`.
tgt: See :func:`self.tgt_reader.read()`.
src_dir: See :func:`self.src_reader.read()` (only relevant
for certain types of data).
batch_size (int): size of examples per mini-batch
attn_debug (bool): enables the attention logging
Returns:
(`list`, `list`)
* all_scores is a list of `batch_size` lists of `n_best` scores
* all_predictions is a list of `batch_size` lists
of `n_best` predictions
"""
if batch_size is None:
raise ValueError("batch_size must be set")
data = inputters.Dataset(
self.fields,
readers=([self.src_reader, self.tgt_reader]
if tgt else [self.src_reader]),
data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
dirs=[src_dir, None] if tgt else [src_dir],
sort_key=inputters.str2sortkey[self.data_type],
filter_pred=self._filter_pred
)
data_iter = inputters.OrderedIterator(
dataset=data,
device=self._dev,
batch_size=batch_size,
train=False,
sort=False,
sort_within_batch=True,
shuffle=False
)
xlation_builder = onmt.translate.TranslationBuilder(
data, self.fields, self.n_best, self.replace_unk, tgt,
self.phrase_table
)
# Statistics
counter = count(1)
pred_score_total, pred_words_total = 0, 0
gold_score_total, gold_words_total = 0, 0
all_scores = []
all_predictions = []
start_time = time.time()
for batch in data_iter:
batch_data = self.translate_batch(
batch, data.src_vocabs, attn_debug
)
translations = xlation_builder.from_batch(batch_data)
for trans in translations:
all_scores += [trans.pred_scores[:self.n_best]]
pred_score_total += trans.pred_scores[0]
pred_words_total += len(trans.pred_sents[0])
if tgt is not None:
gold_score_total += trans.gold_score
gold_words_total += len(trans.gold_sent) + 1
n_best_preds = [" ".join(pred)
for pred in trans.pred_sents[:self.n_best]]
all_predictions += [n_best_preds]
self.out_file.write('\n'.join(n_best_preds) + '\n')
self.out_file.flush()
if self.verbose:
sent_number = next(counter)
output = trans.log(sent_number)
if self.logger:
self.logger.info(output)
else:
os.write(1, output.encode('utf-8'))
if attn_debug:
preds = trans.pred_sents[0]
preds.append('</s>')
attns = trans.attns[0].tolist()
if self.data_type == 'text':
srcs = trans.src_raw
else:
srcs = [str(item) for item in range(len(attns[0]))]
header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
output = header_format.format("", *srcs) + '\n'
for word, row in zip(preds, attns):
max_index = row.index(max(row))
row_format = row_format.replace(
"{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
row_format = row_format.replace(
"{:*>10.7f} ", "{:>10.7f} ", max_index)
output += row_format.format(word, *row) + '\n'
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
if self.logger:
self.logger.info(output)
else:
os.write(1, output.encode('utf-8'))
end_time = time.time()
if self.report_score:
msg = self._report_score('PRED', pred_score_total,
pred_words_total)
self._log(msg)
if tgt is not None:
msg = self._report_score('GOLD', gold_score_total,
gold_words_total)
self._log(msg)
if self.report_bleu:
msg = self._report_bleu(tgt)
self._log(msg)
if self.report_rouge:
msg = self._report_rouge(tgt)
self._log(msg)
if self.report_time:
total_time = end_time - start_time
self._log("Total translation time (s): %f" % total_time)
self._log("Average translation time (s): %f" % (
total_time / len(all_predictions)))
self._log("Tokens per second: %f" % (
pred_words_total / total_time))
if self.dump_beam:
import json
json.dump(self.translator.beam_accum,
codecs.open(self.dump_beam, 'w', 'utf-8'))
return all_scores, all_predictions
def _translate_random_sampling(
self,
batch,
src_vocabs,
max_length,
min_length=0,
sampling_temp=1.0,
keep_topk=-1,
return_attention=False):
"""Alternative to beam search. Do random sampling at each step."""
assert self.beam_size == 1
# TODO: support these blacklisted features.
assert self.block_ngram_repeat == 0
batch_size = batch.batch_size
# Encoder forward.
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
self.model.decoder.init_state(src, memory_bank, enc_states)
use_src_map = self.copy_attn
results = {
"predictions": None,
"scores": None,
"attention": None,
"batch": batch,
"gold_score": self._gold_score(
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
enc_states, batch_size, src)}
memory_lengths = src_lengths
src_map = batch.src_map if use_src_map else None
if isinstance(memory_bank, tuple):
mb_device = memory_bank[0].device
else:
mb_device = memory_bank.device
random_sampler = RandomSampling(
self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
batch_size, mb_device, min_length, self.block_ngram_repeat,
self._exclusion_idxs, return_attention, self.max_length,
sampling_temp, keep_topk, memory_lengths)
for step in range(max_length):
# Shape: (1, B, 1)
decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)
log_probs, attn = self._decode_and_generate(
decoder_input,
memory_bank,
batch,
src_vocabs,
memory_lengths=memory_lengths,
src_map=src_map,
step=step,
batch_offset=random_sampler.select_indices
)
random_sampler.advance(log_probs, attn)
any_batch_is_finished = random_sampler.is_finished.any()
if any_batch_is_finished:
random_sampler.update_finished()
if random_sampler.done:
break
if any_batch_is_finished:
select_indices = random_sampler.select_indices
# Reorder states.
if isinstance(memory_bank, tuple):
memory_bank = tuple(x.index_select(1, select_indices)
for x in memory_bank)
else:
memory_bank = memory_bank.index_select(1, select_indices)
memory_lengths = memory_lengths.index_select(0, select_indices)
if src_map is not None:
src_map = src_map.index_select(1, select_indices)
self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices))
results["scores"] = random_sampler.scores
results["predictions"] = random_sampler.predictions
results["attention"] = random_sampler.attention
return results
[docs] def translate_batch(self, batch, src_vocabs, attn_debug):
"""Translate a batch of sentences."""
with torch.no_grad():
if self.beam_size == 1:
return self._translate_random_sampling(
batch,
src_vocabs,
self.max_length,
min_length=self.min_length,
sampling_temp=self.random_sampling_temp,
keep_topk=self.sample_from_topk,
return_attention=attn_debug or self.replace_unk)
else:
return self._translate_batch(
batch,
src_vocabs,
self.max_length,
min_length=self.min_length,
ratio=self.ratio,
n_best=self.n_best,
return_attention=attn_debug or self.replace_unk)
def _run_encoder(self, batch):
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
else (batch.src, None)
enc_states, memory_bank, src_lengths = self.model.encoder(
src, src_lengths)
if src_lengths is None:
assert not isinstance(memory_bank, tuple), \
'Ensemble decoding only supported for text data'
src_lengths = torch.Tensor(batch.batch_size) \
.type_as(memory_bank) \
.long() \
.fill_(memory_bank.size(0))
return src, enc_states, memory_bank, src_lengths
def _decode_and_generate(
self,
decoder_in,
memory_bank,
batch,
src_vocabs,
memory_lengths,
src_map=None,
step=None,
batch_offset=None):
if self.copy_attn:
# Turn any copied words into UNKs.
decoder_in = decoder_in.masked_fill(
decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
)
# Decoder forward, takes [tgt_len, batch, nfeats] as input
# and [src_len, batch, hidden] as memory_bank
# in case of inference tgt_len = 1, batch = beam times batch_size
# in case of Gold Scoring tgt_len = actual length, batch = 1 batch
dec_out, dec_attn = self.model.decoder(
decoder_in, memory_bank, memory_lengths=memory_lengths, step=step
)
# Generator forward.
if not self.copy_attn:
if "std" in dec_attn:
attn = dec_attn["std"]
else:
attn = None
log_probs = self.model.generator(dec_out.squeeze(0))
# returns [(batch_size x beam_size) , vocab ] when 1 step
# or [ tgt_len, batch_size, vocab ] when full sentence
else:
attn = dec_attn["copy"]
scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
attn.view(-1, attn.size(2)),
src_map)
# here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
if batch_offset is None:
scores = scores.view(batch.batch_size, -1, scores.size(-1))
else:
scores = scores.view(-1, self.beam_size, scores.size(-1))
scores = collapse_copy_scores(
scores,
batch,
self._tgt_vocab,
src_vocabs,
batch_dim=0,
batch_offset=batch_offset
)
scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
log_probs = scores.squeeze(0).log()
# returns [(batch_size x beam_size) , vocab ] when 1 step
# or [ tgt_len, batch_size, vocab ] when full sentence
return log_probs, attn
def _translate_batch(
self,
batch,
src_vocabs,
max_length,
min_length=0,
ratio=0.,
n_best=1,
return_attention=False):
# TODO: support these blacklisted features.
assert not self.dump_beam
# (0) Prep the components of the search.
use_src_map = self.copy_attn
beam_size = self.beam_size
batch_size = batch.batch_size
# (1) Run the encoder on the src.
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
self.model.decoder.init_state(src, memory_bank, enc_states)
results = {
"predictions": None,
"scores": None,
"attention": None,
"batch": batch,
"gold_score": self._gold_score(
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
enc_states, batch_size, src)}
# (2) Repeat src objects `beam_size` times.
# We use batch_size x beam_size
src_map = (tile(batch.src_map, beam_size, dim=1)
if use_src_map else None)
self.model.decoder.map_state(
lambda state, dim: tile(state, beam_size, dim=dim))
if isinstance(memory_bank, tuple):
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
mb_device = memory_bank[0].device
else:
memory_bank = tile(memory_bank, beam_size, dim=1)
mb_device = memory_bank.device
memory_lengths = tile(src_lengths, beam_size)
# (0) pt 2, prep the beam object
beam = BeamSearch(
beam_size,
n_best=n_best,
batch_size=batch_size,
global_scorer=self.global_scorer,
pad=self._tgt_pad_idx,
eos=self._tgt_eos_idx,
bos=self._tgt_bos_idx,
min_length=min_length,
ratio=ratio,
max_length=max_length,
mb_device=mb_device,
return_attention=return_attention,
stepwise_penalty=self.stepwise_penalty,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
memory_lengths=memory_lengths)
for step in range(max_length):
decoder_input = beam.current_predictions.view(1, -1, 1)
log_probs, attn = self._decode_and_generate(
decoder_input,
memory_bank,
batch,
src_vocabs,
memory_lengths=memory_lengths,
src_map=src_map,
step=step,
batch_offset=beam._batch_offset)
beam.advance(log_probs, attn)
any_beam_is_finished = beam.is_finished.any()
if any_beam_is_finished:
beam.update_finished()
if beam.done:
break
select_indices = beam.current_origin
if any_beam_is_finished:
# Reorder states.
if isinstance(memory_bank, tuple):
memory_bank = tuple(x.index_select(1, select_indices)
for x in memory_bank)
else:
memory_bank = memory_bank.index_select(1, select_indices)
memory_lengths = memory_lengths.index_select(0, select_indices)
if src_map is not None:
src_map = src_map.index_select(1, select_indices)
self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices))
results["scores"] = beam.scores
results["predictions"] = beam.predictions
results["attention"] = beam.attention
return results
# This is left in the code for now, but unsued
def _translate_batch_deprecated(self, batch, src_vocabs):
# (0) Prep each of the components of the search.
# And helper method for reducing verbosity.
use_src_map = self.copy_attn
beam_size = self.beam_size
batch_size = batch.batch_size
beam = [onmt.translate.Beam(
beam_size,
n_best=self.n_best,
cuda=self.cuda,
global_scorer=self.global_scorer,
pad=self._tgt_pad_idx,
eos=self._tgt_eos_idx,
bos=self._tgt_bos_idx,
min_length=self.min_length,
stepwise_penalty=self.stepwise_penalty,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs)
for __ in range(batch_size)]
# (1) Run the encoder on the src.
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
self.model.decoder.init_state(src, memory_bank, enc_states)
results = {
"predictions": [],
"scores": [],
"attention": [],
"batch": batch,
"gold_score": self._gold_score(
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
enc_states, batch_size, src)}
# (2) Repeat src objects `beam_size` times.
# We use now batch_size x beam_size (same as fast mode)
src_map = (tile(batch.src_map, beam_size, dim=1)
if use_src_map else None)
self.model.decoder.map_state(
lambda state, dim: tile(state, beam_size, dim=dim))
if isinstance(memory_bank, tuple):
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
else:
memory_bank = tile(memory_bank, beam_size, dim=1)
memory_lengths = tile(src_lengths, beam_size)
# (3) run the decoder to generate sentences, using beam search.
for i in range(self.max_length):
if all((b.done for b in beam)):
break
# (a) Construct batch x beam_size nxt words.
# Get all the pending current beam words and arrange for forward.
inp = torch.stack([b.current_predictions for b in beam])
inp = inp.view(1, -1, 1)
# (b) Decode and forward
out, beam_attn = self._decode_and_generate(
inp, memory_bank, batch, src_vocabs,
memory_lengths=memory_lengths, src_map=src_map, step=i
)
out = out.view(batch_size, beam_size, -1)
beam_attn = beam_attn.view(batch_size, beam_size, -1)
# (c) Advance each beam.
select_indices_array = []
# Loop over the batch_size number of beam
for j, b in enumerate(beam):
if not b.done:
b.advance(out[j, :],
beam_attn.data[j, :, :memory_lengths[j]])
select_indices_array.append(
b.current_origin + j * beam_size)
select_indices = torch.cat(select_indices_array)
self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices))
# (4) Extract sentences from beam.
for b in beam:
scores, ks = b.sort_finished(minimum=self.n_best)
hyps, attn = [], []
for times, k in ks[:self.n_best]:
hyp, att = b.get_hyp(times, k)
hyps.append(hyp)
attn.append(att)
results["predictions"].append(hyps)
results["scores"].append(scores)
results["attention"].append(attn)
return results
def _score_target(self, batch, memory_bank, src_lengths,
src_vocabs, src_map):
tgt = batch.tgt
tgt_in = tgt[:-1]
log_probs, attn = self._decode_and_generate(
tgt_in, memory_bank, batch, src_vocabs,
memory_lengths=src_lengths, src_map=src_map)
log_probs[:, :, self._tgt_pad_idx] = 0
gold = tgt[1:]
gold_scores = log_probs.gather(2, gold)
gold_scores = gold_scores.sum(dim=0).view(-1)
return gold_scores
def _report_score(self, name, score_total, words_total):
if words_total == 0:
msg = "%s No words predicted" % (name,)
else:
msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
name, score_total / words_total,
name, math.exp(-score_total / words_total)))
return msg
def _report_bleu(self, tgt_path):
import subprocess
base_dir = os.path.abspath(__file__ + "/../../..")
# Rollback pointer to the beginning.
self.out_file.seek(0)
print()
res = subprocess.check_output(
"perl %s/tools/multi-bleu.perl %s" % (base_dir, tgt_path),
stdin=self.out_file, shell=True
).decode("utf-8")
msg = ">> " + res.strip()
return msg
def _report_rouge(self, tgt_path):
import subprocess
path = os.path.split(os.path.realpath(__file__))[0]
msg = subprocess.check_output(
"python %s/tools/test_rouge.py -r %s -c STDIN" % (path, tgt_path),
shell=True, stdin=self.out_file
).decode("utf-8").strip()
return msg