多时间步节点拼接

多时间步节点用于拼接历史多个时间步的数据作为输入,以提高模型精度和鲁棒性。使用多时间步节点可以为节点增加更丰富的输入信息, 从而增加模型的捕捉能力和性能。以汽车驾驶任务为例,考虑到当前时刻的速度只包含有限的信息, 为了提高预测的准确性和鲁棒性,通常需要拼接多个时间步的速度信息。通过综合利用历史多个时间步的速度信息, 我们可以获取到汽车更丰富的速度信息,例如加速度、减速度等。这样,可以更精确地预测汽车的未来速度和行驶状态, 并提高车辆控制系统的精度和反馈能力。

因此,拼接历史多个时间步的数据作为输入通常对学习环境模型是很有帮助的。而REVIVE SDK提供了一种配置多时间步节点的简单方法。

下面使用一个示例展示如何进行多时间步节点的拼接.

metadata:

   graph:
     action:
     - observation
     next_observation:
     - action
     - observation

     columns:
     ...

在下面的 .yaml 文件中,我们通过添加节点信息并将 observationts 属性配置为5, 以获得 observation 节点的历史5步拼帧数据。多时间步拼接后的节点名称为 ts_observationts_observation 节点将自动将历史5时间步的 observation 数据进行拼接。

Important

REVIVE SDK会自动检测存在 ts_ 前缀的节点。用户应避免在自定义节点名时使用 ts_ 前缀。此外,在上述例子中拼帧节点 ts_observation 可以类比为一个“长度为5的队列”, observation 在其中符合FIFO原则,最新的 observation 将会被拼在特征维度的最后。

metadata:

   graph:
     action:
     - ts_observation
     next_ts_observation:
     - action
     - ts_observation

   columns:
   ...

   nodes:
     observation:
       ts: 5

还可以增加专家函数以完成多时间步节点的转移函数。该方法虽然比较复杂,但是可以降低REVIVE学习转移节点的难度,提高虚拟环境模型精度。

metadata:

graph:
  action:
  - ts_observation
  next_observation:
  - ts_observation
  - action
  next_ts_observation:
  - next_observation
  - ts_observation

columns:
...

nodes:
  observation:
    ts: 5

expert_functions:
  next_ts_observation:
    'node_function' : 'expert_function.next_ts_observation_func'
import torch
from typing import Dict

def next_ts_observation_func(data: Dict[str, torch.Tensor]) -> torch.Tensor:
   obs_dim = data["next_observation"].shape[-1]
   next_obs = data["next_observation"]
   ts_obs = data["ts_observation"]
   next_ts_obs = torch.cat([ts_obs, next_obs], axis=-1)[..., obs_dim:]

   return next_ts_obs