Skip to content

Commit

Permalink
gru4rec
Browse files Browse the repository at this point in the history
  • Loading branch information
HaSai666 committed Jun 6, 2022
1 parent 3df58fc commit ef31153
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
133 changes: 133 additions & 0 deletions examples/matching/run_ml_gru4rec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import sys

sys.path.append("../..")

import os
import numpy as np
import pandas as pd
import torch

from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from torch_rechub.models.matching import GRU4Rec
from torch_rechub.trainers import MatchTrainer
from torch_rechub.basic.features import DenseFeature, SparseFeature, SequenceFeature
from torch_rechub.utils.match import generate_seq_feature_match, gen_model_input
from torch_rechub.utils.data import df_to_dict, MatchDataGenerator
from movielens_utils import match_evaluation


def get_movielens_data(data_path, load_cache=False):
data = pd.read_csv(data_path)
data["cate_id"] = data["genres"].apply(lambda x: x.split("|")[0])
sparse_features = ['user_id', 'movie_id', 'gender', 'age', 'occupation', 'zip', "cate_id"]
user_col, item_col, label_col = "user_id", "movie_id", "label"

feature_max_idx = {}
for feature in sparse_features:
lbe = LabelEncoder()
data[feature] = lbe.fit_transform(data[feature]) + 1
feature_max_idx[feature] = data[feature].max() + 1
if feature == user_col:
user_map = {encode_id + 1: raw_id for encode_id, raw_id in enumerate(lbe.classes_)} #encode user id: raw user id
if feature == item_col:
item_map = {encode_id + 1: raw_id for encode_id, raw_id in enumerate(lbe.classes_)} #encode item id: raw item id
np.save("./data/ml-1m/saved/raw_id_maps.npy", (user_map, item_map))

user_profile = data[["user_id", "gender", "age", "occupation", "zip"]].drop_duplicates('user_id')
item_profile = data[["movie_id", "cate_id"]].drop_duplicates('movie_id')

if load_cache: #if you have run this script before and saved the preprocessed data
x_train, y_train, x_test = np.load("./data/ml-1m/saved/data_cache.npy", allow_pickle=True)
else:
#Note: mode=2 means list-wise negative sample generate, saved in last col "neg_items"
df_train, df_test = generate_seq_feature_match(data,
user_col,
item_col,
time_col="timestamp",
item_attribute_cols=[],
sample_method=1,
mode=2,
neg_ratio=3,
min_item=0)
x_train = gen_model_input(df_train, user_profile, user_col, item_profile, item_col, seq_max_len=50, padding='post', truncating='post')
y_train = np.array([0] * df_train.shape[0]) #label=0 means the first pred value is positiva sample
x_test = gen_model_input(df_test, user_profile, user_col, item_profile, item_col, seq_max_len=50, padding='post', truncating='post')
np.save("./data/ml-1m/saved/data_cache.npy", (x_train, y_train, x_test))

user_cols = ['user_id', 'gender', 'age', 'occupation', 'zip']

user_features = [SparseFeature(name, vocab_size=feature_max_idx[name], embed_dim=16) for name in user_cols]
history_features = [
SequenceFeature("hist_movie_id",
vocab_size=feature_max_idx["movie_id"],
embed_dim=16,
pooling="concat",
shared_with="movie_id")
]

item_features = [SparseFeature('movie_id', vocab_size=feature_max_idx['movie_id'], embed_dim=16)]
neg_item_feature = [
SequenceFeature('neg_items',
vocab_size=feature_max_idx['movie_id'],
embed_dim=16,
pooling="concat",
shared_with="movie_id")
]

all_item = df_to_dict(item_profile)
test_user = x_test
return user_features, history_features, item_features, neg_item_feature, x_train, y_train, all_item, test_user


def main(dataset_path, model_name, epoch, learning_rate, batch_size, weight_decay, device, save_dir, seed):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.manual_seed(seed)
user_features, history_features, item_features, neg_item_feature, x_train, y_train, all_item, test_user = get_movielens_data(dataset_path)
dg = MatchDataGenerator(x=x_train, y=y_train)

model = GRU4Rec(user_features, history_features, item_features, neg_item_feature, user_params={"dims": [128, 64, 16]}, temperature=0.02)

#mode=1 means pair-wise learning
trainer = MatchTrainer(model,
mode=2,
optimizer_params={
"lr": learning_rate,
"weight_decay": weight_decay
},
n_epoch=epoch,
device=device,
model_path=save_dir,
gpus=[0])

train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=batch_size, num_workers=0)
trainer.fit(train_dl)

print("inference embedding")
user_embedding = trainer.inference_embedding(model=model, mode="user", data_loader=test_dl, model_path=save_dir)
item_embedding = trainer.inference_embedding(model=model, mode="item", data_loader=item_dl, model_path=save_dir)
print(user_embedding.shape, item_embedding.shape)
#torch.save(user_embedding.data.cpu(), save_dir + "user_embedding.pth")
#torch.save(item_embedding.data.cpu(), save_dir + "item_embedding.pth")
match_evaluation(user_embedding, item_embedding, test_user, all_item, topk=10)


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', default="./data/ml-1m/ml-1m_sample.csv")
parser.add_argument('--model_name', default='gru4rec')
parser.add_argument('--epoch', type=int, default=100) #10
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--batch_size', type=int, default=256) #4096
parser.add_argument('--weight_decay', type=float, default=1e-6)
parser.add_argument('--device', default='cuda:0') #cuda:0
parser.add_argument('--save_dir', default='./data/ml-1m/saved/')
parser.add_argument('--seed', type=int, default=2022)

args = parser.parse_args()
main(args.dataset_path, args.model_name, args.epoch, args.learning_rate, args.batch_size, args.weight_decay, args.device,
args.save_dir, args.seed)
"""
python run_ml_youtube_dnn.py
"""
3 changes: 2 additions & 1 deletion torch_rechub/models/matching/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .dssm import DSSM
from .youtube_dnn import YoutubeDNN
from .youtube_sbc import YoutubeSBC
from .dssm_facebook import FaceBookDSSM
from .dssm_facebook import FaceBookDSSM
from .gru4rec import GRU4Rec
89 changes: 89 additions & 0 deletions torch_rechub/models/matching/gru4rec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Date: create on 03/06/2022
References:
paper: SESSION-BASED RECOMMENDATIONS WITH RECURRENT NEURAL NETWORKS
url: http://arxiv.org/abs/1511.06939
Authors: Kai Wang, [email protected]
"""

import torch

from ...basic.layers import MLP, EmbeddingLayer
from torch import nn


class GRU4Rec(torch.nn.Module):
"""The match model mentioned in `Deep Neural Networks for YouTube Recommendations` paper.
It's a DSSM match model trained by global softmax loss on list-wise samples.
Note in origin paper, it's without item dnn tower and train item embedding directly.
Args:
user_features (list[Feature Class]): training by the user tower module.
history_features (list[Feature Class]): training history
item_features (list[Feature Class]): training by the embedding table, it's the item id feature.
neg_item_feature (list[Feature Class]): training by the embedding table, it's the negative items id feature.
user_params (dict): the params of the User Tower module, keys include:`{"dims":list, "activation":str, "dropout":float, "output_layer":bool`}.
sim_func (str): similarity function, includes `["cosine", "dot"]`, default to "cosine".
temperature (float): temperature factor for similarity score, default to 1.0.
"""

def __init__(self, user_features, history_features, item_features, neg_item_feature, user_params, sim_func="cosine", temperature=1.0):
super().__init__()
self.user_features = user_features
self.item_features = item_features
self.history_features = history_features
self.neg_item_feature = neg_item_feature
self.sim_func = sim_func
self.temperature = temperature
self.user_dims = sum([fea.embed_dim for fea in user_features+history_features])

self.embedding = EmbeddingLayer(user_features + item_features + history_features)
self.gru = nn.GRU(input_size = history_features[0].embed_dim,
hidden_size = history_features[0].embed_dim,
num_layers = user_params.get('num_layers',2),
batch_first = True,
bias = False)
self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
self.mode = None

def forward(self, x):
user_embedding = self.user_tower(x)
item_embedding = self.item_tower(x)
if self.mode == "user":
return user_embedding
if self.mode == "item":
return item_embedding
if self.sim_func == "cosine":
y = torch.cosine_similarity(user_embedding, item_embedding, dim=-1) #[batch_size, 1+n_neg_items, embed_dim]
elif self.sim_func == "dot":
y = torch.mul(user_embedding, item_embedding).sum(dim=1)
else:
raise ValueError("similarity function only support %s, but got %s" % (["cosine", "dot"], self.sim_func))
y = y / self.temperature
return y

def user_tower(self, x):
if self.mode == "item":
return None
input_user = self.embedding(x, self.user_features, squeeze_dim=True) #[batch_size, num_features*deep_dims]

history_emb = self.embedding(x, self.history_features).squeeze(1)
_, history_emb = self.gru(history_emb)
history_emb = history_emb[-1]

input_user = torch.cat([input_user,history_emb],dim=-1)

user_embedding = self.user_mlp(input_user).unsqueeze(1) #[batch_size, 1, embed_dim]
if self.mode == "user":
return user_embedding.squeeze(1) #inference embedding mode -> [batch_size, embed_dim]
return user_embedding

def item_tower(self, x):
if self.mode == "user":
return None
pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False) #[batch_size, 1, embed_dim]
if self.mode == "item": #inference embedding mode
return pos_embedding.squeeze(1) #[batch_size, embed_dim]
neg_embeddings = self.embedding(x, self.neg_item_feature,
squeeze_dim=False).squeeze(1) #[batch_size, n_neg_items, embed_dim]
return torch.cat((pos_embedding, neg_embeddings), dim=1) #[batch_size, 1+n_neg_items, embed_dim]
4 changes: 2 additions & 2 deletions torch_rechub/utils/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from .data import pad_sequences, df_to_dict


def gen_model_input(df, user_profile, user_col, item_profile, item_col, seq_max_len):
def gen_model_input(df, user_profile, user_col, item_profile, item_col, seq_max_len, padding='pre', truncating='pre'):
#merge user_profile and item_profile, pad history seuence feature
df = pd.merge(df, user_profile, on=user_col, how='left') # how=left to keep samples order same as the input
df = pd.merge(df, item_profile, on=item_col, how='left')
for col in df.columns.to_list():
if col.startswith("hist_"):
df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0).tolist()
df[col] = pad_sequences(df[col], maxlen=seq_max_len, value=0, padding=padding, truncating=truncating).tolist()
input_dict = df_to_dict(df)
return input_dict

Expand Down

0 comments on commit ef31153

Please sign in to comment.