自定义网络节点¶
有时候,用户希望自定义决策流图中某些网络节点的网络。REVIVE SDK 提供了这样的功能。 我们以 摆杆 为例,我们希望摆杆的 动作节点 使用自定义网络节点。 下面介绍如何使用这个功能:
1. 在 .yaml
文件中配置 custom_nodes
:¶
metadata:
graph:
actions:
- states
next_states:
- states
- actions
columns:
- obs_states_0:
dim: states
type: continuous
- obs_states_1:
dim: states
type: continuous
- obs_states_2:
dim: states
type: continuous
- action:
dim: actions
type: continuous
custom_nodes:
actions: 'custom_node_file.ActionNode'
自定义网络节点的配置通常是 <node_name>: <filename>.<node_class_name>
。
上述的 .yaml
文件为 actions
节点使用自定义网络节点 ActionNode
,
自定义网络节点应该定义在同级目录的 custom_node_file.py
文件中。
此时 data
目录结构如下:
data/
|-- Env-GAIL-pendulum_custom_node.yaml
|-- config.json
|-- custom_node_file.py
|-- expert_data.npz
`-- pendulum-reward.py
2. 编辑自定义网络节点文件:¶
我们还需要编辑 custom_node_file.py
文件,如下所示:
import torch
from torch import nn
from revive.computation.modules import DistributionWrapper, ReviveDistribution
from revive.computation.graph import NetworkDecisionNode
class Net(torch.nn.Module):
def __init__(self,
in_features : dict,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list):
super().__init__()
"""
in_features: Dict.
- key for input_names, value for dimension of corresponding input.
- In general: in_features = {'obs_node_1': obs_node_1_dim, 'obs_node_2': obs_node_2_dim, ...}
- e.g. In Pendulum example: in_features = {'states': 3}
- It enables you to access to dimensions of each input.
"""
# ====================== edit network in here ======================
in_features_dims = sum(in_features.values())
net = []
for i in range(hidden_layers):
net.append(nn.Linear(in_features_dims if i == 0 else hidden_features, hidden_features))
net.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
net.append(nn.Linear(hidden_features, out_features))
net.append(nn.Identity())
self.net = nn.Sequential(*net)
# =================================================================
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config) # Do NOT modify it !
def forward(self, state : dict, adapt_std : torch.Tensor = None, **kwargs) -> ReviveDistribution:
#========== edit forward method in here =========
x = state['states']
output = self.net(x)
# ===============================================
dist = self.dist_wrapper(output, adapt_std) # Do NOT modify it !
return dist
# used to reset necessary network variables: e.g. RNN hidden states
def reset(self):
pass
class ActionNode(NetworkDecisionNode):
custom_node = True # Do NOT modify it !
def initialize_network(self,
inputs: dict,
output_dim: int,
dist_config: list,
*args, **kwargs):
# self.network is a essential property
self.network = Net(inputs, output_dim,
hidden_features=256, hidden_layers=2,
dist_config=dist_config)
在编辑自定义网络节点文件时,需要尤为注意。通常我们会定义 2 个类: 网络类
, 节点类
。
如上所示, 网络类 Net
主要实现 网络结构
与 forward
, reset
方法。
节点类 ActionNode
主要实现 网络定义
以及其他维护节点的方法。
需要注意的是:
网络结构
中in_feature
是python 字典类型
。key
代表输入节点的名字,value
代表输入节点的维度。比如 摆杆 例子中:in_features = {'states': 3}
。这样我们可以针对不同输入节点,做不同处理。网络结构
中dist_wrapper
是为了把输出包装成分布
,这是 REVIVE 的内部结构, 用户不要更改!同样,在forward
方法中,用户也不要更改dist
相关部分。reset
方法是为了必要时候重置网络中的某些变量,比如 RNN 的隐变量。节点类
中用户需要增加额外属性custom_node=True
。只有这样,inputs
才是python 字典类型
。
Note
更多信息用户可以参考 revive/revive/computation/graph.py
与 revive/revive/computation/modules.py
。