Multi-Time-Steps Data Nodes¶
The multi-time-step node is used to concatenate data from multiple historical time steps as input to improve model accuracy and robustness. Using a multi-time-step node can provide richer input information to the node, thereby increasing the model’s capture ability and performance. Taking the example of a driving task, considering that the current speed only contains limited information, to improve prediction accuracy and robustness, it is usually necessary to concatenate speed information from multiple time steps. By comprehensively utilizing speed information from multiple historical time steps, we can obtain richer speed information of the car, such as acceleration, deceleration, etc. This way, we can more accurately predict the car’s future speed and driving status and improve the precision and feedback capability of the vehicle control system.
Therefore, concatenating data from multiple historical time steps as input is usually helpful for learning environment models. The REVIVE SDK provides a simple method for configuring multi-time-step nodes.
The following example shows how to concatenate data from multiple time steps using a multi-time-step node.
metadata:
graph:
action:
- observation
next_observation:
- action
- observation
columns:
...
In the followings, we get the multi-time-step data nodes for the observation
node by adding
node information and configuring the ts
property of observation to 5. The name of the multi-time
steps node is ts_observation
. The ts_observation
node will automatically stitch
the historical 5-time steps data to the observation
.
The ts_observation
node should be defined in the construction of the .yaml
file of the decision flow.
Important
ts_
is the only prefix that REVIVE recognizes which nodes should use the Multi-Time-Steps function. Besides, in the above example, ts_observation
can be analogized to a “queue with length 5”, which means it obeys FIFO principle. That is, the lastest observation
would be appended at the last in feature dimension.
metadata:
graph:
action:
- ts_observation
next_ts_observation:
- action
- ts_observation
columns:
...
nodes:
observation:
ts: 5
In some specialized tasks, the output of the action node needs to take into account the influence of historical actions. Therefore, we also need to perform frame concatenation on the action node:
metadata:
graph:
action:
- ts_action
- ts_observation
next_ts_action:
- action
- ts_action
next_ts_observation:
- next_ts_action
- ts_observation
columns:
...
nodes:
observation:
ts: 5
ts_repeat: false
action:
ts: 5
endpoint: 5
ts_repeat: false
In the above example, we have added a new ts_action node. It is important to note that we have also set the endpoint property. The endpoint is used to control the end position of the data in the ts node. By default, if the endpoint is not set, the system will automatically use ts+1 to retrieve data from historical concatenated time steps as well as the current time step. However, in this case, we need to predict the current action, so by setting the endpoint property, we ensure that ts_action contains only historical action data.
One can also choose the following method for using multi-time-step data nodes. The method can reduce the difficulty for REVIVE in learning a network but is a little bit complex.
metadata:
graph:
action:
- ts_observation
next_observation:
- ts_observation
- action
next_ts_observation:
- next_observation
- ts_observation
columns:
...
nodes:
observation:
ts: 5
expert_functions:
next_ts_observation:
'node_function' : 'expert_function.next_ts_observation_func'
Here, next_observation
is the node for the learning environment transition.
For example, if the observation
in data contains 3 dimensions, after appending 5 frames
by REVIVE, ts_observation
will include 3*5 dimensions. So, we use next_observation
to limit the sizes of the output from 15 to 3, which will reduce the difficulty for REVIVE. Since now, it only has to
learn a network with 3-dimensional output rather than a 15-dimensional output.
So, the process of adding 5 frames is handed over to the node next_ts_observation
,
which is achieved by attaching an expert function as follows.
import torch
from typing import Dict
def next_ts_observation_func(data: Dict[str, torch.Tensor]) -> torch.Tensor:
obs_dim = data["next_observation"].shape[-1] # data["next_observation"] is of shape (batch_size, obs_dim)
next_obs = data["next_observation"]
ts_obs = data["ts_observation"] # data["ts_observation"] is of shape (batch_size, 5 * obs_dim)
next_ts_obs = torch.cat([ts_obs, next_obs], axis=-1)[..., obs_dim:]
return next_ts_obs