自定义节点的损失函数 ===================== REVIVE SDK在训练虚拟环境时,支持为每个节点配置自定义的损失函数以约束学习到的节点模型。下面展示一个 示例通过自定义模块为节点配置自定义损失函数。 首先需要通过自定义模块文件( ``user_module.py`` )定义损失函数。 .. code:: python import torch from typing import Dict def mae_loss(kwargs) -> torch.Tensor: # 获得当前节点名 node_name = kwargs["node_name"] # 节点网络输出的分布 node_dist = kwargs["node_dist"] # 当前节点对应的专家数据 expert_data = kwargs["expert_data"] # 获得决策流图 graph = kwargs["graph"] # get network output data -> node_dist.mode # get node expert data -> expert_data[node_name] # reverse normalization data -> graph.nodes[node_name].processor.deprocess_torch({node_name:expert_data[node_name]}) policy_loss = (node_dist.mode - expert_data[node_name]).abs().sum(dim=-1).mean() return policy_loss 然后在 ``.yaml`` 文件中配置节点要使用的损失函数,配置的函数名 = ``user_module.`` + ``自定义模块中定义的函数名称``。 .. code:: yaml metadata: graph: act: - obs next_obs: - obs - act rew: - obs - act - next_obs expert_functions: rew: 'node_function' : 'expert_function.reward_node_function' nodes: act: loss_type: 'user_module.mae_loss' 最后在训练时需要使用 ``-umf`` 参数指定自定义模块文件,示例如下: :: python train.py -df test_data.npz -cf test.yaml -umf user_module.py -rf test_reward.py -vm once -pm once --run_id once