自定义节点的损失函数¶
REVIVE SDK在训练虚拟环境时,支持为每个节点配置自定义的损失函数以约束学习到的节点模型。下面展示一个 示例通过自定义模块为节点配置自定义损失函数。
首先需要通过自定义模块文件( user_module.py
)定义损失函数。
import torch
from typing import Dict
def mae_loss(kwargs) -> torch.Tensor:
# 获得当前节点名
node_name = kwargs["node_name"]
# 节点网络输出的分布
node_dist = kwargs["node_dist"]
# 当前节点对应的专家数据
expert_data = kwargs["expert_data"]
# 获得决策流图
graph = kwargs["graph"]
# get network output data -> node_dist.mode
# get node expert data -> expert_data[node_name]
# reverse normalization data -> graph.nodes[node_name].processor.deprocess_torch({node_name:expert_data[node_name]})
policy_loss = (node_dist.mode - expert_data[node_name]).abs().sum(dim=-1).mean()
return policy_loss
然后在 .yaml
文件中配置节点要使用的损失函数,配置的函数名 = user_module.
+ 自定义模块中定义的函数名称
。
metadata:
graph:
act:
- obs
next_obs:
- obs
- act
rew:
- obs
- act
- next_obs
expert_functions:
rew:
'node_function' : 'expert_function.reward_node_function'
nodes:
act:
loss_type: 'user_module.mae_loss'
最后在训练时需要使用 -umf
参数指定自定义模块文件,示例如下:
python train.py -df test_data.npz -cf test.yaml -umf user_module.py -rf test_reward.py -vm once -pm once --run_id once