训练模型 ======== 当我们准备好训练数据集( ``.npz`` 或 ``.h5`` 文件)、决策流图描述文件( ``.yaml`` ) 和奖励函数( ``reward.py`` )后。 我们可以使用 ``python train.py`` 命令开启虚拟环境模型和策略模型训练。该脚本将实例化 ``revive.server.ReviveServer`` 并开启训练。 **训练脚本示例** .. code:: bash python train.py -df <训练数据文件路径> -cf <决策流图文件路径> -rf <奖励函数文件路径> -vm <训练虚拟环境模式> -pm <训练策略模型模式> --run_id <训练实验名称> 运行 ``train.py`` 脚本可定义的命令行参数如下: - -df: 训练数据的文件路径( ``.npz`` 或 ``.h5`` 文件)。 - -vf: 验证数据的文件路径(可选)。 - -cf: 决策流图的文件路径( ``.yaml`` )。 - -rf: 定义的奖励函数的文件路径( ``reward.py`` )(仅在训练策略时需要)。 - -rcf: 支持进行超参配置的 ``.json`` 文件(可选)。 - -tpn: 策略节点的名称。必须是决策流图中定义的节点;如果未指定,在默认情况下,排在拓扑顺序第一位的节点将作为策略节点。 - -vm: 训练虚拟环境的不同模式, 包括: ``once``\ ,\ ``tune``\ ,\ ``None``。 - ``once`` 模式: REVIVE SDK将使用默认参数训练模型。 - ``tune`` 模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。 - ``None`` 模式: REVIVE SDK不会训练虚拟环境,它适用于调用已有虚拟环境进行策略训练。 - -pm: 策略模型的训练模式, 包括: ``once``\ ,\ ``tune``\ ,\ ``None``。 - ``once`` 模式: REVIVE SDK将使用默认参数训练模型。 - ``tune`` 模式: REVIVE SDK将使用超参数搜索来训练模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。 - ``None`` 模式: REVIVE SDK不会训练策略,它适用于只训练虚拟环境而不进行策略训练的情况。 - \--run_id: 用户为训练实验提供的名称。REVIVE将创建 ``logs/`` 作为日志目录。如果未提供,REVIVE将随机生成名称。 **训练虚拟环境和策略模型** :: python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm once -pm once --run_id once python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm tune -pm tune --run_id tune **只训练虚拟环境** :: python train.py -df test_data.npz -cf test.yaml -vm once -pm None --run_id venv_once python train.py -df test_data.npz -cf test.yaml -vm tune -pm None --run_id venv_tune .. note:: 需要定义的奖励函数仅在策略训练的时候需要被提供。 **在已有的环境模型基础上训练策略模型** .. important:: 当单独训练策略时,REVIVE SDK将根据索引 ``run_id`` 查找完成训练的虚拟环境模型。 .. code:: bash python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm once --run_id venv_once python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm tune --run_id venv_tune .. note:: ``train.py`` 是一个调用REVIVE SDK进行模型训练的通用脚本。 我们也可以基于 ``recovery.server.ReviveServer`` 类进行自定义训练方法的开发。 请参阅API中的 ``revive.server.ReviveServer`` ( :doc:`REVIVE API <../revive_server_cn>` ) 了解更多信息。 在训练过程中,我们可以随时使用tensorboard打开日志目录以监控训练过程。当REVIVE SDK完成虚拟环境模型训练和策略模型训练后。 我们可以在日志文件夹( ``logs/``)下找到保存的模型( ``.pkl`` 或 ``.onnx``)。 REVIVE提供了一个超参数调优工具,使用 ``tune`` 模式开始训练将切换到超参调优模式。 在该模式下,REVIVE将从预设超参数空间中采样多组超参数用于模型训练,通常使用超参调优模式可以获得更好的模型。 有关超参调优模式的详细说明,请参阅文档中的REVIVE API(:doc:`revive.conf <../revive_conf_cn>`)部分。 我们也可以通过修改 ``config.json`` 中的相关配置来调整超参搜索空间和搜索方法。 .. code:: bash python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -rcf config.json --run_id test .. note:: 示例数据 ``test_data.npz``, ``test.yaml``, ``test_reward.py`` 和 ``config.json`` 存储在 ``revive/data`` 文件夹中。