Loss function of user-defined node

When training the virtual environment, the REVIVE SDK supports configuring a customized Loss function for each node to constrain the learned node model. The following shows an example of configuring a custom Loss function for a node through a custom module.

First, you need to define the Loss function through the user-defined Module file ( user_module.py ).

import torch
from typing import Dict

def mae_loss(kwargs) -> torch.Tensor:
    # 获得当前节点名
    node_name = kwargs["node_name"]
    # 节点网络输出的分布
    node_dist = kwargs["node_dist"]
    # 当前节点对应的专家数据
    expert_data = kwargs["expert_data"]
    # 获得决策流图
    graph = kwargs["graph"]

    # get network output data -> node_dist.mode
    # get node expert data -> expert_data[node_name]
    # reverse normalization data -> graph.nodes[node_name].processor.deprocess_torch({node_name:expert_data[node_name]})
    policy_loss = (node_dist.mode - expert_data[node_name]).abs().sum(dim=-1).mean()

    return policy_loss

Then configure the Loss function to be used by the node in the .yaml file. The configured function name= user_module. + The function name defined in the custom module .

metadata:
   graph:
   act:
   - obs
   next_obs:
   - obs
   - act
   rew:
   - obs
   - act
   - next_obs
   expert_functions:
   rew:
       'node_function' : 'expert_function.reward_node_function'
   nodes:
   act:
       loss_type: 'user_module.mae_loss'

Finally, you need to use the -umf parameter to specify a custom Module file during training. An example is as follows:

python train.py -df test_data.npz -cf test.yaml -umf user_module.py -rf test_reward.py -vm once -pm once --run_id once