revive.server¶
- class revive.server_cn.ReviveServer(dataset_file_path: str, dataset_desc_file_path: str, val_file_path: Optional[str] = None, reward_file_path: Optional[str] = None, target_policy_name: Optional[str] = None, log_dir: Optional[str] = None, run_id: Optional[str] = None, address: Optional[str] = None, venv_mode: str = 'tune', policy_mode: str = 'tune', tuning_mode: str = 'None', tune_initial_state: Optional[Dict[str, ndarray]] = None, debug: bool = False, revive_config_file_path: Optional[str] = None, **kwargs)[source]¶
Bases:
object
ReviveServer是Revive SDK的训练入口,负责启动并管理所有训练任务。
ReviveServer 执行四个步骤来完成初始化:
1. 创建或连接到ray集群。集群地址由`address`参数控制。如果`address`参数为`None`,它将创建自己的集群。 如果指定了`address`参数,它将使用参数连接到现有集群。
加载培训配置文件。提供的默认配置文件为`config.json`中。可以通过编辑文件来更改默认参数。
加载决策流图,npz数据和函数。数据文件由参数`dataset_file_path `、`dataset_desc_file_path和`val_file_path’参数指定。
4. 创建日志文件夹存储训练结果。这些日志的顶层文件夹由`log_dir`参数控制。如果未提供,则默认生成的`logs`文件夹。 第二级文件夹由训练配置中的`run_id`参数控制,如果未指定,将为文件夹生成一个随机id。所有训练日志和模型都将放在第二级文件夹中。
- 参数:
- dataset_file_path (str)
训练数据的文件路径(
.npz
或.h5
文件)。- dataset_desc_file_path (str)
决策流图的文件路径(
.yaml
)。- val_file_path (str)
验证数据的文件路径(可选)。
- reward_file_path (str)
定义奖励函数的文件的存储路径。
- target_policy_name (str)
要优化的策略节点的名称。如果为None,则将选择决策流图中的第一网络节点作为策略节点。
- log_dir (str)
模型和训练日志存储文件夹
- run_id (str)
实验ID,用于生成日志文件夹名称,区分不同的实验。如果未提供,系统会自动生成。
- address (str)
ray集群地址,集群地址由`address`参数控制。如果`address`参数为`None`,它将创建自己的集群。如果指定了`address`参数,它将使用参数连接到现有集群。
- venv_mode (“tune”,”once”,”None”)
训练虚拟环境的不同模式: tune 使用超参数搜索来训练虚拟环境模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。 once 使用默认参数训练虚拟环境模型。 None 不训练虚拟环境模型。
- policy_mode (“tune”,”once”,”None”)
策略模型的训练模式: tune 使用超参数搜索来训练策略模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。 once 使用默认参数训练策略模型。 None 不训练策略模型。
- custom_config
超参配置文件路径,可用于覆盖默认参数。
- kwargs
关键字参数,可用于覆盖默认参数。
- revive_config_file_path¶
preprocess config
- train(env_save_path: Optional[str] = None)[source]¶
训练虚拟环境和策略 步骤:
加载数据和参数配置启动ray actor训练虚拟环境;
加载数据,参数和已训练完成的虚拟环境启动ray actor训练策略。
- train_policy(env_save_path: Optional[str] = None)[source]¶
加载数据,参数和已训练完成的虚拟环境启动ray actor训练策略.
- 参数:
- env_save_path
虚拟环境的保存地址,默认为None,将会自动根据run_id查找虚拟环境文件
Note
在训练策略之前,应提供已训练完成的虚拟环境模型和奖励函数。
- get_virtualenv_env() Tuple[VirtualEnv, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]] [source]¶
获取实时最佳虚拟环境模型和训练日志
- get_policy_model() Tuple[PolicyModel, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]] [source]¶
获取实时最佳策略模型和训练日志