Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
torch:
device: 'cpu'
seed: 42

prepare_data:
train_data:
Expand Down
61 changes: 36 additions & 25 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from argparse import ArgumentParser
from collections import Counter

Expand All @@ -19,14 +20,20 @@
)
from pytorch_ner.save import save_model
from pytorch_ner.train import train
from pytorch_ner.utils import set_global_seed


def main(path_to_config: str):

with open(path_to_config, mode="r") as fp:
config = yaml.safe_load(fp)

# check existence of save path_to_folder
if os.path.exists(config["save"]["path_to_folder"]):
raise FileExistsError("save directory already exists")

device = torch.device(config["torch"]["device"])
set_global_seed(config["torch"]["seed"])

# LOAD DATA

Expand All @@ -46,12 +53,13 @@ def main(path_to_config: str):
verbose=config["prepare_data"]["val_data"]["verbose"],
)

test_token_seq, test_label_seq = prepare_conll_data_format(
path=config["prepare_data"]["test_data"]["path"],
sep=config["prepare_data"]["test_data"]["sep"],
lower=config["prepare_data"]["test_data"]["lower"],
verbose=config["prepare_data"]["test_data"]["verbose"],
)
if "test_data" in config["prepare_data"]:
test_token_seq, test_label_seq = prepare_conll_data_format(
path=config["prepare_data"]["test_data"]["path"],
sep=config["prepare_data"]["test_data"]["sep"],
lower=config["prepare_data"]["test_data"]["lower"],
verbose=config["prepare_data"]["test_data"]["verbose"],
)

# token2idx / label2idx

Expand Down Expand Up @@ -85,13 +93,14 @@ def main(path_to_config: str):
preprocess=config["dataloader"]["preprocess"],
)

testset = NERDataset(
token_seq=test_token_seq,
label_seq=test_label_seq,
token2idx=token2idx,
label2idx=label2idx,
preprocess=config["dataloader"]["preprocess"],
)
if "test_data" in config["prepare_data"]:
testset = NERDataset(
token_seq=test_token_seq,
label_seq=test_label_seq,
token2idx=token2idx,
label2idx=label2idx,
preprocess=config["dataloader"]["preprocess"],
)

# collators

Expand All @@ -107,11 +116,12 @@ def main(path_to_config: str):
percentile=100, # hardcoded
)

test_collator = NERCollator(
token_padding_value=token2idx[config["dataloader"]["token_padding"]],
label_padding_value=label2idx[config["dataloader"]["label_padding"]],
percentile=100, # hardcoded
)
if "test_data" in config["prepare_data"]:
test_collator = NERCollator(
token_padding_value=token2idx[config["dataloader"]["token_padding"]],
label_padding_value=label2idx[config["dataloader"]["label_padding"]],
percentile=100, # hardcoded
)

# dataloaders

Expand All @@ -130,12 +140,13 @@ def main(path_to_config: str):
collate_fn=val_collator,
)

testloader = DataLoader(
dataset=testset,
batch_size=1, # hardcoded
shuffle=False, # hardcoded
collate_fn=test_collator,
)
if "test_data" in config["prepare_data"]:
testloader = DataLoader(
dataset=testset,
batch_size=1, # hardcoded
shuffle=False, # hardcoded
collate_fn=test_collator,
)

# INIT MODEL

Expand Down Expand Up @@ -196,7 +207,7 @@ def main(path_to_config: str):
model=model,
trainloader=trainloader,
valloader=valloader,
testloader=testloader,
testloader=testloader if "test_data" in config["prepare_data"] else None,
criterion=criterion,
optimizer=optimizer,
device=device,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_ner/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
return np.array(tokens), np.array(labels), np.array(lengths)


class NERCollator(object):
class NERCollator:
"""
Collator that handles variable-size sentences.
"""
Expand Down
7 changes: 4 additions & 3 deletions pytorch_ner/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import yaml

from pytorch_ner.onnx import onnx_export_and_check
from pytorch_ner.utils import mkdir, rmdir
from pytorch_ner.utils import mkdir


def save_model(
Expand All @@ -18,8 +18,9 @@ def save_model(
config: Dict,
export_onnx: bool = False,
):
# make empty dir
rmdir(path_to_folder)
# check existence of save path_to_folder
if os.path.exists(path_to_folder):
raise FileExistsError("save directory already exists")
mkdir(path_to_folder)

model.cpu()
Expand Down
10 changes: 5 additions & 5 deletions pytorch_ner/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def masking(lengths: torch.Tensor) -> torch.Tensor:


# TODO: clip_grad_norm
def train_loop(
def train_epoch(
model: nn.Module,
dataloader: DataLoader,
criterion: Callable,
Expand Down Expand Up @@ -82,7 +82,7 @@ def train_loop(
return metrics


def validate_loop(
def validate_epoch(
model: nn.Module,
dataloader: DataLoader,
criterion: Callable,
Expand Down Expand Up @@ -156,7 +156,7 @@ def train(
if verbose:
print(f"epoch [{epoch+1}/{n_epoch}]\n")

train_metrics = train_loop(
train_metrics = train_epoch(
model=model,
dataloader=trainloader,
criterion=criterion,
Expand All @@ -170,7 +170,7 @@ def train(
print(f"train {metric_name}: {np.mean(metric_list)}")
print()

val_metrics = validate_loop(
val_metrics = validate_epoch(
model=model,
dataloader=valloader,
criterion=criterion,
Expand All @@ -185,7 +185,7 @@ def train(

if testloader is not None:

test_metrics = validate_loop(
test_metrics = validate_epoch(
model=model,
dataloader=testloader,
criterion=criterion,
Expand Down
1 change: 1 addition & 0 deletions pytorch_ner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def set_global_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Expand Down
4 changes: 4 additions & 0 deletions tests/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
import yaml

from pytorch_ner.save import save_model
from pytorch_ner.utils import rmdir
from tests.test_train import label2idx, model, token2idx

path_to_folder = "models/test_save/"
path_to_onnx_folder = "models/test_onnx_save/"

rmdir(path_to_folder)
rmdir(path_to_onnx_folder)


with open("config.yaml", "r") as fp:
config = yaml.safe_load(fp)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
get_token2idx,
prepare_conll_data_format,
)
from pytorch_ner.train import train, validate_loop
from pytorch_ner.train import train, validate_epoch
from tests.test_nn_modules.test_architecture import model_bilstm as model

device = torch.device("cpu")
Expand Down Expand Up @@ -55,7 +55,7 @@

# VALIDATE

metrics_before = validate_loop(
metrics_before = validate_epoch(
model=model.to(device),
dataloader=dataloader,
criterion=criterion,
Expand All @@ -82,7 +82,7 @@
class TestTrain(unittest.TestCase):
def test_val_metrics(self):

metrics_after = validate_loop(
metrics_after = validate_epoch(
model=model.to(device),
dataloader=dataloader,
criterion=criterion,
Expand Down