训练模型

当我们准备好训练数据集( .npz.h5 文件)、决策流图描述文件( .yaml ) 和奖励函数( reward.py )后。 我们可以使用 python train.py 命令开启虚拟环境模型和策略模型训练。该脚本将实例化 revive.server.ReviveServer 并开启训练。

训练脚本示例

python train.py -df <训练数据文件路径> -cf <决策流图文件路径> -rf <奖励函数文件路径> -vm <训练虚拟环境模式> -pm <训练策略模型模式> --run_id <训练实验名称>

运行 train.py 脚本可定义的命令行参数如下:

  • -df: 训练数据的文件路径( .npz.h5 文件)。

  • -vf: 验证数据的文件路径(可选)。

  • -cf: 决策流图的文件路径( .yaml )。

  • -rf: 定义的奖励函数的文件路径( reward.py )(仅在训练策略时需要)。

  • -rcf: 支持进行超参配置的 .json 文件(可选)。

  • -tpn: 策略节点的名称。必须是决策流图中定义的节点;如果未指定,在默认情况下,排在拓扑顺序第一位的节点将作为策略节点。

  • -vm: 训练虚拟环境的不同模式, 包括: once,tune,None

    • once 模式: REVIVE SDK将使用默认参数训练模型。

    • tune 模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。

    • None 模式: REVIVE SDK不会训练虚拟环境,它适用于调用已有虚拟环境进行策略训练。

  • -pm: 策略模型的训练模式, 包括: once,tune,None

    • once 模式: REVIVE SDK将使用默认参数训练模型。

    • tune 模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。

    • None 模式: REVIVE SDK不会训练策略,它适用于只训练虚拟环境而不进行策略训练的情况。

  • --run_id: 用户为训练实验提供的名称。REVIVE将创建 logs/<run_id> 作为日志目录。如果未提供,REVIVE将随机生成名称。

训练虚拟环境和策略模型

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm once -pm once --run_id once

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm tune -pm tune --run_id tune

只训练虚拟环境

python train.py -df test_data.npz -cf test.yaml -vm once -pm None --run_id venv_once

python train.py -df test_data.npz -cf test.yaml -vm tune -pm None --run_id venv_tune

Note

需要定义的奖励函数仅在策略训练的时候需要被提供。

在已有的环境模型基础上训练策略模型

Important

当单独训练策略时,REVIVE SDK将根据索引 run_id 查找完成训练的虚拟环境模型。

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm once --run_id venv_once

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm tune --run_id venv_tune

Note

train.py 是一个调用REVIVE SDK进行模型训练的通用脚本。 我们也可以基于 recovery.server.ReviveServer 类进行自定义训练方法的开发。

请参阅API中的 revive.server.ReviveServerREVIVE API ) 了解更多信息。

在训练过程中,我们可以随时使用tensorboard打开日志目录以监控训练过程。当REVIVE SDK完成虚拟环境模型训练和策略模型训练后。 我们可以在日志文件夹( logs/<run_id>)下找到保存的模型( .pkl.onnx)。

REVIVE提供了一个超参数调优工具,使用 tune 模式开始训练将切换到超参调优模式。 在该模式下,REVIVE将从预设超参数空间中采样多组超参数用于模型训练,通常使用超参调优模式可以获得更好的模型。 有关超参调优模式的详细说明,请参阅文档中的REVIVE API(revive.conf)部分。 我们也可以通过修改 config.json 中的相关配置来调整超参搜索空间和搜索方法。

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -rcf config.json --run_id test

Note

示例数据 test_data.npz, test.yaml, test_reward.pyconfig.json 存储在 revive/data 文件夹中。