多时间步节点拼接¶
多时间步节点用于拼接历史多个时间步的数据作为输入,以提高模型精度和鲁棒性。使用多时间步节点可以为节点增加更丰富的输入信息, 从而增加模型的捕捉能力和性能。以汽车驾驶任务为例,考虑到当前时刻的速度只包含有限的信息, 为了提高预测的准确性和鲁棒性,通常需要拼接多个时间步的速度信息。通过综合利用历史多个时间步的速度信息, 我们可以获取到汽车更丰富的速度信息,例如加速度、减速度等。这样,可以更精确地预测汽车的未来速度和行驶状态, 并提高车辆控制系统的精度和反馈能力。
因此,拼接历史多个时间步的数据作为输入通常对学习环境模型是很有帮助的。而REVIVE SDK提供了一种配置多时间步节点的简单方法。
下面使用一个示例展示如何进行多时间步节点的拼接.
metadata:
graph:
action:
- observation
next_observation:
- action
- observation
columns:
...
在下面的 .yaml
文件中,我们通过添加节点信息并将 observation
的 ts
属性配置为5,
以获得 observation
节点的历史5步拼帧数据。多时间步拼接后的节点名称为 ts_observation
。
ts_observation
节点将自动将历史5时间步的 observation
数据进行拼接。
Important
REVIVE SDK会自动检测存在 ts_
前缀的节点。用户应避免在自定义节点名时使用 ts_
前缀。此外,在上述例子中拼帧节点 ts_observation
可以类比为一个“长度为5的队列”, observation
在其中符合FIFO原则,最新的 observation
将会被拼在特征维度的最后。
metadata:
graph:
action:
- ts_observation
next_ts_observation:
- action
- ts_observation
columns:
...
nodes:
observation:
ts: 5
还可以增加专家函数以完成多时间步节点的转移函数。该方法虽然比较复杂,但是可以降低REVIVE学习转移节点的难度,提高虚拟环境模型精度。
metadata:
graph:
action:
- ts_observation
next_observation:
- ts_observation
- action
next_ts_observation:
- next_observation
- ts_observation
columns:
...
nodes:
observation:
ts: 5
expert_functions:
next_ts_observation:
'node_function' : 'expert_function.next_ts_observation_func'
import torch
from typing import Dict
def next_ts_observation_func(data: Dict[str, torch.Tensor]) -> torch.Tensor:
obs_dim = data["next_observation"].shape[-1]
next_obs = data["next_observation"]
ts_obs = data["ts_observation"]
next_ts_obs = torch.cat([ts_obs, next_obs], axis=-1)[..., obs_dim:]
return next_ts_obs