Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions NLP/CPT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,22 @@ We implemented some extra functions here:

<!-- - `LayerNorm` in `models/dev_ops.py` -->
- `tensor_unique` in `models/bart_utils.py`
- `position_scores` in `models/bert_utils.py`

## requirement

This project uses the lightly version of oneflow. You can use the following command to install.
CPU:
```bash
python3 -m pip install -f https://staging.oneflow.info/branch/master/cpu --pre oneflow
```
GPU:
```bash
python3 -m pip install -f https://staging.oneflow.info/branch/master/cu112 --pre oneflow
```
You can install other dependencies using the following command.
```bash
pip install -r requirements.txt
```

## Train

Expand All @@ -41,7 +56,7 @@ or the pretrained model `cpt-large`, which contains 24 encoder layers and 4 deco

```bash
wget https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/nlp/CPT/cpt-large.tar.gz
tar -xzf cpt-base.tar.gz
tar -xzf cpt-large.tar.gz
```

Finally, run bash `train.sh` can train the model. Remember to change parameter `CPT_PRETRAIN_DIR` to correctly load the pretrained parameters.
Expand Down
18 changes: 11 additions & 7 deletions NLP/CPT/classifier_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os

import oneflow as flow
import oneflow.nn as nn

Expand All @@ -10,16 +11,19 @@ class ClueAFQMCCPT(nn.Module):
def __init__(self, pretrain_dir, num_labels, is_train):
super(ClueAFQMCCPT, self).__init__()
kwargs_path = os.path.join(pretrain_dir, "parameters.json")
with open(kwargs_path, "r") as f:
kwargs = json.load(f)
model = CPT(**kwargs)
if is_train == True:
model.load_state_dict(flow.load(os.path.join(pretrain_dir, "weights")))
self.cpt = model
self.classifier = nn.Linear(model.d_model, num_labels)
self.cpt = self.load_model(pretrain_dir, kwargs_path, is_train)
self.classifier = nn.Linear(self.cpt.d_model, num_labels)

def forward(self, inputs, masks):
outputs = self.cpt(inputs, masks)
outputs = outputs[0][:, 0, :]
outputs = self.classifier(outputs)
return outputs

def load_model(self, pretrain_dir, kwargs_path, is_train):
with open(kwargs_path, "r") as f:
kwargs = json.load(f)
model = CPT(**kwargs)
if is_train == True:
model.load_state_dict(flow.load(os.path.join(pretrain_dir, "weights")))
return model
2 changes: 1 addition & 1 deletion NLP/CPT/infer_flow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse

import oneflow as flow
from transformers import BertTokenizer

from classifier_flow import ClueAFQMCCPT
from tokenizer.tokenization_bert import BertTokenizer


def inference_afqmc(args):
Expand Down
51 changes: 20 additions & 31 deletions NLP/CPT/models/CPT.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from logging import log
import math
import random
from typing import Optional, Tuple

import oneflow as flow
import oneflow.nn as nn
from oneflow.nn import CrossEntropyLoss, MSELoss
from oneflow.nn import CrossEntropyLoss

from .bert import Bert
from .bart_utils import (
Expand All @@ -15,20 +14,7 @@
init_weights,
tensor_unique, # for tensor.unique
)

ACT2FN = {
"relu": flow._C.relu,
# "silu": silu,
# "swish": silu,
"gelu": flow._C.gelu,
"tanh": flow.tanh,
# "gelu_new": gelu_new,
# "gelu_fast": gelu_fast,
# "quick_gelu": quick_gelu,
# "mish": mish,
# "linear": linear_act,
"sigmoid": flow.sigmoid,
}
from .utils import ACT2FN


class BartLearnedPositionalEmbedding(nn.Embedding):
Expand All @@ -55,7 +41,7 @@ def forward(self, input_ids_shape: flow.Size, past_key_values_length: int = 0):


class BartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper.See https://doi.org/10.48550/arXiv.1706.03762 """

def __init__(
self,
Expand Down Expand Up @@ -132,10 +118,12 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
proj_shape = [bsz * self.num_heads, -1, self.head_dim]
query_states = flow.reshape(
self._shape(query_states, tgt_len, bsz), shape=proj_shape
)
key_states = flow.reshape(key_states, shape=proj_shape)
value_states = flow.reshape(value_states, shape=proj_shape)

src_len = key_states.size(1)
attn_weights = flow.bmm(query_states, key_states.transpose(1, 2))
Expand Down Expand Up @@ -189,7 +177,7 @@ def forward(
# attn_probs = flow.F.dropout(attn_weights, p=prob)
# attn_output = flow.bmm(attn_probs, value_states)
if self.training:
attn_weights = flow._C.dropout(attn_weights, p=self.dropout)
attn_weights = flow.nn.functional.dropout(attn_weights, p=self.dropout)
attn_output = flow.bmm(attn_weights, value_states)

assert attn_output.size() == (
Expand Down Expand Up @@ -271,7 +259,7 @@ def forward(
output_attentions,
)
if self.training:
hidden_states = flow._C.dropout(hidden_states, p=self.dropout)
hidden_states = flow.nn.functional.dropout(hidden_states, p=self.dropout)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)

Expand All @@ -298,7 +286,9 @@ def forward(
output_attentions,
)
if self.training:
hidden_states = flow._C.dropout(hidden_states, p=self.dropout)
hidden_states = flow.nn.functional.dropout(
hidden_states, p=self.dropout
)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)

Expand All @@ -309,10 +299,12 @@ def forward(
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
if self.training:
hidden_states = flow._C.dropout(hidden_states, p=self.activation_dropout)
hidden_states = flow.nn.functional.dropout(
hidden_states, p=self.activation_dropout
)
hidden_states = self.fc2(hidden_states)
if self.training:
hidden_states = flow._C.dropout(hidden_states, p=self.dropout)
hidden_states = flow.nn.functional.dropout(hidden_states, p=self.dropout)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)

Expand Down Expand Up @@ -472,7 +464,7 @@ def forward(
hidden_states = self.layernorm_embedding(hidden_states)

if self.training:
hidden_states = flow._C.dropout(hidden_states, p=self.dropout)
hidden_states = flow.nn.functional.dropout(hidden_states, p=self.dropout)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand Down Expand Up @@ -1494,10 +1486,7 @@ def forward(
else:
raise NotImplementedError

# start_logits, end_logits = logits.split(1, dim=-1)
# oneflow does not support split.
split_half = logits.shape[-1] // 2
start_logits, end_logits = logits[:, :, :split_half], logits[:, :, split_half:]
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)

Expand Down
40 changes: 13 additions & 27 deletions NLP/CPT/models/bert.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,17 @@
# referemce to transformers bert model
import oneflow as flow
import math
from typing import Tuple, Optional

import oneflow as flow
import oneflow.nn as nn

import math
from typing import Tuple, Optional
from .bert_utils import (
init_weights,
find_pruneable_heads_and_indices,
prune_linear_layer,
apply_chunking_to_forward,
position_scores, # replace einsum
)

ACT2FN = {
"relu": flow._C.relu,
# "silu": silu,
# "swish": silu,
"gelu": flow._C.gelu,
"tanh": flow.tanh,
# "gelu_new": gelu_new,
# "gelu_fast": gelu_fast,
# "quick_gelu": quick_gelu,
# "mish": mish,
# "linear": linear_act,
"sigmoid": flow.sigmoid,
}
from .utils import ACT2FN


class BertEmbeddings(nn.Module):
Expand All @@ -52,7 +38,7 @@ def __init__(
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = position_embedding_type
self.register_buffer(
"position_ids", flow.arange(max_position_embeddings).expand((1, -1))
"position_ids", flow.arange(max_position_embeddings).expand(1, -1)
)
self.register_buffer(
"token_type_ids",
Expand Down Expand Up @@ -224,17 +210,18 @@ def forward(
) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = position_scores(
query_layer, positional_embedding
relative_position_scores = flow.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = position_scores(
query_layer, positional_embedding
relative_position_scores_query = flow.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
relative_position_scores_key = position_scores(
key_layer, positional_embedding
relative_position_scores_key = flow.einsum(
"bhld,lrd->bhlr", key_layer, positional_embedding
)

attention_scores = (
attention_scores
+ relative_position_scores_query
Expand Down Expand Up @@ -263,8 +250,7 @@ def forward(
new_context_layer_shape = tuple(context_layer.size()[:-2]) + (
self.all_head_size,
)
context_layer = context_layer.view(*new_context_layer_shape)

context_layer = flow.reshape(context_layer, shape=new_context_layer_shape)
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
Expand Down
17 changes: 0 additions & 17 deletions NLP/CPT/models/bert_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,3 @@ def apply_chunking_to_forward(
return flow.cat(output_chunks, dim=chunk_dim)

return forward_fn(*input_tensors)


def position_scores(layer, embed):
# replace flow.einsum when
# position_embedding_type == "relative_key" or "relative_key_query"

assert layer.dim() == 4
assert embed.dim() == 3
assert layer.shape[3] == embed.shape[2]
assert layer.shape[2] == embed.shape[0]
b, h, l, d = layer.shape
l, r, d = embed.shape

layer = layer.unsqueeze(-2)
embed = embed.transpose(-2, -1).unsqueeze(0).unsqueeze(0).expand(b, h, l, d, r)

return flow.matmul(layer, embed).squeeze(-2)
15 changes: 15 additions & 0 deletions NLP/CPT/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import oneflow as flow


def gelu_new(x):
gelu = flow.nn.GELU(approximate="tanh")
return gelu(x)


ACT2FN = {
"relu": flow.nn.functional.relu,
"gelu": flow.nn.functional.gelu,
"tanh": flow.nn.functional.tanh,
"gelu_new": gelu_new,
"sigmoid": flow.nn.functional.sigmoid,
}
7 changes: 7 additions & 0 deletions NLP/CPT/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
filelock==3.7.1
huggingface_hub==0.8.1
numpy==1.21.2
packaging==21.0
requests==2.25.1
tqdm==4.62.1
transformers==4.21.1
1 change: 0 additions & 1 deletion NLP/CPT/tokenizer/__init__.py

This file was deleted.

Loading