Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
morningsky committed Jun 13, 2022
1 parent 21dc3e5 commit 41e1ed2
Showing 1 changed file with 41 additions and 10 deletions.
51 changes: 41 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#稳定版
pip install torch-rechub

#最新版
#最新版(推荐)
1. git clone https://github.com/datawhalechina/torch-rechub.git
2. cd torch-rechub
3. python setup.py install
Expand Down Expand Up @@ -79,23 +79,27 @@ pip install torch-rechub

## 快速使用

### 单任务排序
### 使用案例

- 所有模型使用案例参考 `/examples`

- 202206 Datawhale-RecHub推荐课程 组队学习期间notebook教程参考 `/tutorials`

### 精排(CTR预测)

```python
from torch_rechub.models.ranking import WideDeep, DeepFM, DIN
from torch_rechub.models.ranking import DeepFM
from torch_rechub.trainers import CTRTrainer
from torch_rechub.basic.utils import DataGenerator
from torch_rechub.utils.data import DataGenerator

dg = DataGenerator(x, y)
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader()
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256)

model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})

ctr_trainer = CTRTrainer(model)
ctr_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)


```

### 多任务排序
Expand All @@ -104,9 +108,36 @@ auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
from torch_rechub.trainers import MTLTrainer

model = MMOE(features, task_types, n_expert=3, expert_params={"dims": [64,32,16]}, tower_params_list=[{"dims": [8]}, {"dims": [8]}])
task_types = ["classification", "classification"]
model = MMOE(features, task_types, 8, expert_params={"dims": [32,16]}, tower_params_list=[{"dims": [32, 16]}, {"dims": [32, 16]}])

ctr_trainer = MTLTrainer(model)
ctr_trainer.fit(train_dataloader, val_dataloader)
mtl_trainer = MTLTrainer(model)
mtl_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
```

### 召回模型

```python
from torch_rechub.models.matching import DSSM
from torch_rechub.trainers import MatchTrainer
from torch_rechub.utils.data import MatchDataGenerator

dg = MatchDataGenerator(x y)
train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)

model = DSSM(user_features, item_features, temperature=0.02,
user_params={
"dims": [256, 128, 64],
"activation": 'prelu',
},
item_params={
"dims": [256, 128, 64],
"activation": 'prelu',
})

match_trainer = MatchTrainer(model)
match_trainer.fit(train_dl)

```

0 comments on commit 41e1ed2

Please sign in to comment.