自定义网络节点

有时候,用户希望自定义决策流图中某些网络节点的网络。REVIVE SDK 提供了这样的功能。 我们以 摆杆 为例,我们希望摆杆的 动作节点 使用自定义网络节点。 下面介绍如何使用这个功能:

1. 在 .yaml 文件中配置 custom_nodes

metadata:

  graph:
    actions:
    - states
    next_states:
    - states
    - actions

  columns:
  - obs_states_0:
      dim: states
      type: continuous
  - obs_states_1:
      dim: states
      type: continuous
  - obs_states_2:
      dim: states
      type: continuous
  - action:
      dim: actions
      type: continuous

  custom_nodes:
    actions: 'custom_node_file.ActionNode'

自定义网络节点的配置通常是 <node_name>: <filename>.<node_class_name> 。 上述的 .yaml 文件为 actions 节点使用自定义网络节点 ActionNode, 自定义网络节点应该定义在同级目录的 custom_node_file.py 文件中。 此时 data 目录结构如下:

data/
|-- Env-GAIL-pendulum_custom_node.yaml
|-- config.json
|-- custom_node_file.py
|-- expert_data.npz
`-- pendulum-reward.py

2. 编辑自定义网络节点文件:

我们还需要编辑 custom_node_file.py 文件,如下所示:

import torch
from torch import nn
from revive.computation.modules import DistributionWrapper, ReviveDistribution
from revive.computation.graph import NetworkDecisionNode


class Net(torch.nn.Module):
    def __init__(self,
                in_features : dict,
                out_features : int,
                hidden_features : int,
                hidden_layers : int,
                dist_config : list):
        super().__init__()
        """
        in_features: Dict.
            - key for input_names, value for dimension of corresponding input.
            - In general: in_features = {'obs_node_1': obs_node_1_dim, 'obs_node_2': obs_node_2_dim, ...}
            - e.g. In Pendulum example: in_features = {'states': 3}
            - It enables you to access to dimensions of each input.
        """

        # ====================== edit network in here ======================
        in_features_dims = sum(in_features.values())
        net = []
        for i in range(hidden_layers):
            net.append(nn.Linear(in_features_dims if i == 0 else hidden_features, hidden_features))
            net.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
        net.append(nn.Linear(hidden_features, out_features))
        net.append(nn.Identity())
        self.net = nn.Sequential(*net)
        # =================================================================

        self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)  # Do NOT modify it !

    def forward(self, state : dict, adapt_std : torch.Tensor = None, **kwargs) -> ReviveDistribution:
        #========== edit forward method in here =========
        x = state['states']
        output = self.net(x)
        # ===============================================

        dist = self.dist_wrapper(output, adapt_std)  # Do NOT modify it !
        return dist

    # used to reset necessary network variables: e.g. RNN hidden states
    def reset(self):
        pass


class ActionNode(NetworkDecisionNode):
    custom_node = True  # Do NOT modify it !

    def initialize_network(self,
                          inputs: dict,
                          output_dim: int,
                          dist_config: list,
                          *args, **kwargs):

        # self.network is a essential property
        self.network = Net(inputs, output_dim,
                            hidden_features=256, hidden_layers=2,
                            dist_config=dist_config)

在编辑自定义网络节点文件时,需要尤为注意。通常我们会定义 2 个类: 网络类节点类。 如上所示, 网络类 Net 主要实现 网络结构forward, reset 方法。 节点类 ActionNode 主要实现 网络定义 以及其他维护节点的方法。

需要注意的是:

  1. 网络结构in_featurepython 字典类型key 代表输入节点的名字, value 代表输入节点的维度。比如 摆杆 例子中: in_features = {'states': 3}。这样我们可以针对不同输入节点,做不同处理。

  2. 网络结构dist_wrapper 是为了把输出包装成 分布,这是 REVIVE 的内部结构, 用户不要更改!同样,在 forward 方法中,用户也不要更改 dist 相关部分。

  3. reset 方法是为了必要时候重置网络中的某些变量,比如 RNN 的隐变量。

  4. 节点类 中用户需要增加额外属性 custom_node=True。只有这样, inputs 才是 python 字典类型

Note

更多信息用户可以参考 revive/revive/computation/graph.pyrevive/revive/computation/modules.py