Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
update setup.py
update comments in EmbeddingLayer
  • Loading branch information
morningsky committed Jun 11, 2022
1 parent b8a8974 commit 3ca917b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 21 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
## 安装

```python
#最新版(目前推荐安装方式)
#稳定版
pip install torch-rechub

#最新版
1. git clone https://github.com/datawhalechina/torch-rechub.git
2. cd torch-rechub
3. pip install -e . --verbose

#稳定版 (暂未更新)
#pip install torch-rechub

3. python setup.py install
```

## 核心定位
Expand Down
13 changes: 6 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from distutils.core import setup
from setuptools import find_packages

with open("README.md", "r") as f:
with open("README.md", "r", encoding='utf-8') as f:
long_description = f.read()

setup(
name='torch-rechub',
version='0.0.1',
version='0.0.2',
description='A Lighting Pytorch Framework for Recommendation System, Easy-to-use and Easy-to-extend.',
long_description=long_description,
long_description_content_type="text/markdown",
author='Mincai Lai',
author_email='[email protected]',
url='https://github.com/morningsky/Torch-RecHub',
install_requires=['numpy>=1.21.5', 'torch>=1.7.0', 'pandas>=1.0.5', 'tqdm>=4.64.0', 'scikit_learn>=0.23.2'],
author='Datawhale',
author_email='[email protected]',
url='https://github.com/datawhalechina/torch-rechub',
install_requires=['numpy>=1.19.0', 'torch>=1.7.0', 'pandas>=1.0.5', 'tqdm>=4.64.0', 'scikit_learn>=0.23.2', 'annoy>=1.17.0'],
packages=find_packages(),
platforms=["all"],
classifiers=[
Expand All @@ -22,7 +22,6 @@
"Intended Audience :: Science/Research",
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Topic :: Scientific/Engineering',
Expand Down
30 changes: 22 additions & 8 deletions torch_rechub/basic/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ def forward(self, x):


class EmbeddingLayer(nn.Module):
"""General Embedding Layer. We init each embedding layer by `xavier_normal_`.
"""General Embedding Layer.
We save all the feature embeddings in embed_dict: `{feature_name : embedding table}`.
Args:
features (list): the list of `Feature Class`. It is means all the features which we want to create a embedding table.
embed_dict (dict): the embedding dict, `{feature_name : embedding table}`.
Shape:
- Input:
Expand Down Expand Up @@ -78,11 +79,13 @@ def forward(self, x, features, squeeze_dim=False):
elif fea.pooling == "concat":
pooling_layer = ConcatPooling()
else:
raise ValueError("Sequence pooling method supports only pooling in %s, got %s." % (["sum", "mean"], fea.pooling))
raise ValueError("Sequence pooling method supports only pooling in %s, got %s." %
(["sum", "mean"], fea.pooling))
if fea.shared_with == None:
sparse_emb.append(pooling_layer(self.embed_dict[fea.name](x[fea.name].long())).unsqueeze(1))
else:
sparse_emb.append(pooling_layer(self.embed_dict[fea.shared_with](x[fea.name].long())).unsqueeze(1)) #shared specific sparse feature embedding
sparse_emb.append(pooling_layer(self.embed_dict[fea.shared_with](
x[fea.name].long())).unsqueeze(1)) #shared specific sparse feature embedding
else:
dense_values.append(x[fea.name].float().unsqueeze(1)) #.unsqueeze(1).unsqueeze(1)

Expand All @@ -99,14 +102,17 @@ def forward(self, x, features, squeeze_dim=False):
elif not dense_exists and sparse_exists:
return sparse_emb.flatten(start_dim=1) #squeeze dim to : [batch_size, num_features*embed_dim]
elif dense_exists and sparse_exists:
return torch.cat((sparse_emb.flatten(start_dim=1), dense_values), dim=1) #concat dense value with sparse embedding
return torch.cat((sparse_emb.flatten(start_dim=1), dense_values),
dim=1) #concat dense value with sparse embedding
else:
raise ValueError("The input features can note be empty")
else:
if sparse_exists:
return sparse_emb #[batch_size, num_features, embed_dim]
else:
raise ValueError("If keep the original shape:[batch_size, num_features, embed_dim], expected %s in feature list, got %s" % ("SparseFeatures", features))
raise ValueError(
"If keep the original shape:[batch_size, num_features, embed_dim], expected %s in feature list, got %s" %
("SparseFeatures", features))


class LR(nn.Module):
Expand Down Expand Up @@ -406,9 +412,17 @@ def forward(self, item_eb, mask):
item_eb_hat_iter = item_eb_hat

if self.bilinear_type > 0:
capsule_weight = torch.zeros(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
capsule_weight = torch.zeros(item_eb_hat.shape[0],
self.interest_num,
self.seq_len,
device=item_eb.device,
requires_grad=False)
else:
capsule_weight = torch.randn(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
capsule_weight = torch.randn(item_eb_hat.shape[0],
self.interest_num,
self.seq_len,
device=item_eb.device,
requires_grad=False)

for i in range(self.routing_times): # 动态路由传播3次
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
Expand Down

0 comments on commit 3ca917b

Please sign in to comment.