Skip to content

Commit

Permalink
update ranking model BST version.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
1985312383 committed Feb 26, 2024
1 parent b5e5db3 commit 4b9be91
Show file tree
Hide file tree
Showing 5 changed files with 493 additions and 128 deletions.
2 changes: 2 additions & 0 deletions torch_rechub/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def activation_layer(act_name):
act_layer = nn.PReLU()
elif act_name.lower() == "softmax":
act_layer = nn.Softmax(dim=1)
elif act_name.lower() == 'leakyrelu':
act_layer = nn.LeakyReLU()
elif issubclass(act_name, nn.Module):
act_layer = act_name()
else:
Expand Down
3 changes: 2 additions & 1 deletion torch_rechub/models/ranking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .dcn_v2 import DCNv2
from .edcn import EDCN
from .deepffm import DeepFFM, FatDeepFFM
from .fibinet import FiBiNet
from .fibinet import FiBiNet
from .bst import BST
77 changes: 77 additions & 0 deletions torch_rechub/models/ranking/bst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Date: create on 26/02/2024, update on 30/04/2022
References:
paper: Behavior Sequence Transformer for E-commerce Recommendation in Alibaba
url: https://arxiv.org/pdf/1905.06874
code: https://github.com/jiwidi/Behavior-Sequence-Transformer-Pytorch/blob/master/pytorch_bst.ipynb
Authors: Tao Fan, [email protected]
"""

import torch
# import torch.utils.data as data
# from torchvision import transforms
# import ast
# from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from ...basic.layers import EmbeddingLayer, MLP




class BST(nn.Module):
"""Behavior Sequence Transformer
features (list): the list of `Feature Class`. training by MLP. It means the user profile features and context features in origin paper, exclude history and target features.
history_features (list): the list of `Feature Class`,training by ActivationUnit. It means the user behaviour sequence features, eg.item id sequence, shop id sequence.
target_features (list): the list of `Feature Class`, training by ActivationUnit. It means the target feature which will execute target-attention with history feature.
mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}
"""
def __init__(self, features, history_features, target_features, mlp_params):
super().__init__()
self.features = features
self.history_features = history_features
self.target_features = target_features
self.num_history_features = len(history_features)
# self.positional_embedding = PositionalEmbedding(8, 9)
self.all_dims = sum([fea.embed_dim for fea in features + history_features + target_features])
self.embedding = EmbeddingLayer(features + history_features + target_features)
self.attention_layers = nn.TransformerEncoderLayer(64, 8, dropout=0.2) # nn.ModuleList([ActivationUnit(fea.embed_dim, **attention_mlp_params) for fea in self.history_features])
self.mlp = MLP(self.all_dims, activation="leakyrelu", **mlp_params) # # 定义模型,模型的参数需要我们之前的feature类,用于构建模型的输入层,mlp指定模型后续DNN的结构


def forward(self, x):
embed_x_features = self.embedding(x, self.features) #(batch_size, num_features, emb_dim)
embed_x_history = self.embedding(x, self.history_features) #(batch_size, num_history_features, seq_length, emb_dim)
embed_x_target = self.embedding(x, self.target_features) #(batch_size, num_target_features, emb_dim)
# positional_embedding = self.positional_embedding(torch.cat([embed_x_history,embed_x_target],dim=2))
# embed_x_history = torch.cat((embed_x_history, positional_embedding), dim=2)
attention_pooling = []
for i in range(self.num_history_features):
attention_seq = self.attention_layers(embed_x_history[:, i, :, :])
attention_pooling.append(attention_seq) #(batch_size, seq_length, emb_dim)
attention_pooling = torch.stack(attention_pooling,dim=1).mean(dim=2) #(batch_size, num_history_features, emb_dim)
# print(attention_pooling.shape, embed_x_target.shape, embed_x_features.shape)
mlp_in = torch.cat([
attention_pooling.flatten(start_dim=1),
embed_x_target.flatten(start_dim=1),
embed_x_features.flatten(start_dim=1)
],
dim=1) #(batch_size, N)
# print(mlp_in.shape)
y = self.mlp(mlp_in)
return torch.sigmoid(y.squeeze(1))

class PositionalEmbedding(nn.Module):
"""
Computes positional embedding following "Attention is all you need"
"""

def __init__(self, max_len, d_model):
super().__init__()

# Compute the positional encodings once in log space.
self.pe = nn.Embedding(max_len, d_model)

def forward(self, x):
batch_size = x.size(0)
return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)

71 changes: 71 additions & 0 deletions torch_rechub/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,74 @@ def array_replace_with_dict(array, dic):
# Get argsort indices
idx = k.argsort()
return v[idx[np.searchsorted(k, array, sorter=idx)]]


# Temporarily reserved for testing purposes([email protected])
def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50, drop_short=3, shuffle=True):
"""Build a sequence of user's history by time.
Args:
data (pd.DataFrame): must contain keys: `user_id, item_id, cate_id, time`.
seq_feature_col (list): specify the column name that needs to generate sequence features, and its sequence features will be generated according to userid.
max_len (int): the max length of a user history sequence.
drop_short (int): remove some inactive user who's sequence length < drop_short.
shuffle (bool): shuffle data if true.
Returns:
train (pd.DataFrame): target item will be each item before last two items.
val (pd.DataFrame): target item is the second to last item of user's history sequence.
test (pd.DataFrame): target item is the last item of user's history sequence.
"""
for feat in data:
le = LabelEncoder()
data[feat] = le.fit_transform(data[feat])
data[feat] = data[feat].apply(lambda x: x + 1) # 0 to be used as the symbol for padding
data = data.astype('int32')

n_items = data["item_id"].max()

item_cate_map = data[['item_id', 'cate_id']]
item2cate_dict = item_cate_map.set_index(['item_id'])['cate_id'].to_dict()

data = data.sort_values(['user_id', 'time']).groupby('user_id').agg(click_hist_list=('item_id', list), cate_hist_hist=('cate_id', list)).reset_index()

# Sliding window to construct negative samples
train_data, val_data, test_data = [], [], []
for item in data.itertuples():
if len(item[2]) < drop_short:
continue
user_id = item[1]
click_hist_list = item[2][:max_len]
cate_hist_list = item[3][:max_len]

neg_list = [neg_sample(click_hist_list, n_items) for _ in range(len(click_hist_list))]
hist_list = []
cate_list = []
for i in range(1, len(click_hist_list)):
hist_list.append(click_hist_list[i - 1])
cate_list.append(cate_hist_list[i - 1])
hist_list_pad = hist_list + [0] * (max_len - len(hist_list))
cate_list_pad = cate_list + [0] * (max_len - len(cate_list))
if i == len(click_hist_list) - 1:
test_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
test_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
if i == len(click_hist_list) - 2:
val_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
val_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
else:
train_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
train_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])

# shuffle
if shuffle:
random.shuffle(train_data)
random.shuffle(val_data)
random.shuffle(test_data)

col_name = ['user_id', 'history_item', 'history_cate', 'target_item', 'target_cate', 'label']
train = pd.DataFrame(train_data, columns=col_name)
val = pd.DataFrame(val_data, columns=col_name)
test = pd.DataFrame(test_data, columns=col_name)

return train, val, test

Loading

0 comments on commit 4b9be91

Please sign in to comment.