Multi-Time-Steps Data Nodes

The multi-time-step node is used to concatenate data from multiple historical time steps as input to improve model accuracy and robustness. Using a multi-time-step node can provide richer input information to the node, thereby increasing the model’s capture ability and performance. Taking the example of a driving task, considering that the current speed only contains limited information, to improve prediction accuracy and robustness, it is usually necessary to concatenate speed information from multiple time steps. By comprehensively utilizing speed information from multiple historical time steps, we can obtain richer speed information of the car, such as acceleration, deceleration, etc. This way, we can more accurately predict the car’s future speed and driving status and improve the precision and feedback capability of the vehicle control system.

Therefore, concatenating data from multiple historical time steps as input is usually helpful for learning environment models. The REVIVE SDK provides a simple method for configuring multi-time-step nodes.

The following example shows how to concatenate data from multiple time steps using a multi-time-step node.

metadata:

   graph:
     action:
     - observation
     next_observation:
     - action
     - observation

     columns:
     ...

In the followings, we get the multi-time-step data nodes for the observation node by adding node information and configuring the ts property of observation to 5. The name of the multi-time steps node is ts_observation. The ts_observation node will automatically stitch the historical 5-time steps data to the observation. The ts_observation node should be defined in the construction of the .yaml file of the decision flow.

Important

ts_ is the only prefix that REVIVE recognizes which nodes should use the Multi-Time-Steps function. Besides, in the above example, ts_observation can be analogized to a “queue with length 5”, which means it obeys FIFO principle. That is, the lastest observation would be appended at the last in feature dimension.

metadata:

   graph:
     action:
     - ts_observation
     next_ts_observation:
     - action
     - ts_observation

   columns:
   ...

   nodes:
     observation:
       ts: 5

In some specialized tasks, the output of the action node needs to take into account the influence of historical actions. Therefore, we also need to perform frame concatenation on the action node:

metadata:

   graph:
     action:
     - ts_action
     - ts_observation
     next_ts_action:
     - action
     - ts_action
     next_ts_observation:
     - next_ts_action
     - ts_observation
   columns:
   ...

   nodes:
     observation:
       ts: 5
       ts_repeat: false
     action:
       ts: 5
       endpoint: 5
       ts_repeat: false

In the above example, we have added a new ts_action node. It is important to note that we have also set the endpoint property. The endpoint is used to control the end position of the data in the ts node. By default, if the endpoint is not set, the system will automatically use ts+1 to retrieve data from historical concatenated time steps as well as the current time step. However, in this case, we need to predict the current action, so by setting the endpoint property, we ensure that ts_action contains only historical action data.

One can also choose the following method for using multi-time-step data nodes. The method can reduce the difficulty for REVIVE in learning a network but is a little bit complex.

metadata:

graph:
  action:
  - ts_observation
  next_observation:
  - ts_observation
  - action
  next_ts_observation:
  - next_observation
  - ts_observation

columns:
...

nodes:
  observation:
    ts: 5

expert_functions:
  next_ts_observation:
    'node_function' : 'expert_function.next_ts_observation_func'

Here, next_observation is the node for the learning environment transition. For example, if the observation in data contains 3 dimensions, after appending 5 frames by REVIVE, ts_observation will include 3*5 dimensions. So, we use next_observation to limit the sizes of the output from 15 to 3, which will reduce the difficulty for REVIVE. Since now, it only has to learn a network with 3-dimensional output rather than a 15-dimensional output. So, the process of adding 5 frames is handed over to the node next_ts_observation, which is achieved by attaching an expert function as follows.

import torch
from typing import Dict

def next_ts_observation_func(data: Dict[str, torch.Tensor]) -> torch.Tensor:
   obs_dim = data["next_observation"].shape[-1]  # data["next_observation"] is of shape (batch_size, obs_dim)
   next_obs = data["next_observation"]
   ts_obs = data["ts_observation"]  # data["ts_observation"] is of shape (batch_size, 5 * obs_dim)
   next_ts_obs = torch.cat([ts_obs, next_obs], axis=-1)[..., obs_dim:]

   return next_ts_obs