''''''
""" 本文件只为用来生成中文文档,请不要另作它用 """
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import os
import sys
import json
import uuid
import socket
import pickle
import warnings
from copy import deepcopy
from typing import Dict, Union, Optional, Tuple
import ray
import numpy as np
from loguru import logger
from revive.utils.common_utils import get_reward_fn, list2parser, setup_seed
from revive.computation.inference import PolicyModel, VirtualEnv, VirtualEnvDev
from revive.conf.config import DEBUG_CONFIG, DEFAULT_CONFIG
from revive.data.dataset import OfflineDataset
from revive.utils.server_utils import DataBufferEnv, DataBufferPolicy, DataBufferTuner, Logger, VenvTrain, TuneVenvTrain, PolicyTrain, TunePolicyTrain, ParameterTuner
warnings.filterwarnings('ignore')
[docs]
class ReviveServer:
r"""
ReviveServer是Revive SDK的训练入口,负责启动并管理所有训练任务。
`ReviveServer` 执行四个步骤来完成初始化:
1. 创建或连接到ray集群。集群地址由`address`参数控制。如果`address`参数为`None`,它将创建自己的集群。
如果指定了`address`参数,它将使用参数连接到现有集群。
2. 加载培训配置文件。提供的默认配置文件为`config.json`中。可以通过编辑文件来更改默认参数。
3. 加载决策流图,npz数据和函数。数据文件由参数`dataset_file_path `、`dataset_desc_file_path和`val_file_path'参数指定。
4. 创建日志文件夹存储训练结果。这些日志的顶层文件夹由`log_dir`参数控制。如果未提供,则默认生成的`logs`文件夹。
第二级文件夹由训练配置中的`run_id`参数控制,如果未指定,将为文件夹生成一个随机id。所有训练日志和模型都将放在第二级文件夹中。
参数:
:dataset_file_path (str):
训练数据的文件路径( ``.npz`` 或 ``.h5`` 文件)。
:dataset_desc_file_path (str):
决策流图的文件路径( ``.yaml`` )。
:val_file_path (str):
验证数据的文件路径(可选)。
:reward_file_path (str):
定义奖励函数的文件的存储路径。
:target_policy_name (str):
要优化的策略节点的名称。如果为None,则将选择决策流图中的第一网络节点作为策略节点。
:log_dir (str):
模型和训练日志存储文件夹
:run_id (str):
实验ID,用于生成日志文件夹名称,区分不同的实验。如果未提供,系统会自动生成。
:address (str):
ray集群地址,集群地址由`address`参数控制。如果`address`参数为`None`,它将创建自己的集群。如果指定了`address`参数,它将使用参数连接到现有集群。
:venv_mode ("tune","once","None"):
训练虚拟环境的不同模式:
`tune` 使用超参数搜索来训练虚拟环境模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。
`once` 使用默认参数训练虚拟环境模型。
`None` 不训练虚拟环境模型。
:policy_mode ("tune","once","None"):
策略模型的训练模式:
`tune` 使用超参数搜索来训练策略模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。
`once` 使用默认参数训练策略模型。
`None` 不训练策略模型。
:custom_config:
超参配置文件路径,可用于覆盖默认参数。
:kwargs:
关键字参数,可用于覆盖默认参数。
"""
def __init__(self,
dataset_file_path : str,
dataset_desc_file_path : str,
val_file_path : Optional[str] = None,
reward_file_path : Optional[str] = None,
target_policy_name : str = None,
log_dir : str = None,
run_id : Optional[str] = None,
address : Optional[str] = None,
venv_mode : str = 'tune',
policy_mode : str = 'tune',
tuning_mode : str = 'None',
tune_initial_state : Optional[Dict[str, np.ndarray]] = None,
debug : bool = False,
revive_config_file_path : Optional[str] = None,
**kwargs):
assert policy_mode == 'None' or tuning_mode == 'None', 'Cannot perform both policy training and parameter tuning!'
# ray.init(local_mode=True) # debug only
''' get config '''
config = DEBUG_CONFIG if debug else DEFAULT_CONFIG
parser = list2parser(config)
self.config = parser.parse_known_args()[0].__dict__
self.run_id = run_id or uuid.uuid4().hex
self.workspace = os.path.abspath(os.path.join(log_dir, self.run_id))
self.config['workspace'] = self.workspace
os.makedirs(self.workspace, mode=0o777, exist_ok=True)
assert os.path.exists(self.workspace)
self.revive_log_path = os.path.join(os.path.abspath(self.workspace),"revive.log")
self.config["revive_log_path"] = self.revive_log_path
logger.add(self.revive_log_path)
if revive_config_file_path is not None:
with open(revive_config_file_path, 'r') as f:
custom_config = json.load(f)
self.config.update(custom_config)
for parameter_description in custom_config.get('base_config', {}):
self.config[parameter_description['name']] = parameter_description['default']
revive_config_save_file_path = os.path.join(self.workspace, "config.json")
with open(revive_config_save_file_path, 'w') as f:
json.dump(self.config,f)
self.revive_config_file_path = revive_config_save_file_path
''' preprocess config'''
# NOTE: in crypto mode, each trail is fixed to use one GPU.
self.config['is_crypto'] = os.environ.get('REVIVE_CRYPTO', 0)
setup_seed(self.config['global_seed'])
self.venv_mode = venv_mode
self.policy_mode = policy_mode
self.tuning_mode = tuning_mode
self.tune_initial_state = tune_initial_state
self.reward_func = get_reward_fn(reward_file_path, dataset_desc_file_path)
self.config['user_func'] = self.reward_func
''' create dataset '''
self.data_file = dataset_file_path
self.config_file = dataset_desc_file_path
self.val_file = val_file_path
self.dataset = OfflineDataset(self.data_file, self.config_file, self.config['ignore_check'])
self._check_license()
self.runtime_env = {"env_vars": {"PYTHONPATH":os.pathsep.join(sys.path), "PYARMOR_LICENSE": sys.PYARMOR_LICENSE}}
ray.init(address=address, runtime_env=self.runtime_env)
if self.val_file:
self.val_dataset = OfflineDataset(self.val_file, self.config_file, self.config['ignore_check'])
self.val_dataset.processor = self.dataset.processor # make sure dataprocessing is the same
self.config['val_dataset'] = ray.put(self.val_dataset)
else: # split the training set if validation set is not provided
self.dataset, self.val_dataset = self.dataset.split(self.config['val_split_ratio'], self.config['val_split_mode'])
self.config['val_dataset'] = ray.put(self.val_dataset)
self.config['dataset'] = ray.put(self.dataset)
self.config['graph'] = self.dataset.graph
self.graph = self.config['graph']
if not tuning_mode == 'None': assert len(self.dataset.graph.tunable) > 0, 'No tunable parameter detected, please check the config yaml!'
self.config['learning_nodes_num'] = self.dataset.learning_nodes_num
if target_policy_name is None:
target_policy_name = list(self.config['graph'].keys())[0]
logger.warning(f"Target policy name [{target_policy_name}] is chosen as default")
self.config['target_policy_name'] = target_policy_name
''' save a copy of the base graph '''
with open(os.path.join(self.workspace, 'graph.pkl'), 'wb') as f:
pickle.dump(self.config['graph'], f)
''' setup data buffers '''
self.driver_ip = socket.gethostbyname(socket.gethostname())
self.venv_data_buffer = ray.remote(DataBufferEnv).options(resources={}).remote(venv_max_num=self.config['num_venv_store'])
self.policy_data_buffer = ray.remote(DataBufferPolicy).options(resources={}).remote()
self.tuner_data_buffer = ray.remote(DataBufferTuner).options(resources={}).remote(self.tuning_mode, self.config['parameter_tuning_budget'])
self.config['venv_data_buffer'] = self.venv_data_buffer
self.config['policy_data_buffer'] = self.policy_data_buffer
self.config['tuner_data_buffer'] = self.tuner_data_buffer
''' try to load existing venv and policy '''
self.env_save_path = kwargs.get("env_save_path", None)
self.policy_save_path = kwargs.get("policy_save_path", None)
#self._reload_venv(os.path.join(self.workspace, 'env.pkl'))
#self._reload_policy(os.path.join(self.workspace, 'policy.pkl'))
self.venv_acc = - float('inf')
self.policy_acc = - float('inf')
self.venv_logger = None
self.policy_logger = None
self.tuner_logger = None
data = {"REVIVE_STOP" : False, "LOG_DIR":os.path.join(os.path.abspath(self.workspace),"revive.log")}
with open(os.path.join(self.workspace, ".env.json"), 'w') as f:
json.dump(data, f)
def _reload_venv(self, path: str, return_graph: bool = False):
r'''从给定路径重新加载venv'''
try:
with open(path, 'rb') as f:
self.venv = pickle.load(f)
self.venv.check_version()
if not self.graph.is_equal_venv(self.venv.graph,self.config['target_policy_name']):
logger.error('Detect different graph between loaded venv and data config, it is mostly cased by change of config file, trying to rebuild ...')
logger.error('Please check if there are some changes between config files of learing Environment and Policy!')
sys.exit()
if not self.graph.is_equal_structure(self.venv.graph):
logger.warning('graph.is_equal_structure Detect different graph between loaded venv and data config, it is mostly cased by change of config file, trying to rebuild ...')
venv_list = []
for _venv in self.venv.env_list:
graph = deepcopy(self.graph)
graph.copy_graph_node(_venv.graph)
venv_list.append(VirtualEnvDev(graph))
if return_graph:
return graph
self.venv = VirtualEnv(venv_list)
ray.get(self.venv_data_buffer.set_best_venv.remote(self.venv))
except Exception as e:
logger.info(f"Don't load venv -> {e}")
self.venv = None
def _reload_policy(self, path : str):
r'''从给定路径重新加载策略'''
try:
with open(path, 'rb') as f:
self.policy = pickle.load(f)
self.policy.check_version()
ray.get(self.policy_data_buffer.set_best_policy.remote(self.policy))
except Exception as e:
logger.info(f"Don't load policy -> {e}")
self.policy = None
[docs]
def train(self, env_save_path : Optional[str] = None):
r"""
训练虚拟环境和策略
步骤:
1. 加载数据和参数配置启动ray actor训练虚拟环境;
2. 加载数据,参数和已训练完成的虚拟环境启动ray actor训练策略。
"""
self.train_venv()
self.train_policy(env_save_path)
self.tune_parameter(env_save_path)
[docs]
def train_venv(self):
r"""
加载数据和参数配置启动ray actor训练虚拟环境
"""
if self.env_save_path and os.path.exists(self.env_save_path):
graph = self._reload_venv(self.env_save_path, return_graph=True)
self.config['graph'] = graph
self.graph = graph
self.venv_logger = ray.remote(Logger).remote()
self.venv_logger.update.remote(key="task_state", value="Wait")
if self.venv_mode == 'None':
self.venv_logger.update.remote(key="task_state", value="End")
else:
if 'wdist' in self.config['venv_metric']:
self.config['max_distance'] = 2
self.config['min_distance'] = 0
elif 'mae' in self.config['venv_metric']:
self.config['max_distance'] = np.log(2)
self.config['min_distance'] = np.log(2) - 15
elif 'mse' in self.config['venv_metric']:
self.config['max_distance'] = np.log(4)
self.config['min_distance'] = np.log(4) - 15
elif 'nll' in self.config['venv_metric']:
self.config['max_distance'] = 0.5 * np.log(2 * np.pi)
self.config['min_distance'] = 0.5 * np.log(2 * np.pi) - 10
logger.info(f"Distance is between {self.config['min_distance']} and {self.config['max_distance']}")
if self.config["venv_algo"] == "revive":
self.config["venv_algo"] = "revive_p"
logger.remove()
if self.venv_mode == 'once':
venv_trainer = ray.remote(VenvTrain).remote(self.config, self.venv_logger, command=sys.argv[1:])
venv_trainer.train.remote()
# NOTE: after task finish, the actor will be automatically killed by ray, since there is no reference to it
elif self.venv_mode == 'tune':
self.venv_trainer = ray.remote(TuneVenvTrain).remote(self.config, self.venv_logger, command=sys.argv[1:])
self.venv_trainer.train.remote()
logger.add(self.revive_log_path)
[docs]
def train_policy(self, env_save_path : Optional[str] = None):
r"""
加载数据,参数和已训练完成的虚拟环境启动ray actor训练策略.
参数:
:env_save_path: 虚拟环境的保存地址,默认为None,将会自动根据run_id查找虚拟环境文件
.. note:: 在训练策略之前,应提供已训练完成的虚拟环境模型和奖励函数。
"""
if not env_save_path:
env_save_path = os.path.join(self.workspace, 'env.pkl')
self._reload_venv(env_save_path)
if self.venv is None:
logger.warning(f"Can't load the exist env model.")
self.policy_logger = ray.remote(Logger).remote()
self.policy_logger.update.remote(key="task_state", value="Wait")
logger.remove()
if self.policy_mode == 'None':
self.policy_logger.update.remote(key="task_state", value="End")
elif self.policy_mode == 'once':
assert self.reward_func is not None, 'policy training need reward function'
policy_trainer = ray.remote(PolicyTrain).remote(self.config, self.policy_logger, self.venv_logger, command=sys.argv[1:])
policy_trainer.train.remote()
# NOTE: after task finish, the actor will be automatically killed by ray, since there is no reference to it
elif self.policy_mode == 'tune':
assert self.reward_func is not None, 'policy training need reward function'
self.policy_trainer = ray.remote(TunePolicyTrain).remote(self.config, self.policy_logger, self.venv_logger, command=sys.argv[1:])
self.policy_trainer.train.remote()
logger.add(self.revive_log_path)
[docs]
def tune_parameter(self, env_save_path : Optional[str] = None):
if env_save_path is not None:
self._reload_venv(env_save_path)
self.config['user_func'] = self.reward_func
self.tuner_logger = ray.remote(Logger).remote()
self.tuner_logger.update.remote(key="task_state", value="Wait")
if self.tuning_mode == 'None':
self.tuner_logger.update.remote(key="task_state", value="End")
else:
assert self.reward_func is not None, 'tuning parameter needs reward function'
self.tuner = ray.remote(ParameterTuner).remote(self.config, self.tuning_mode, self.tune_initial_state, self.tuner_logger, self.venv_logger)
self.tuner.run.remote()
[docs]
def stop_train(self) -> None:
r"""停止所有训练任务
"""
_data = {"REVIVE_STOP" : True}
with open(os.path.join(self.workspace, ".env.json"), 'w') as f:
json.dump(_data, f)
if self.venv_logger is not None:
venv_logger = self.venv_logger.get_log.remote()
venv_logger = ray.get(venv_logger)
if venv_logger["task_state"] != "End":
self.venv_logger.update.remote(key="task_state", value="Shutdown")
if self.policy_logger is not None:
policy_logger = self.policy_logger.get_log.remote()
policy_logger = ray.get(policy_logger)
if policy_logger["task_state"] != "End":
self.policy_logger.update.remote(key="task_state", value="Shutdown")
[docs]
def get_virtualenv_env(self) -> Tuple[VirtualEnv, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]]:
r"""获取实时最佳虚拟环境模型和训练日志
"""
assert self.dataset is not None
train_log = {}
if self.venv_logger is not None:
try:
venv_logger = self.venv_logger.get_log.remote()
venv_logger = ray.get(venv_logger)
train_log.update({"task_state": venv_logger["task_state"],})
except AttributeError:
train_log.update({"task_state": "Shutdown"})
metric = ray.get(self.venv_data_buffer.get_dict.remote())
venv_acc = float(metric["max_acc"])
current_num_of_trials = int(metric["num_of_trial"])
total_num_of_trials = int(metric["total_num_of_trials"])
train_log.update({
"venv_acc" : venv_acc,
"current_num_of_trials" : current_num_of_trials,
"total_num_of_trials" : total_num_of_trials,
})
self.venv_acc = max(self.venv_acc, venv_acc)
self.venv = ray.get(self.venv_data_buffer.get_best_venv.remote())
best_model_workspace = ray.get(self.venv_data_buffer.get_best_model_workspace.remote())
if self.venv is not None:
with open(os.path.join(self.workspace, 'env.pkl'), 'wb') as f:
pickle.dump(self.venv, f)
try:
self.venv.export2onnx(os.path.join(self.workspace, 'env.onnx'), verbose=False)
except Exception as e:
pass
logger.info(f"Can't to export venv to ONNX. -> {e}")
status_message = ray.get(self.venv_data_buffer.get_status.remote())
return self.venv, train_log, status_message, best_model_workspace
[docs]
def get_policy_model(self) -> Tuple[PolicyModel, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]]:
r"""获取实时最佳策略模型和训练日志
"""
assert self.dataset is not None
train_log = {}
if self.policy_logger is not None:
try:
policy_logger = self.policy_logger.get_log.remote()
policy_logger = ray.get(policy_logger)
train_log.update({"task_state": policy_logger["task_state"],})
except AttributeError:
train_log.update({"task_state": "Shutdown"})
metric = ray.get(self.policy_data_buffer.get_dict.remote())
policy_acc = float(metric["max_reward"])
current_num_of_trials = int(metric["num_of_trial"])
total_num_of_trials = int(metric["total_num_of_trials"])
train_log.update({
"policy_acc" : policy_acc,
"current_num_of_trials" : current_num_of_trials,
"total_num_of_trials" : total_num_of_trials,
})
self.policy_acc = max(self.policy_acc, policy_acc)
self.policy = ray.get(self.policy_data_buffer.get_best_policy.remote())
best_model_workspace = ray.get(self.policy_data_buffer.get_best_model_workspace.remote())
if self.policy is not None:
with open(os.path.join(self.workspace, 'policy.pkl'), 'wb') as f:
pickle.dump(self.policy, f)
try:
tmp_policy = deepcopy(self.policy)
tmp_policy.reset()
tmp_policy.export2onnx(os.path.join(self.workspace, 'policy.onnx'), verbose=False)
except Exception as e:
logger.info(f"Can't to export venv to ONNX. -> {e}")
status_message = ray.get(self.policy_data_buffer.get_status.remote())
return self.policy, train_log, status_message, best_model_workspace
[docs]
def get_parameter(self) -> Tuple[np.ndarray, Dict[str, Union[str, float]]]:
train_log = {}
if self.tuner_logger is not None:
try:
tuner_logger = self.tuner_logger.get_log.remote()
tuner_logger = ray.get(tuner_logger)
train_log.update({"task_state": tuner_logger["task_state"],})
except AttributeError:
train_log.update({"task_state": "Shutdown"})
metric = ray.get(self.tuner_data_buffer.get_state.remote())
train_log.update(metric)
self.best_parameter = train_log.pop('best_parameter')
return self.best_parameter, train_log
def _check_license(self):
from revive.utils.auth_utils import check_license
check_license(self)