训练模型¶
当我们准备好训练数据集( .npz
或 .h5
文件)、决策流图描述文件( .yaml
) 和奖励函数( reward.py
)后。
我们可以使用 python train.py
命令开启虚拟环境模型和策略模型训练。该脚本将实例化 revive.server.ReviveServer
并开启训练。
训练脚本示例
python train.py -df <训练数据文件路径> -cf <决策流图文件路径> -rf <奖励函数文件路径> -vm <训练虚拟环境模式> -pm <训练策略模型模式> --run_id <训练实验名称>
运行 train.py
脚本可定义的命令行参数如下:
-df: 训练数据的文件路径(
.npz
或.h5
文件)。-vf: 验证数据的文件路径(可选)。
-cf: 决策流图的文件路径(
.yaml
)。-rf: 定义的奖励函数的文件路径(
reward.py
)(仅在训练策略时需要)。-rcf: 支持进行超参配置的
.json
文件(可选)。-tpn: 策略节点的名称。必须是决策流图中定义的节点;如果未指定,在默认情况下,排在拓扑顺序第一位的节点将作为策略节点。
-vm: 训练虚拟环境的不同模式, 包括:
once
,tune
,None
。once
模式: REVIVE SDK将使用默认参数训练模型。tune
模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。None
模式: REVIVE SDK不会训练虚拟环境,它适用于调用已有虚拟环境进行策略训练。
-pm: 策略模型的训练模式, 包括:
once
,tune
,None
。once
模式: REVIVE SDK将使用默认参数训练模型。tune
模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。None
模式: REVIVE SDK不会训练策略,它适用于只训练虚拟环境而不进行策略训练的情况。
--run_id: 用户为训练实验提供的名称。REVIVE将创建
logs/<run_id>
作为日志目录。如果未提供,REVIVE将随机生成名称。
训练虚拟环境和策略模型
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm once -pm once --run_id once
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm tune -pm tune --run_id tune
只训练虚拟环境
python train.py -df test_data.npz -cf test.yaml -vm once -pm None --run_id venv_once
python train.py -df test_data.npz -cf test.yaml -vm tune -pm None --run_id venv_tune
Note
需要定义的奖励函数仅在策略训练的时候需要被提供。
在已有的环境模型基础上训练策略模型
Important
当单独训练策略时,REVIVE SDK将根据索引 run_id
查找完成训练的虚拟环境模型。
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm once --run_id venv_once
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm tune --run_id venv_tune
Note
train.py
是一个调用REVIVE SDK进行模型训练的通用脚本。 我们也可以基于 recovery.server.ReviveServer
类进行自定义训练方法的开发。
请参阅API中的 revive.server.ReviveServer
( REVIVE API ) 了解更多信息。
在训练过程中,我们可以随时使用tensorboard打开日志目录以监控训练过程。当REVIVE SDK完成虚拟环境模型训练和策略模型训练后。
我们可以在日志文件夹( logs/<run_id>
)下找到保存的模型( .pkl
或 .onnx
)。
REVIVE提供了一个超参数调优工具,使用 tune
模式开始训练将切换到超参调优模式。
在该模式下,REVIVE将从预设超参数空间中采样多组超参数用于模型训练,通常使用超参调优模式可以获得更好的模型。
有关超参调优模式的详细说明,请参阅文档中的REVIVE API(revive.conf)部分。
我们也可以通过修改 config.json
中的相关配置来调整超参搜索空间和搜索方法。
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -rcf config.json --run_id test
Note
示例数据 test_data.npz
, test.yaml
, test_reward.py
和 config.json
存储在 revive/data
文件夹中。