Demo
我们的Rank和多任务模型所对外暴露的接口十分相似,我们下面会分别给出Rank和多任务模型的demo
1.Rank Demo
# 声明数据schema
import torch
from rec_pangu.dataset import get_dataloader
from rec_pangu.models.ranking import WDL, DeepFM, NFM, FiBiNet, AFM, AFN, AOANet, AutoInt, CCPM, LR, FM, xDeepFM
from rec_pangu.trainer import RankTrainer
import pandas as pd
if __name__ == '__main__':
df = pd.read_csv('sample_data/ranking_sample_data.csv')
print(df.head())
# 声明数据schema
schema = {
"sparse_cols": ['user_id', 'item_id', 'item_type', 'dayofweek', 'is_workday', 'city', 'county',
'town', 'village', 'lbs_city', 'lbs_district', 'hardware_platform', 'hardware_ischarging',
'os_type', 'network_type', 'position'],
"dense_cols": ['item_expo_1d', 'item_expo_7d', 'item_expo_14d', 'item_expo_30d', 'item_clk_1d',
'item_clk_7d', 'item_clk_14d', 'item_clk_30d', 'use_duration'],
"label_col": 'click',
}
# 准备数据,这里只选择了100条数据,所以没有切分数据集
train_df = df
valid_df = df
test_df = df
# 声明使用的device
device = torch.device('cpu')
# 获取dataloader
train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema)
# 声明模型,排序模型目前支持:WDL, DeepFM, NFM, FiBiNet, AFM, AFN, AOANet, AutoInt, CCPM, LR, FM, xDeepFM
model = xDeepFM(enc_dict=enc_dict)
# 声明Trainer
trainer = RankTrainer(num_task=1)
# 训练模型
trainer.fit(model, train_loader, valid_loader, epoch=5, lr=1e-3, device=device)
# 保存模型权重
trainer.save_model(model, './model_ckpt')
# 模型验证
test_metric = trainer.evaluate_model(model, test_loader, device=device)
print('Test metric:{}'.format(test_metric))
2.多任务模型Demo
import torch
from rec_pangu.dataset import get_dataloader
from rec_pangu.models.multi_task import AITM, ShareBottom, ESSM, MMOE, OMOE, MLMMOE
from rec_pangu.trainer import RankTrainer
import pandas as pd
if __name__ == '__main__':
df = pd.read_csv('sample_data/multi_task_sample_data.csv')
print(df.head())
#声明数据schema
schema = {
"sparse_cols": ['user_id', 'item_id', 'item_type', 'dayofweek', 'is_workday', 'city', 'county',
'town', 'village', 'lbs_city', 'lbs_district', 'hardware_platform', 'hardware_ischarging',
'os_type', 'network_type', 'position'],
"dense_cols": ['item_expo_1d', 'item_expo_7d', 'item_expo_14d', 'item_expo_30d', 'item_clk_1d',
'item_clk_7d', 'item_clk_14d', 'item_clk_30d', 'use_duration'],
"label_col": ['click', 'scroll'],
}
#准备数据,这里只选择了100条数据,所以没有切分数据集
train_df = df
valid_df = df
test_df = df
#声明使用的device
device = torch.device('cpu')
#获取dataloader
train_loader, valid_loader, test_loader, enc_dict = get_dataloader(train_df, valid_df, test_df, schema)
#声明模型,多任务模型目前支持:AITM,ShareBottom,ESSM,MMOE,OMOE,MLMMOE
model = AITM(enc_dict=enc_dict)
#声明Trainer
trainer = RankTrainer(num_task=2)
#训练模型
trainer.fit(model, train_loader, valid_loader, epoch=5, lr=1e-3, device=device)
#保存模型权重
trainer.save_model(model, './model_ckpt')
#模型验证
test_metric = trainer.evaluate_model(model, test_loader, device=device)
print('Test metric:{}'.format(test_metric))