Customize Network Decision Nodes

Sometimes, users want to customize the network of certain network decision nodes in the decision flow. REVIVE SDK provides such a function. Let’s take Pendulum as an example, we want the action node to be a custom network node. The following shows how to use this function:

1. Configure custom_nodes in .yaml file:

metadata:

  graph:
    actions:
    - states
    next_states:
    - states
    - actions

  columns:
  - obs_states_0:
      dim: states
      type: continuous
  - obs_states_1:
      dim: states
      type: continuous
  - obs_states_2:
      dim: states
      type: continuous
  - action:
      dim: actions
      type: continuous

  custom_nodes:
    actions: 'custom_node_file.ActionNode'

The configuration of .yaml file for a custom network node is usually <node_name>: <filename>.<node_class_name>. The above .yaml file uses a custom network node ActionNode for the actions node. Custom network nodes should be defined in the file custom_node_file.py in the same directory. At this point, the data directory structure is as follows:

data/
|-- Env-GAIL-pendulum_custom_node.yaml
|-- config.json
|-- custom_node_file.py
|-- expert_data.npz
`-- pendulum-reward.py

2. Edit Custom Network Node file:

We also need to edit the custom_node_file.py file as follows:

import torch
from torch import nn
from revive.computation.modules import DistributionWrapper, ReviveDistribution
from revive.computation.graph import NetworkDecisionNode


class Net(torch.nn.Module):
    def __init__(self,
                in_features : dict,
                out_features : int,
                hidden_features : int,
                hidden_layers : int,
                dist_config : list):
        super().__init__()
        """
        in_features: Dict.
            - key for input_names, value for dimension of corresponding input.
            - In general: in_features = {'obs_node_1': obs_node_1_dim, 'obs_node_2': obs_node_2_dim, ...}
            - e.g. In Pendulum example: in_features = {'states': 3}
            - It enables you to access to dimensions of each input.
        """

        # ====================== edit network in here ======================
        in_features_dims = sum(in_features.values())
        net = []
        for i in range(hidden_layers):
            net.append(nn.Linear(in_features_dims if i == 0 else hidden_features, hidden_features))
            net.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
        net.append(nn.Linear(hidden_features, out_features))
        net.append(nn.Identity())
        self.net = nn.Sequential(*net)
        # =================================================================

        self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)  # Do NOT modify it !

    def forward(self, state : dict, adapt_std : torch.Tensor = None, **kwargs) -> ReviveDistribution:
        #========== edit forward method in here =========
        x = state['states']
        output = self.net(x)
        # ===============================================

        dist = self.dist_wrapper(output, adapt_std)  # Do NOT modify it !
        return dist

    # used to reset necessary network variables: e.g. RNN hidden states
    def reset(self):
        pass


class ActionNode(NetworkDecisionNode):
    custom_node = True  # Do NOT modify it !

    def initialize_network(self,
                          inputs: dict,
                          output_dim: int,
                          dist_config: list,
                          *args, **kwargs):

        # self.network is a essential property
        self.network = Net(inputs, output_dim,
                            hidden_features=256, hidden_layers=2,
                            dist_config=dist_config)

We should pay attention when editing custom network node files. Usually we will define 2 classes: network class, node class. As shown above, network class Net mainly implements network structure and forward, reset methods. The node class ActionNode mainly implements network definition and other methods for maintaining nodes.

We should note:

  1. in_feature in network structure is python dictionary type. key represents the name of the input node, value represents the dimension of the input node. E.g. In Pendulum, in_features = {'states': 3}. In this way, we can do different processing for different input nodes.

  2. dist_wrapper in network structure is to wrap the output into distribution, which is the internal structure of REVIVE. Users don’t have to change it! Also, in the forward method, the user should not change the relevant parts of dist.

  3. The reset method is to reset some variables in the network when necessary, such as the hidden variables of RNN.

  4. In the node class, the user needs to add an additional attribute custom_node=True. Only then inputs is a python dictionary type.

Note

For more information, users can refer to revive/revive/computation/graph.py and revive/revive/computation/modules.py.