-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update ranking model BST version.0.1
- Loading branch information
1 parent
b5e5db3
commit 4b9be91
Showing
5 changed files
with
493 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" | ||
Date: create on 26/02/2024, update on 30/04/2022 | ||
References: | ||
paper: Behavior Sequence Transformer for E-commerce Recommendation in Alibaba | ||
url: https://arxiv.org/pdf/1905.06874 | ||
code: https://github.com/jiwidi/Behavior-Sequence-Transformer-Pytorch/blob/master/pytorch_bst.ipynb | ||
Authors: Tao Fan, [email protected] | ||
""" | ||
|
||
import torch | ||
# import torch.utils.data as data | ||
# from torchvision import transforms | ||
# import ast | ||
# from torch.nn.utils.rnn import pad_sequence | ||
import torch.nn as nn | ||
from ...basic.layers import EmbeddingLayer, MLP | ||
|
||
|
||
|
||
|
||
class BST(nn.Module): | ||
"""Behavior Sequence Transformer | ||
features (list): the list of `Feature Class`. training by MLP. It means the user profile features and context features in origin paper, exclude history and target features. | ||
history_features (list): the list of `Feature Class`,training by ActivationUnit. It means the user behaviour sequence features, eg.item id sequence, shop id sequence. | ||
target_features (list): the list of `Feature Class`, training by ActivationUnit. It means the target feature which will execute target-attention with history feature. | ||
mlp_params (dict): the params of the last MLP module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`} | ||
""" | ||
def __init__(self, features, history_features, target_features, mlp_params): | ||
super().__init__() | ||
self.features = features | ||
self.history_features = history_features | ||
self.target_features = target_features | ||
self.num_history_features = len(history_features) | ||
# self.positional_embedding = PositionalEmbedding(8, 9) | ||
self.all_dims = sum([fea.embed_dim for fea in features + history_features + target_features]) | ||
self.embedding = EmbeddingLayer(features + history_features + target_features) | ||
self.attention_layers = nn.TransformerEncoderLayer(64, 8, dropout=0.2) # nn.ModuleList([ActivationUnit(fea.embed_dim, **attention_mlp_params) for fea in self.history_features]) | ||
self.mlp = MLP(self.all_dims, activation="leakyrelu", **mlp_params) # # 定义模型,模型的参数需要我们之前的feature类,用于构建模型的输入层,mlp指定模型后续DNN的结构 | ||
|
||
|
||
def forward(self, x): | ||
embed_x_features = self.embedding(x, self.features) #(batch_size, num_features, emb_dim) | ||
embed_x_history = self.embedding(x, self.history_features) #(batch_size, num_history_features, seq_length, emb_dim) | ||
embed_x_target = self.embedding(x, self.target_features) #(batch_size, num_target_features, emb_dim) | ||
# positional_embedding = self.positional_embedding(torch.cat([embed_x_history,embed_x_target],dim=2)) | ||
# embed_x_history = torch.cat((embed_x_history, positional_embedding), dim=2) | ||
attention_pooling = [] | ||
for i in range(self.num_history_features): | ||
attention_seq = self.attention_layers(embed_x_history[:, i, :, :]) | ||
attention_pooling.append(attention_seq) #(batch_size, seq_length, emb_dim) | ||
attention_pooling = torch.stack(attention_pooling,dim=1).mean(dim=2) #(batch_size, num_history_features, emb_dim) | ||
# print(attention_pooling.shape, embed_x_target.shape, embed_x_features.shape) | ||
mlp_in = torch.cat([ | ||
attention_pooling.flatten(start_dim=1), | ||
embed_x_target.flatten(start_dim=1), | ||
embed_x_features.flatten(start_dim=1) | ||
], | ||
dim=1) #(batch_size, N) | ||
# print(mlp_in.shape) | ||
y = self.mlp(mlp_in) | ||
return torch.sigmoid(y.squeeze(1)) | ||
|
||
class PositionalEmbedding(nn.Module): | ||
""" | ||
Computes positional embedding following "Attention is all you need" | ||
""" | ||
|
||
def __init__(self, max_len, d_model): | ||
super().__init__() | ||
|
||
# Compute the positional encodings once in log space. | ||
self.pe = nn.Embedding(max_len, d_model) | ||
|
||
def forward(self, x): | ||
batch_size = x.size(0) | ||
return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -274,3 +274,74 @@ def array_replace_with_dict(array, dic): | |
# Get argsort indices | ||
idx = k.argsort() | ||
return v[idx[np.searchsorted(k, array, sorter=idx)]] | ||
|
||
|
||
# Temporarily reserved for testing purposes([email protected]) | ||
def create_seq_features(data, seq_feature_col=['item_id', 'cate_id'], max_len=50, drop_short=3, shuffle=True): | ||
"""Build a sequence of user's history by time. | ||
Args: | ||
data (pd.DataFrame): must contain keys: `user_id, item_id, cate_id, time`. | ||
seq_feature_col (list): specify the column name that needs to generate sequence features, and its sequence features will be generated according to userid. | ||
max_len (int): the max length of a user history sequence. | ||
drop_short (int): remove some inactive user who's sequence length < drop_short. | ||
shuffle (bool): shuffle data if true. | ||
Returns: | ||
train (pd.DataFrame): target item will be each item before last two items. | ||
val (pd.DataFrame): target item is the second to last item of user's history sequence. | ||
test (pd.DataFrame): target item is the last item of user's history sequence. | ||
""" | ||
for feat in data: | ||
le = LabelEncoder() | ||
data[feat] = le.fit_transform(data[feat]) | ||
data[feat] = data[feat].apply(lambda x: x + 1) # 0 to be used as the symbol for padding | ||
data = data.astype('int32') | ||
|
||
n_items = data["item_id"].max() | ||
|
||
item_cate_map = data[['item_id', 'cate_id']] | ||
item2cate_dict = item_cate_map.set_index(['item_id'])['cate_id'].to_dict() | ||
|
||
data = data.sort_values(['user_id', 'time']).groupby('user_id').agg(click_hist_list=('item_id', list), cate_hist_hist=('cate_id', list)).reset_index() | ||
|
||
# Sliding window to construct negative samples | ||
train_data, val_data, test_data = [], [], [] | ||
for item in data.itertuples(): | ||
if len(item[2]) < drop_short: | ||
continue | ||
user_id = item[1] | ||
click_hist_list = item[2][:max_len] | ||
cate_hist_list = item[3][:max_len] | ||
|
||
neg_list = [neg_sample(click_hist_list, n_items) for _ in range(len(click_hist_list))] | ||
hist_list = [] | ||
cate_list = [] | ||
for i in range(1, len(click_hist_list)): | ||
hist_list.append(click_hist_list[i - 1]) | ||
cate_list.append(cate_hist_list[i - 1]) | ||
hist_list_pad = hist_list + [0] * (max_len - len(hist_list)) | ||
cate_list_pad = cate_list + [0] * (max_len - len(cate_list)) | ||
if i == len(click_hist_list) - 1: | ||
test_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1]) | ||
test_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0]) | ||
if i == len(click_hist_list) - 2: | ||
val_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1]) | ||
val_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0]) | ||
else: | ||
train_data.append([user_id, hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1]) | ||
train_data.append([user_id, hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0]) | ||
|
||
# shuffle | ||
if shuffle: | ||
random.shuffle(train_data) | ||
random.shuffle(val_data) | ||
random.shuffle(test_data) | ||
|
||
col_name = ['user_id', 'history_item', 'history_cate', 'target_item', 'target_cate', 'label'] | ||
train = pd.DataFrame(train_data, columns=col_name) | ||
val = pd.DataFrame(val_data, columns=col_name) | ||
test = pd.DataFrame(test_data, columns=col_name) | ||
|
||
return train, val, test | ||
|
Oops, something went wrong.