Customize Network Decision Nodes¶
Sometimes, users want to customize the network of certain network decision nodes in the decision flow. REVIVE SDK provides such a function. Let’s take Pendulum as an example, we want the action node to be a custom network node. The following shows how to use this function:
1. Configure custom_nodes
in .yaml
file:¶
metadata:
graph:
actions:
- states
next_states:
- states
- actions
columns:
- obs_states_0:
dim: states
type: continuous
- obs_states_1:
dim: states
type: continuous
- obs_states_2:
dim: states
type: continuous
- action:
dim: actions
type: continuous
custom_nodes:
actions: 'custom_node_file.ActionNode'
The configuration of .yaml
file for a custom network node is usually <node_name>: <filename>.<node_class_name>
.
The above .yaml
file uses a custom network node ActionNode
for the actions
node.
Custom network nodes should be defined in the file custom_node_file.py
in the same directory.
At this point, the data
directory structure is as follows:
data/
|-- Env-GAIL-pendulum_custom_node.yaml
|-- config.json
|-- custom_node_file.py
|-- expert_data.npz
`-- pendulum-reward.py
2. Edit Custom Network Node file:¶
We also need to edit the custom_node_file.py
file as follows:
import torch
from torch import nn
from revive.computation.modules import DistributionWrapper, ReviveDistribution
from revive.computation.graph import NetworkDecisionNode
class Net(torch.nn.Module):
def __init__(self,
in_features : dict,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list):
super().__init__()
"""
in_features: Dict.
- key for input_names, value for dimension of corresponding input.
- In general: in_features = {'obs_node_1': obs_node_1_dim, 'obs_node_2': obs_node_2_dim, ...}
- e.g. In Pendulum example: in_features = {'states': 3}
- It enables you to access to dimensions of each input.
"""
# ====================== edit network in here ======================
in_features_dims = sum(in_features.values())
net = []
for i in range(hidden_layers):
net.append(nn.Linear(in_features_dims if i == 0 else hidden_features, hidden_features))
net.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
net.append(nn.Linear(hidden_features, out_features))
net.append(nn.Identity())
self.net = nn.Sequential(*net)
# =================================================================
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config) # Do NOT modify it !
def forward(self, state : dict, adapt_std : torch.Tensor = None, **kwargs) -> ReviveDistribution:
#========== edit forward method in here =========
x = state['states']
output = self.net(x)
# ===============================================
dist = self.dist_wrapper(output, adapt_std) # Do NOT modify it !
return dist
# used to reset necessary network variables: e.g. RNN hidden states
def reset(self):
pass
class ActionNode(NetworkDecisionNode):
custom_node = True # Do NOT modify it !
def initialize_network(self,
inputs: dict,
output_dim: int,
dist_config: list,
*args, **kwargs):
# self.network is a essential property
self.network = Net(inputs, output_dim,
hidden_features=256, hidden_layers=2,
dist_config=dist_config)
We should pay attention when editing custom network node files.
Usually we will define 2 classes: network class
, node class
.
As shown above, network class Net
mainly implements network structure
and forward
, reset
methods.
The node class ActionNode
mainly implements network definition
and other methods for maintaining nodes.
We should note:
in_feature
innetwork structure
ispython dictionary type
.key
represents the name of the input node,value
represents the dimension of the input node. E.g. In Pendulum,in_features = {'states': 3}
. In this way, we can do different processing for different input nodes.dist_wrapper
innetwork structure
is to wrap the output intodistribution
, which is the internal structure of REVIVE. Users don’t have to change it! Also, in theforward
method, the user should not change the relevant parts ofdist
.The
reset
method is to reset some variables in the network when necessary, such as the hidden variables of RNN.In the
node class
, the user needs to add an additional attributecustom_node=True
. Only theninputs
is apython dictionary type
.
Note
For more information, users can refer to revive/revive/computation/graph.py
and revive/revive/computation/modules.py
.