自定义网络节点 ============================== 有时候,用户希望自定义决策流图中某些网络节点的网络。REVIVE SDK 提供了这样的功能。 我们以 :doc:`摆杆<../task_examples/Use_revive_to_play_pendulum_game_cn>` 为例,我们希望摆杆的 **动作节点** 使用自定义网络节点。 下面介绍如何使用这个功能: 1. 在 ``.yaml`` 文件中配置 ``custom_nodes``: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: yaml metadata: graph: actions: - states next_states: - states - actions columns: - obs_states_0: dim: states type: continuous - obs_states_1: dim: states type: continuous - obs_states_2: dim: states type: continuous - action: dim: actions type: continuous custom_nodes: actions: 'custom_node_file.ActionNode' 自定义网络节点的配置通常是 ``: .`` 。 上述的 ``.yaml`` 文件为 ``actions`` 节点使用自定义网络节点 ``ActionNode``, 自定义网络节点应该定义在同级目录的 ``custom_node_file.py`` 文件中。 此时 ``data`` 目录结构如下:: data/ |-- Env-GAIL-pendulum_custom_node.yaml |-- config.json |-- custom_node_file.py |-- expert_data.npz `-- pendulum-reward.py 2. 编辑自定义网络节点文件: ~~~~~~~~~~~~~~~~~~~~~~~ 我们还需要编辑 ``custom_node_file.py`` 文件,如下所示: .. code:: python import torch from torch import nn from revive.computation.modules import DistributionWrapper, ReviveDistribution from revive.computation.graph import NetworkDecisionNode class Net(torch.nn.Module): def __init__(self, in_features : dict, out_features : int, hidden_features : int, hidden_layers : int, dist_config : list): super().__init__() """ in_features: Dict. - key for input_names, value for dimension of corresponding input. - In general: in_features = {'obs_node_1': obs_node_1_dim, 'obs_node_2': obs_node_2_dim, ...} - e.g. In Pendulum example: in_features = {'states': 3} - It enables you to access to dimensions of each input. """ # ====================== edit network in here ====================== in_features_dims = sum(in_features.values()) net = [] for i in range(hidden_layers): net.append(nn.Linear(in_features_dims if i == 0 else hidden_features, hidden_features)) net.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) net.append(nn.Linear(hidden_features, out_features)) net.append(nn.Identity()) self.net = nn.Sequential(*net) # ================================================================= self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config) # Do NOT modify it ! def forward(self, state : dict, adapt_std : torch.Tensor = None, **kwargs) -> ReviveDistribution: #========== edit forward method in here ========= x = state['states'] output = self.net(x) # =============================================== dist = self.dist_wrapper(output, adapt_std) # Do NOT modify it ! return dist # used to reset necessary network variables: e.g. RNN hidden states def reset(self): pass class ActionNode(NetworkDecisionNode): custom_node = True # Do NOT modify it ! def initialize_network(self, inputs: dict, output_dim: int, dist_config: list, *args, **kwargs): # self.network is a essential property self.network = Net(inputs, output_dim, hidden_features=256, hidden_layers=2, dist_config=dist_config) 在编辑自定义网络节点文件时,需要尤为注意。通常我们会定义 2 个类: ``网络类``, ``节点类``。 如上所示, ``网络类 Net`` 主要实现 ``网络结构`` 与 ``forward``, ``reset`` 方法。 ``节点类 ActionNode`` 主要实现 ``网络定义`` 以及其他维护节点的方法。 需要注意的是: 1. ``网络结构`` 中 ``in_feature`` 是 ``python 字典类型``。 ``key`` 代表输入节点的名字, ``value`` 代表输入节点的维度。比如 :doc:`摆杆<../task_examples/Use_revive_to_play_pendulum_game_cn>` 例子中: ``in_features = {'states': 3}``。这样我们可以针对不同输入节点,做不同处理。 2. ``网络结构`` 中 ``dist_wrapper`` 是为了把输出包装成 ``分布``,这是 REVIVE 的内部结构, 用户不要更改!同样,在 ``forward`` 方法中,用户也不要更改 ``dist`` 相关部分。 3. ``reset`` 方法是为了必要时候重置网络中的某些变量,比如 RNN 的隐变量。 4. ``节点类`` 中用户需要增加额外属性 ``custom_node=True``。只有这样, ``inputs`` 才是 ``python 字典类型``。 .. note:: 更多信息用户可以参考 ``revive/revive/computation/graph.py`` 与 ``revive/revive/computation/modules.py``。