自定义节点的损失函数

REVIVE SDK在训练虚拟环境时,支持为每个节点配置自定义的损失函数以约束学习到的节点模型。下面展示一个 示例通过自定义模块为节点配置自定义损失函数。

首先需要通过自定义模块文件( user_module.py )定义损失函数。

import torch
from typing import Dict

def mae_loss(kwargs) -> torch.Tensor:
    # 获得当前节点名
    node_name = kwargs["node_name"]
    # 节点网络输出的分布
    node_dist = kwargs["node_dist"]
    # 当前节点对应的专家数据
    expert_data = kwargs["expert_data"]
    # 获得决策流图
    graph = kwargs["graph"]

    # get network output data -> node_dist.mode
    # get node expert data -> expert_data[node_name]
    # reverse normalization data -> graph.nodes[node_name].processor.deprocess_torch({node_name:expert_data[node_name]})
    policy_loss = (node_dist.mode - expert_data[node_name]).abs().sum(dim=-1).mean()

    return policy_loss

然后在 .yaml 文件中配置节点要使用的损失函数,配置的函数名 = user_module. + 自定义模块中定义的函数名称

metadata:
   graph:
   act:
   - obs
   next_obs:
   - obs
   - act
   rew:
   - obs
   - act
   - next_obs
   expert_functions:
   rew:
       'node_function' : 'expert_function.reward_node_function'
   nodes:
   act:
       loss_type: 'user_module.mae_loss'

最后在训练时需要使用 -umf 参数指定自定义模块文件,示例如下:

python train.py -df test_data.npz -cf test.yaml -umf user_module.py -rf test_reward.py -vm once -pm once --run_id once