冻结部分节点的网络参数

在虚拟环境训练过程中,可以冻结某些节点的网络参数不进行梯度更新。这在加载现有环境模型时非常有用。

例如在机器人控制任务上,假设我们已经在一个A型号的机器人任务上训练完成了一个较好的虚拟环境模型, 现在我们希望快速将A型号的虚拟环境模型迁移到新的B型号机器人任务上,两个型号的机器人系统共用了大部分相同的机械结构, 但是也存在一些不同。这时我们可以将A型号机器人的虚拟环境模型作为预训练模型,然后通过参数冻结功能将模型中与A型号机器人相同的部分保护起来, 这样训练过程中就不会破坏这些节点的参数设置了。 然后,我们可以在B型号机器人上继续训练,只对那些与A型号机器人不同的部分进行微调或者重新学习。 这样做的好处是可以节省训练时间和计算资源,同时也可以在一定程度上避免过拟合,因为预训练模型已经学会了一些通用的特征, 可以避免从头开始训练过程中的过拟合现象。同时,由于模型已经预训练过,所以迁移学习过程中需要拟合的新数据集的大小不需要太大。

需要注意的是,在进行参数冻结时,我们需要根据具体问题自行决定哪些节点需要被冻结,哪些节点需要重新训练。 如果冻结太多的节点,那么新任务的机器人可能无法得到足够的学习;如果冻结太少的节点, 那么新任务的机器人可能无法发挥预训练模型的优势,训练时间也会很长。因此,在具体操作时需要合理权衡。

下面通过这个示例展示如何编写 .yaml 文件来实现此功能。

metadata:

   graph:
     action:
     - observation
     next_observation:
     - action
     - observation

   columns:
   ...

在下面的 .yaml 示例中,我们将 action 节点的 freeze 属性配置为 true 。在训练期间, action 节点的网络参数将不会更新。

metadata:

   graph:
     action:
     - observation
     next_observation:
     - action
     - observation

   columns:
   ...

   nodes:
     action:
       freeze: true

Important

不能同时冻结所有网络节点。至少需要存在一个网络节点可以进行梯度更新,否则REVIVE将报错。