# -*- coding: utf-8 -*-
from functools import partial
import six
import torch
from torchtext.data import Field, RawField
from onmt.inputters.datareader_base import DataReaderBase
[docs]class TextDataReader(DataReaderBase):
[docs] def read(self, sequences, side, _dir=None):
"""Read text data from disk.
Args:
sequences (str or Iterable[str]):
path to text file or iterable of the actual text data.
side (str): Prefix used in return dict. Usually
``"src"`` or ``"tgt"``.
_dir (NoneType): Leave as ``None``. This parameter exists to
conform with the :func:`DataReaderBase.read()` signature.
Yields:
dictionaries whose keys are the names of fields and whose
values are more or less the result of tokenizing with those
fields.
"""
assert _dir is None or _dir == "", \
"Cannot use _dir with TextDataReader."
if isinstance(sequences, str):
sequences = DataReaderBase._read_file(sequences)
for i, seq in enumerate(sequences):
if isinstance(seq, six.binary_type):
seq = seq.decode("utf-8")
yield {side: seq, "indices": i}
def text_sort_key(ex):
"""Sort using the number of tokens in the sequence."""
if hasattr(ex, "tgt"):
return len(ex.src[0]), len(ex.tgt[0])
return len(ex.src[0])
# mix this with partial
def _feature_tokenize(
string, layer=0, tok_delim=None, feat_delim=None, truncate=None):
"""Split apart word features (like POS/NER tags) from the tokens.
Args:
string (str): A string with ``tok_delim`` joining tokens and
features joined by ``feat_delim``. For example,
``"hello|NOUN|'' Earth|NOUN|PLANET"``.
layer (int): Which feature to extract. (Not used if there are no
features, indicated by ``feat_delim is None``). In the
example above, layer 2 is ``'' PLANET``.
truncate (int or NoneType): Restrict sequences to this length of
tokens.
Returns:
List[str] of tokens.
"""
tokens = string.split(tok_delim)
if truncate is not None:
tokens = tokens[:truncate]
if feat_delim is not None:
tokens = [t.split(feat_delim)[layer] for t in tokens]
return tokens
class TextMultiField(RawField):
"""Container for subfields.
Text data might use POS/NER/etc labels in addition to tokens.
This class associates the "base" :class:`Field` with any subfields.
It also handles padding the data and stacking it.
Args:
base_name (str): Name for the base field.
base_field (Field): The token field.
feats_fields (Iterable[Tuple[str, Field]]): A list of name-field
pairs.
Attributes:
fields (Iterable[Tuple[str, Field]]): A list of name-field pairs.
The order is defined as the base field first, then
``feats_fields`` in alphabetical order.
"""
def __init__(self, base_name, base_field, feats_fields):
super(TextMultiField, self).__init__()
self.fields = [(base_name, base_field)]
for name, ff in sorted(feats_fields, key=lambda kv: kv[0]):
self.fields.append((name, ff))
@property
def base_field(self):
return self.fields[0][1]
def process(self, batch, device=None):
"""Convert outputs of preprocess into Tensors.
Args:
batch (List[List[List[str]]]): A list of length batch size.
Each element is a list of the preprocess results for each
field (which are lists of str "words" or feature tags.
device (torch.device or str): The device on which the tensor(s)
are built.
Returns:
torch.LongTensor or Tuple[LongTensor, LongTensor]:
A tensor of shape ``(seq_len, batch_size, len(self.fields))``
where the field features are ordered like ``self.fields``.
If the base field returns lengths, these are also returned
and have shape ``(batch_size,)``.
"""
# batch (list(list(list))): batch_size x len(self.fields) x seq_len
batch_by_feat = list(zip(*batch))
base_data = self.base_field.process(batch_by_feat[0], device=device)
if self.base_field.include_lengths:
# lengths: batch_size
base_data, lengths = base_data
feats = [ff.process(batch_by_feat[i], device=device)
for i, (_, ff) in enumerate(self.fields[1:], 1)]
levels = [base_data] + feats
# data: seq_len x batch_size x len(self.fields)
data = torch.stack(levels, 2)
if self.base_field.include_lengths:
return data, lengths
else:
return data
def preprocess(self, x):
"""Preprocess data.
Args:
x (str): A sentence string (words joined by whitespace).
Returns:
List[List[str]]: A list of length ``len(self.fields)`` containing
lists of tokens/feature tags for the sentence. The output
is ordered like ``self.fields``.
"""
return [f.preprocess(x) for _, f in self.fields]
def __getitem__(self, item):
return self.fields[item]
def text_fields(**kwargs):
"""Create text fields.
Args:
base_name (str): Name associated with the field.
n_feats (int): Number of word level feats (not counting the tokens)
include_lengths (bool): Optionally return the sequence lengths.
pad (str, optional): Defaults to ``"<blank>"``.
bos (str or NoneType, optional): Defaults to ``"<s>"``.
eos (str or NoneType, optional): Defaults to ``"</s>"``.
truncate (bool or NoneType, optional): Defaults to ``None``.
Returns:
TextMultiField
"""
n_feats = kwargs["n_feats"]
include_lengths = kwargs["include_lengths"]
base_name = kwargs["base_name"]
pad = kwargs.get("pad", "<blank>")
bos = kwargs.get("bos", "<s>")
eos = kwargs.get("eos", "</s>")
truncate = kwargs.get("truncate", None)
fields_ = []
feat_delim = u"│" if n_feats > 0 else None
for i in range(n_feats + 1):
name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
tokenize = partial(
_feature_tokenize,
layer=i,
truncate=truncate,
feat_delim=feat_delim)
use_len = i == 0 and include_lengths
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=use_len)
fields_.append((name, feat))
assert fields_[0][0] == base_name # sanity check
field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:])
return field