Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 20 additions & 29 deletions RecommenderSystems/dlrm/models/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ def __init__(self, interaction_type='dot', interaction_itself=False, num_sparse_
offset = 1 if self.interaction_itself else 0
# indices = flow.tensor([i * 27 + j for i in range(27) for j in range(i + offset)])
# self.register_buffer("indices", indices)
li = flow.tensor([i for i in range(27) for j in range(i + offset)])
lj = flow.tensor([j for i in range(27) for j in range(i + offset)])
self.register_buffer("li", li)
self.register_buffer("lj", lj)
if interaction_type == 'dot':
li = flow.tensor([i for i in range(27) for j in range(i + offset)])
lj = flow.tensor([j for i in range(27) for j in range(i + offset)])
self.register_buffer("li", li)
self.register_buffer("lj", lj)

def forward(self, x:flow.Tensor, ly:flow.Tensor) -> flow.Tensor:
# x - dense fields, ly = embedding
Expand All @@ -71,7 +72,10 @@ def forward(self, x:flow.Tensor, ly:flow.Tensor) -> flow.Tensor:
# perform a dot product
Z = flow.matmul(T, T, transpose_b=True)
Zflat = Z[:, self.li, self.lj]
R = flow.cat([x, Zflat], dim=1)
R = flow.cat([x, Zflat], dim=1)
elif self.interaction_type == 'fused':
(batch_size, d) = x.shape
R = flow._C.fused_interaction(x, flow.reshape(ly, (batch_size, -1, d)))
else:
assert 0, 'dot or cat'
return R
Expand All @@ -86,6 +90,12 @@ def output_feature_size(self, embedding_vec_size, dense_feature_size):
return dense_feature_size + sum(range(n_cols))
elif self.interaction_type == 'cat':
return embedding_vec_size * self.num_sparse_fields + dense_feature_size
elif self.interaction_type == 'fused':
assert embedding_vec_size == dense_feature_size, "Embedding vector size must equle to dense feature size"
n_cols = self.num_sparse_fields + 1
if self.interaction_itself:
n_cols += 1
return dense_feature_size + sum(range(n_cols)) + 1
else:
assert 0, 'dot or cat'

Expand Down Expand Up @@ -133,7 +143,7 @@ def forward(self, ids):
else:
return embeddings

class OneEmbedding(nn.OneEmbeddingLookup):
class OneEmbedding(nn.Module):
def __init__(self, vocab_size, embed_size, args):
assert args.column_size_array is not None
scales = np.sqrt(1 / np.array(args.column_size_array))
Expand All @@ -159,6 +169,7 @@ def __init__(self, vocab_size, embed_size, args):
"dtype": flow.float,
"name": "my_embedding",
"embedding_dim": embed_size,
"scale_factor": 1,
"cache" : cache_list,
"kv_store": {
"persistent_table": {
Expand All @@ -168,38 +179,18 @@ def __init__(self, vocab_size, embed_size, args):
},
"default_initializer": {"type": "normal", "mean": 0, "std": 1},
"columns": initializer_list,
"optimizer": {
"lr": {
"base_lr": 24,
"decay": {
"type": "polynomial",
"decay_batches": 27772,
"end_lr": 0.0,
"power": 2.0,
"cycle": False,
},
"warmup": {
"type": "linear",
"warmup_batches": 2750,
"start_multiplier": 0.0,
},
},
"type": "sgd",
"momentum": 0.0,
"betas": [0.9, 0.999],
"eps": 1e-8,
},
}
super(OneEmbedding, self).__init__(options)
super(OneEmbedding, self).__init__()
column_id = flow.tensor(range(26), dtype=flow.int32).reshape(1,26)
self.register_buffer("column_id", column_id)
self.one_embedding = nn.OneEmbeddingLookup(options)

def forward(self, ids):
bsz = ids.shape[0]
column_id = flow.ones((bsz, 1), dtype=flow.int32, sbp=ids.sbp, placement=ids.placement) * self.column_id
if (ids.is_consistent):
column_id = column_id.to_consistent(sbp=ids.sbp, placement=ids.placement)
return super(OneEmbedding, self._origin).forward(ids, column_id)
return self.one_embedding.forward(ids, column_id)
def set_model_parallel(self, placement=None):
pass

Expand Down
2 changes: 1 addition & 1 deletion RecommenderSystems/dlrm/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import oneflow as flow
import os
import sys
import pickle
Expand All @@ -7,7 +8,6 @@
)
import numpy as np
from sklearn.metrics import roc_auc_score
import oneflow as flow
from config import get_args
from models.data import make_data_loader
from models.dlrm import make_dlrm_module
Expand Down