Skip to content
Merged
4 changes: 2 additions & 2 deletions docs/source/user/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pre-processing data, constructing model and training model.
from fastNLP.modules import aggregation
from fastNLP.modules import decoder

from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.dataset_loader import ClassDataSetLoader
from fastNLP.loader.preprocess import ClassPreprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.core.inference import ClassificationInfer
Expand Down Expand Up @@ -50,7 +50,7 @@ pre-processing data, constructing model and training model.
train_path = 'test/data_for_tests/text_classify.txt' # training set file

# load dataset
ds_loader = ClassDatasetLoader("train", train_path)
ds_loader = ClassDataSetLoader("train", train_path)
data = ds_loader.load()

# pre-process dataset
Expand Down
4 changes: 2 additions & 2 deletions examples/readme_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastNLP.core.predictor import ClassificationInfer
from fastNLP.core.preprocess import ClassPreprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.dataset_loader import ClassDataSetLoader
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import aggregator
from fastNLP.modules import decoder
Expand Down Expand Up @@ -36,7 +36,7 @@ def forward(self, x):
train_path = './data_for_tests/text_classify.txt' # training set file

# load dataset
ds_loader = ClassDatasetLoader(train_path)
ds_loader = ClassDataSetLoader()
data = ds_loader.load()

# pre-process dataset
Expand Down
21 changes: 4 additions & 17 deletions fastNLP/core/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, dataset, batch_size, sampler, use_cuda):
:param dataset: a DataSet object
:param batch_size: int, the size of the batch
:param sampler: a Sampler object
:param use_cuda: bool, whetjher to use GPU
:param use_cuda: bool, whether to use GPU

"""
self.dataset = dataset
Expand All @@ -37,15 +37,12 @@ def __next__(self):
"""

:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
batch_x also contains an item (str: list of int) about origin lengths,
which means ("field_name_origin_len": origin lengths).
E.g.
::
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]})

batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True.
The names of fields are defined in preprocessor's convert_to_dataset method.

"""
if self.curidx >= len(self.idx_list):
Expand All @@ -54,34 +51,24 @@ def __next__(self):
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
padding_length = {field_name: max(field_length[self.curidx: endidx])
for field_name, field_length in self.lengths.items()}
origin_lengths = {field_name: field_length[self.curidx: endidx]
for field_name, field_length in self.lengths.items()}

batch_x, batch_y = defaultdict(list), defaultdict(list)

# transform index to tensor and do padding for sequences
for idx in range(self.curidx, endidx):
x, y = self.dataset.to_tensor(idx, padding_length)
for name, tensor in x.items():
batch_x[name].append(tensor)
for name, tensor in y.items():
batch_y[name].append(tensor)

batch_origin_length = {}
# combine instances into a batch
# combine instances to form a batch
for batch in (batch_x, batch_y):
for name, tensor_list in batch.items():
if self.use_cuda:
batch[name] = torch.stack(tensor_list, dim=0).cuda()
else:
batch[name] = torch.stack(tensor_list, dim=0)

# add origin lengths in batch_x
for name, tensor in batch_x.items():
if self.use_cuda:
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda()
else:
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name])
batch_x.update(batch_origin_length)

self.curidx = endidx
return batch_x, batch_y

182 changes: 179 additions & 3 deletions fastNLP/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import random
from collections import defaultdict
from copy import deepcopy

from fastNLP.core.field import TextField
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.loader.dataset_loader import POSDataSetLoader, ClassDataSetLoader


def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None):
Expand Down Expand Up @@ -65,17 +69,19 @@ class DataSet(list):
"""A DataSet object is a list of Instance objects.

"""
def __init__(self, name="", instances=None):

def __init__(self, name="", instances=None, load_func=None):
"""

:param name: str, the name of the dataset. (default: "")
:param instances: list of Instance objects. (default: None)

:param load_func: a function that takes the dataset path (string) as input and returns multi-level lists.
"""
list.__init__([])
self.name = name
if instances is not None:
self.extend(instances)
self.data_set_load_func = load_func

def index_all(self, vocab):
for ins in self:
Expand Down Expand Up @@ -109,3 +115,173 @@ def get_length(self):
for field_name, field_length in ins.get_length().items():
lengths[field_name].append(field_length)
return lengths

def convert(self, data):
"""Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training."""
raise NotImplementedError

def convert_with_vocabs(self, data, vocabs):
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing."""
raise NotImplementedError

def convert_for_infer(self, data, vocabs):
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting."""

def load(self, data_path, vocabs=None, infer=False):
"""Load data from the given files.

:param data_path: str, the path to the data
:param infer: bool. If True, there is no label information in the data. Default: False.
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed.

"""
raw_data = self.data_set_load_func(data_path)
if infer is True:
self.convert_for_infer(raw_data, vocabs)
else:
if vocabs is not None:
self.convert_with_vocabs(raw_data, vocabs)
else:
self.convert(raw_data)

def load_raw(self, raw_data, vocabs):
"""Load raw data without loader. Used in FastNLP class.

:param raw_data:
:param vocabs:
:return:
"""
self.convert_for_infer(raw_data, vocabs)

def split(self, ratio, shuffle=True):
"""Train/dev splitting

:param ratio: float, between 0 and 1. The ratio of development set in origin data set.
:param shuffle: bool, whether shuffle the data set before splitting. Default: True.
:return train_set: a DataSet object, representing the training set
dev_set: a DataSet object, representing the validation set

"""
assert 0 < ratio < 1
if shuffle:
random.shuffle(self)
split_idx = int(len(self) * ratio)
dev_set = deepcopy(self)
train_set = deepcopy(self)
del train_set[:split_idx]
del dev_set[split_idx:]
return train_set, dev_set


class SeqLabelDataSet(DataSet):
def __init__(self, instances=None, load_func=POSDataSetLoader().load):
super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func)
self.word_vocab = Vocabulary()
self.label_vocab = Vocabulary()

def convert(self, data):
"""Convert lists of strings into Instances with Fields.

:param data: 3-level lists. Entries are strings.
"""
for example in data:
word_seq, label_seq = example[0], example[1]
# list, list
self.word_vocab.update(word_seq)
self.label_vocab.update(label_seq)
x = TextField(word_seq, is_target=False)
x_len = LabelField(len(word_seq), is_target=False)
y = TextField(label_seq, is_target=False)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("truth", y)
instance.add_field("word_seq_origin_len", x_len)
self.append(instance)
self.index_field("word_seq", self.word_vocab)
self.index_field("truth", self.label_vocab)
# no need to index "word_seq_origin_len"

def convert_with_vocabs(self, data, vocabs):
for example in data:
word_seq, label_seq = example[0], example[1]
# list, list
x = TextField(word_seq, is_target=False)
x_len = LabelField(len(word_seq), is_target=False)
y = TextField(label_seq, is_target=False)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("truth", y)
instance.add_field("word_seq_origin_len", x_len)
self.append(instance)
self.index_field("word_seq", vocabs["word_vocab"])
self.index_field("truth", vocabs["label_vocab"])
# no need to index "word_seq_origin_len"

def convert_for_infer(self, data, vocabs):
for word_seq in data:
# list
x = TextField(word_seq, is_target=False)
x_len = LabelField(len(word_seq), is_target=False)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("word_seq_origin_len", x_len)
self.append(instance)
self.index_field("word_seq", vocabs["word_vocab"])
# no need to index "word_seq_origin_len"


class TextClassifyDataSet(DataSet):
def __init__(self, instances=None, load_func=ClassDataSetLoader().load):
super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func)
self.word_vocab = Vocabulary()
self.label_vocab = Vocabulary(need_default=False)

def convert(self, data):
for example in data:
word_seq, label = example[0], example[1]
# list, str
self.word_vocab.update(word_seq)
self.label_vocab.update(label)
x = TextField(word_seq, is_target=False)
y = LabelField(label, is_target=True)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("label", y)
self.append(instance)
self.index_field("word_seq", self.word_vocab)
self.index_field("label", self.label_vocab)

def convert_with_vocabs(self, data, vocabs):
for example in data:
word_seq, label = example[0], example[1]
# list, str
x = TextField(word_seq, is_target=False)
y = LabelField(label, is_target=True)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("label", y)
self.append(instance)
self.index_field("word_seq", vocabs["word_vocab"])
self.index_field("label", vocabs["label_vocab"])

def convert_for_infer(self, data, vocabs):
for word_seq in data:
# list
x = TextField(word_seq, is_target=False)
instance = Instance()
instance.add_field("word_seq", x)
self.append(instance)
self.index_field("word_seq", vocabs["word_vocab"])


def change_field_is_target(data_set, field_name, new_target):
"""Change the flag of is_target in a field.

:param data_set: a DataSet object
:param field_name: str, the name of the field
:param new_target: one of (True, False, None), representing this field is batch_x / is batch_y / neither.

"""
for inst in data_set:
inst.fields[field_name].is_target = new_target

8 changes: 6 additions & 2 deletions fastNLP/core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def to_tensor(self, padding_length: int):


class LabelField(Field):
"""The Field representing a single label. Can be a string or integer.

"""
def __init__(self, label, is_target=True):
super(LabelField, self).__init__(is_target)
self.label = label
Expand All @@ -73,13 +76,14 @@ def get_length(self):

def index(self, vocab):
if self._index is None:
self._index = vocab[self.label]
if isinstance(self.label, str):
self._index = vocab[self.label]
return self._index

def to_tensor(self, padding_length):
if self._index is None:
if isinstance(self.label, int):
return torch.LongTensor([self.label])
return torch.tensor(self.label)
elif isinstance(self.label, str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:
Expand Down
7 changes: 5 additions & 2 deletions fastNLP/core/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def to_tensor(self, padding_length: dict):
tensor_x = {}
tensor_y = {}
for name, field in self.fields.items():
if field.is_target:
if field.is_target is True:
tensor_y[name] = field.to_tensor(padding_length[name])
else:
elif field.is_target is False:
tensor_x[name] = field.to_tensor(padding_length[name])
else:
# is_target is None
continue
return tensor_x, tensor_y
17 changes: 16 additions & 1 deletion fastNLP/core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,25 @@ def _borrow_from_pytorch(loss_name):
"""Given a name of a loss function, return it from PyTorch.

:param loss_name: str, the name of a loss function

- cross_entropy: combines log softmax and nll loss in a single function.
- nll: negative log likelihood

:return loss: a PyTorch loss
"""

class InnerCrossEntropy:
"""A simple wrapper to guarantee input shapes."""

def __init__(self):
self.f = torch.nn.CrossEntropyLoss()

def __call__(self, predict, truth):
truth = truth.view(-1, )
return self.f(predict, truth)

if loss_name == "cross_entropy":
return torch.nn.CrossEntropyLoss()
return InnerCrossEntropy()
elif loss_name == 'nll':
return torch.nn.NLLLoss()
else:
Expand Down
Loading