Loss function of user-defined node¶
When training the virtual environment, the REVIVE SDK supports configuring a customized Loss function for each node to constrain the learned node model. The following shows an example of configuring a custom Loss function for a node through a custom module.
First, you need to define the Loss function through the user-defined Module file ( user_module.py
).
import torch
from typing import Dict
def mae_loss(kwargs) -> torch.Tensor:
# 获得当前节点名
node_name = kwargs["node_name"]
# 节点网络输出的分布
node_dist = kwargs["node_dist"]
# 当前节点对应的专家数据
expert_data = kwargs["expert_data"]
# 获得决策流图
graph = kwargs["graph"]
# get network output data -> node_dist.mode
# get node expert data -> expert_data[node_name]
# reverse normalization data -> graph.nodes[node_name].processor.deprocess_torch({node_name:expert_data[node_name]})
policy_loss = (node_dist.mode - expert_data[node_name]).abs().sum(dim=-1).mean()
return policy_loss
Then configure the Loss function to be used by the node in the .yaml
file. The configured function name= user_module.
+ The function name defined in the custom module
.
metadata:
graph:
act:
- obs
next_obs:
- obs
- act
rew:
- obs
- act
- next_obs
expert_functions:
rew:
'node_function' : 'expert_function.reward_node_function'
nodes:
act:
loss_type: 'user_module.mae_loss'
Finally, you need to use the -umf
parameter to specify a custom Module file during training. An example is as follows:
python train.py -df test_data.npz -cf test.yaml -umf user_module.py -rf test_reward.py -vm once -pm once --run_id once