引入专家函数¶
引入专家函数可以为模型的训练和推断提供专业领域知识支持,从而提高模型的表达能力和鲁棒性,减少任务难度,并提高模型预测精度。 专家函数可以将专业知识在模型中进行嵌入或调用,使模型可以快速学习核心问题并减少模型的调参时间,从而降低了任务难度。 同时,使用预定义的专家函数会提高模型的精度,因为它们被经过精心证明和实践,并被用于生产环境中。
在机械控制任务中,通常使用 PID 控制器进行控制,这是一种较为常见的控制方法。可以将该控制方法定义为专家函数节点, 以将专业领域的知识引入到模型学习中,以加速模型训练并提高模型精度。在自动驾驶任务中,交通规则是一个非常重要的部分。 可以将交通规则定义为专家函数节点。
REVIVE SDK支持引入专家函数。在构建复杂的虚拟环境时,专家知识非常有用。REVIVE SDK支持用户自定义专家函数 expert_function
从而将专家知识引入到虚拟环境模型学习中。
如果我们知道决策流图中某个节点的计算方法,那么就可以使用关键字 expert_function
将其定义为专家函数节点。
下面是一个专家函数的定义示例:
metadata:
graph:
action:
- obs
next_obs:
- obs
- action
expert_functions:
next_obs:
'node_function' : 'dynamics.transition'
...
...
专家函数的配置通常是 node_function:<filename>.<function_name>
。
上述的 .yaml
文件为 next_obs
节点引入了名为 transition
的专家函数,专家函数应该定义在同级目录的 dynamics.py
文件中。
对应的 transition
函数源代码如下:
import torch
from typing import Dict
def transition(data: Dict[str, torch.Tensor]) -> torch.Tensor:
obs = data["obs"]
return obs + 1
如果某节点上没有绑定对应的专家函数,那么REVIVE将通过神经网络自动对该节点进行建模。
Note
.yaml
文件中 graph
的任意输出节点都支持绑定专家函数,但是至少需要存在一个可学习的节点(即通过神经网络初始化的节点)。
下面是一个绑定多个专家函数节点,其中一个专家函数的输出可以作为另一个专家函数的输入:
metadata:
graph:
action:
- obs_1
- obs_2
next_obs_1:
- obs_1
- obs_2
- action
next_obs_2:
- obs_1
- obs_2
- action
- next_obs_1
expert_functions:
next_obs_1:
'node_function' : 'dynamics.transition_1'
next_obs_2:
'node_function' : 'dynamics.transition_2'
...
...
需要注意,在使用专家函数对数据进行处理时,通常会将多个数据按批量(batch)组织起来进行一次性运算处理。这种方式可以提高代码的运行效率。
因此,在编写奖励函数时,需要注意保证函数能够处理与输入张量形状相对应的多维数据。此外,在计算专家函数输出时,我们通常会关注最后一维的特征维度。
为方便处理,专家函数的计算维度通常都设在了最后一维。因此,在使用数据时需要使用切片([..., n:m ]
)的方式获取数据的最后一维的特征,
并对特征进行计算。
专家函数的输出应该是一个对应的Pytorch Tensor,其batch维度保持和输入数据一致, 最后一维特征的维度应该与 *.yaml
文件中该节点的定义一致。