Source code for revive.server_cn

''''''
""" 本文件只为用来生成中文文档,请不要另作它用 """
"""
    POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
import os
import sys
import json
import uuid
import socket
import pickle
import warnings

from copy import deepcopy
from typing import Dict, Union, Optional, Tuple

import ray
import numpy as np
from loguru import logger

from revive.utils.common_utils import get_reward_fn, list2parser, setup_seed
from revive.computation.inference import PolicyModel, VirtualEnv, VirtualEnvDev
from revive.conf.config import DEBUG_CONFIG, DEFAULT_CONFIG
from revive.data.dataset import OfflineDataset
from revive.utils.server_utils import DataBufferEnv, DataBufferPolicy, DataBufferTuner, Logger, VenvTrain, TuneVenvTrain, PolicyTrain, TunePolicyTrain, ParameterTuner

warnings.filterwarnings('ignore')


[docs]class ReviveServer: r""" ReviveServer是Revive SDK的训练入口,负责启动并管理所有训练任务。 `ReviveServer` 执行四个步骤来完成初始化: 1. 创建或连接到ray集群。集群地址由`address`参数控制。如果`address`参数为`None`,它将创建自己的集群。 如果指定了`address`参数,它将使用参数连接到现有集群。 2. 加载培训配置文件。提供的默认配置文件为`config.json`中。可以通过编辑文件来更改默认参数。 3. 加载决策流图,npz数据和函数。数据文件由参数`dataset_file_path `、`dataset_desc_file_path和`val_file_path'参数指定。 4. 创建日志文件夹存储训练结果。这些日志的顶层文件夹由`log_dir`参数控制。如果未提供,则默认生成的`logs`文件夹。 第二级文件夹由训练配置中的`run_id`参数控制,如果未指定,将为文件夹生成一个随机id。所有训练日志和模型都将放在第二级文件夹中。 参数: :dataset_file_path (str): 训练数据的文件路径( ``.npz`` 或 ``.h5`` 文件)。 :dataset_desc_file_path (str): 决策流图的文件路径( ``.yaml`` )。 :val_file_path (str): 验证数据的文件路径(可选)。 :reward_file_path (str): 定义奖励函数的文件的存储路径。 :target_policy_name (str): 要优化的策略节点的名称。如果为None,则将选择决策流图中的第一网络节点作为策略节点。 :log_dir (str): 模型和训练日志存储文件夹 :run_id (str): 实验ID,用于生成日志文件夹名称,区分不同的实验。如果未提供,系统会自动生成。 :address (str): ray集群地址,集群地址由`address`参数控制。如果`address`参数为`None`,它将创建自己的集群。如果指定了`address`参数,它将使用参数连接到现有集群。 :venv_mode ("tune","once","None"): 训练虚拟环境的不同模式: `tune` 使用超参数搜索来训练虚拟环境模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。 `once` 使用默认参数训练虚拟环境模型。 `None` 不训练虚拟环境模型。 :policy_mode ("tune","once","None"): 策略模型的训练模式: `tune` 使用超参数搜索来训练策略模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。 `once` 使用默认参数训练策略模型。 `None` 不训练策略模型。 :custom_config: 超参配置文件路径,可用于覆盖默认参数。 :kwargs: 关键字参数,可用于覆盖默认参数。 """ def __init__(self, dataset_file_path : str, dataset_desc_file_path : str, val_file_path : Optional[str] = None, reward_file_path : Optional[str] = None, target_policy_name : str = None, log_dir : str = None, run_id : Optional[str] = None, address : Optional[str] = None, venv_mode : str = 'tune', policy_mode : str = 'tune', tuning_mode : str = 'None', tune_initial_state : Optional[Dict[str, np.ndarray]] = None, debug : bool = False, revive_config_file_path : Optional[str] = None, **kwargs): assert policy_mode == 'None' or tuning_mode == 'None', 'Cannot perform both policy training and parameter tuning!' # ray.init(local_mode=True) # debug only ''' get config ''' config = DEBUG_CONFIG if debug else DEFAULT_CONFIG parser = list2parser(config) self.config = parser.parse_known_args()[0].__dict__ self.run_id = run_id or uuid.uuid4().hex self.workspace = os.path.abspath(os.path.join(log_dir, self.run_id)) self.config['workspace'] = self.workspace os.makedirs(self.workspace, mode=0o777, exist_ok=True) assert os.path.exists(self.workspace) self.revive_log_path = os.path.join(os.path.abspath(self.workspace),"revive.log") self.config["revive_log_path"] = self.revive_log_path logger.add(self.revive_log_path) if revive_config_file_path is not None: with open(revive_config_file_path, 'r') as f: custom_config = json.load(f) self.config.update(custom_config) for parameter_description in custom_config.get('base_config', {}): self.config[parameter_description['name']] = parameter_description['default'] revive_config_save_file_path = os.path.join(self.workspace, "config.json") with open(revive_config_save_file_path, 'w') as f: json.dump(self.config,f) self.revive_config_file_path = revive_config_save_file_path ''' preprocess config''' # NOTE: in crypto mode, each trail is fixed to use one GPU. self.config['is_crypto'] = os.environ.get('REVIVE_CRYPTO', 0) setup_seed(self.config['global_seed']) self.venv_mode = venv_mode self.policy_mode = policy_mode self.tuning_mode = tuning_mode self.tune_initial_state = tune_initial_state self.reward_func = get_reward_fn(reward_file_path, dataset_desc_file_path) self.config['user_func'] = self.reward_func ''' create dataset ''' self.data_file = dataset_file_path self.config_file = dataset_desc_file_path self.val_file = val_file_path self.dataset = OfflineDataset(self.data_file, self.config_file, self.config['ignore_check']) self._check_license() self.runtime_env = {"env_vars": {"PYTHONPATH":os.pathsep.join(sys.path), "PYARMOR_LICENSE": sys.PYARMOR_LICENSE}} ray.init(address=address, runtime_env=self.runtime_env) if self.val_file: self.val_dataset = OfflineDataset(self.val_file, self.config_file, self.config['ignore_check']) self.val_dataset.processor = self.dataset.processor # make sure dataprocessing is the same self.config['val_dataset'] = ray.put(self.val_dataset) else: # split the training set if validation set is not provided self.dataset, self.val_dataset = self.dataset.split(self.config['val_split_ratio'], self.config['val_split_mode']) self.config['val_dataset'] = ray.put(self.val_dataset) self.config['dataset'] = ray.put(self.dataset) self.config['graph'] = self.dataset.graph self.graph = self.config['graph'] if not tuning_mode == 'None': assert len(self.dataset.graph.tunable) > 0, 'No tunable parameter detected, please check the config yaml!' self.config['learning_nodes_num'] = self.dataset.learning_nodes_num if target_policy_name is None: target_policy_name = list(self.config['graph'].keys())[0] logger.warning(f"Target policy name [{target_policy_name}] is chosen as default") self.config['target_policy_name'] = target_policy_name ''' save a copy of the base graph ''' with open(os.path.join(self.workspace, 'graph.pkl'), 'wb') as f: pickle.dump(self.config['graph'], f) ''' setup data buffers ''' self.driver_ip = socket.gethostbyname(socket.gethostname()) self.venv_data_buffer = ray.remote(DataBufferEnv).options(resources={}).remote(venv_max_num=self.config['num_venv_store']) self.policy_data_buffer = ray.remote(DataBufferPolicy).options(resources={}).remote() self.tuner_data_buffer = ray.remote(DataBufferTuner).options(resources={}).remote(self.tuning_mode, self.config['parameter_tuning_budget']) self.config['venv_data_buffer'] = self.venv_data_buffer self.config['policy_data_buffer'] = self.policy_data_buffer self.config['tuner_data_buffer'] = self.tuner_data_buffer ''' try to load existing venv and policy ''' self.env_save_path = kwargs.get("env_save_path", None) self.policy_save_path = kwargs.get("policy_save_path", None) #self._reload_venv(os.path.join(self.workspace, 'env.pkl')) #self._reload_policy(os.path.join(self.workspace, 'policy.pkl')) self.venv_acc = - float('inf') self.policy_acc = - float('inf') self.venv_logger = None self.policy_logger = None self.tuner_logger = None data = {"REVIVE_STOP" : False, "LOG_DIR":os.path.join(os.path.abspath(self.workspace),"revive.log")} with open(os.path.join(self.workspace, ".env.json"), 'w') as f: json.dump(data, f) def _reload_venv(self, path: str, return_graph: bool = False): r'''从给定路径重新加载venv''' try: with open(path, 'rb') as f: self.venv = pickle.load(f) self.venv.check_version() if not self.graph.is_equal_venv(self.venv.graph,self.config['target_policy_name']): logger.error('Detect different graph between loaded venv and data config, it is mostly cased by change of config file, trying to rebuild ...') logger.error('Please check if there are some changes between config files of learing Environment and Policy!') sys.exit() if not self.graph.is_equal_structure(self.venv.graph): logger.warning('graph.is_equal_structure Detect different graph between loaded venv and data config, it is mostly cased by change of config file, trying to rebuild ...') venv_list = [] for _venv in self.venv.env_list: graph = deepcopy(self.graph) graph.copy_graph_node(_venv.graph) venv_list.append(VirtualEnvDev(graph)) if return_graph: return graph self.venv = VirtualEnv(venv_list) ray.get(self.venv_data_buffer.set_best_venv.remote(self.venv)) except Exception as e: logger.info(f"Don't load venv -> {e}") self.venv = None def _reload_policy(self, path : str): r'''从给定路径重新加载策略''' try: with open(path, 'rb') as f: self.policy = pickle.load(f) self.policy.check_version() ray.get(self.policy_data_buffer.set_best_policy.remote(self.policy)) except Exception as e: logger.info(f"Don't load policy -> {e}") self.policy = None
[docs] def train(self, env_save_path : Optional[str] = None): r""" 训练虚拟环境和策略 步骤: 1. 加载数据和参数配置启动ray actor训练虚拟环境; 2. 加载数据,参数和已训练完成的虚拟环境启动ray actor训练策略。 """ self.train_venv() self.train_policy(env_save_path) self.tune_parameter(env_save_path)
[docs] def train_venv(self): r""" 加载数据和参数配置启动ray actor训练虚拟环境 """ if self.env_save_path and os.path.exists(self.env_save_path): graph = self._reload_venv(self.env_save_path, return_graph=True) self.config['graph'] = graph self.graph = graph self.venv_logger = ray.remote(Logger).remote() self.venv_logger.update.remote(key="task_state", value="Wait") if self.venv_mode == 'None': self.venv_logger.update.remote(key="task_state", value="End") else: if 'wdist' in self.config['venv_metric']: self.config['max_distance'] = 2 self.config['min_distance'] = 0 elif 'mae' in self.config['venv_metric']: self.config['max_distance'] = np.log(2) self.config['min_distance'] = np.log(2) - 15 elif 'mse' in self.config['venv_metric']: self.config['max_distance'] = np.log(4) self.config['min_distance'] = np.log(4) - 15 elif 'nll' in self.config['venv_metric']: self.config['max_distance'] = 0.5 * np.log(2 * np.pi) self.config['min_distance'] = 0.5 * np.log(2 * np.pi) - 10 logger.info(f"Distance is between {self.config['min_distance']} and {self.config['max_distance']}") if self.config["venv_algo"] == "revive": self.config["venv_algo"] = "revive_p" logger.remove() if self.venv_mode == 'once': venv_trainer = ray.remote(VenvTrain).remote(self.config, self.venv_logger, command=sys.argv[1:]) venv_trainer.train.remote() # NOTE: after task finish, the actor will be automatically killed by ray, since there is no reference to it elif self.venv_mode == 'tune': self.venv_trainer = ray.remote(TuneVenvTrain).remote(self.config, self.venv_logger, command=sys.argv[1:]) self.venv_trainer.train.remote() logger.add(self.revive_log_path)
[docs] def train_policy(self, env_save_path : Optional[str] = None): r""" 加载数据,参数和已训练完成的虚拟环境启动ray actor训练策略. 参数: :env_save_path: 虚拟环境的保存地址,默认为None,将会自动根据run_id查找虚拟环境文件 .. note:: 在训练策略之前,应提供已训练完成的虚拟环境模型和奖励函数。 """ if not env_save_path: env_save_path = os.path.join(self.workspace, 'env.pkl') self._reload_venv(env_save_path) if self.venv is None: logger.warning(f"Can't load the exist env model.") self.policy_logger = ray.remote(Logger).remote() self.policy_logger.update.remote(key="task_state", value="Wait") logger.remove() if self.policy_mode == 'None': self.policy_logger.update.remote(key="task_state", value="End") elif self.policy_mode == 'once': assert self.reward_func is not None, 'policy training need reward function' policy_trainer = ray.remote(PolicyTrain).remote(self.config, self.policy_logger, self.venv_logger, command=sys.argv[1:]) policy_trainer.train.remote() # NOTE: after task finish, the actor will be automatically killed by ray, since there is no reference to it elif self.policy_mode == 'tune': assert self.reward_func is not None, 'policy training need reward function' self.policy_trainer = ray.remote(TunePolicyTrain).remote(self.config, self.policy_logger, self.venv_logger, command=sys.argv[1:]) self.policy_trainer.train.remote() logger.add(self.revive_log_path)
[docs] def tune_parameter(self, env_save_path : Optional[str] = None): if env_save_path is not None: self._reload_venv(env_save_path) self.config['user_func'] = self.reward_func self.tuner_logger = ray.remote(Logger).remote() self.tuner_logger.update.remote(key="task_state", value="Wait") if self.tuning_mode == 'None': self.tuner_logger.update.remote(key="task_state", value="End") else: assert self.reward_func is not None, 'tuning parameter needs reward function' self.tuner = ray.remote(ParameterTuner).remote(self.config, self.tuning_mode, self.tune_initial_state, self.tuner_logger, self.venv_logger) self.tuner.run.remote()
[docs] def stop_train(self) -> None: r"""停止所有训练任务 """ _data = {"REVIVE_STOP" : True} with open(os.path.join(self.workspace, ".env.json"), 'w') as f: json.dump(_data, f) if self.venv_logger is not None: venv_logger = self.venv_logger.get_log.remote() venv_logger = ray.get(venv_logger) if venv_logger["task_state"] != "End": self.venv_logger.update.remote(key="task_state", value="Shutdown") if self.policy_logger is not None: policy_logger = self.policy_logger.get_log.remote() policy_logger = ray.get(policy_logger) if policy_logger["task_state"] != "End": self.policy_logger.update.remote(key="task_state", value="Shutdown")
[docs] def get_virtualenv_env(self) -> Tuple[VirtualEnv, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]]: r"""获取实时最佳虚拟环境模型和训练日志 """ assert self.dataset is not None train_log = {} if self.venv_logger is not None: try: venv_logger = self.venv_logger.get_log.remote() venv_logger = ray.get(venv_logger) train_log.update({"task_state": venv_logger["task_state"],}) except AttributeError: train_log.update({"task_state": "Shutdown"}) metric = ray.get(self.venv_data_buffer.get_dict.remote()) venv_acc = float(metric["max_acc"]) current_num_of_trials = int(metric["num_of_trial"]) total_num_of_trials = int(metric["total_num_of_trials"]) train_log.update({ "venv_acc" : venv_acc, "current_num_of_trials" : current_num_of_trials, "total_num_of_trials" : total_num_of_trials, }) self.venv_acc = max(self.venv_acc, venv_acc) self.venv = ray.get(self.venv_data_buffer.get_best_venv.remote()) best_model_workspace = ray.get(self.venv_data_buffer.get_best_model_workspace.remote()) if self.venv is not None: with open(os.path.join(self.workspace, 'env.pkl'), 'wb') as f: pickle.dump(self.venv, f) try: self.venv.export2onnx(os.path.join(self.workspace, 'env.onnx'), verbose=False) except Exception as e: pass logger.info(f"Can't to export venv to ONNX. -> {e}") status_message = ray.get(self.venv_data_buffer.get_status.remote()) return self.venv, train_log, status_message, best_model_workspace
[docs] def get_policy_model(self) -> Tuple[PolicyModel, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]]: r"""获取实时最佳策略模型和训练日志 """ assert self.dataset is not None train_log = {} if self.policy_logger is not None: try: policy_logger = self.policy_logger.get_log.remote() policy_logger = ray.get(policy_logger) train_log.update({"task_state": policy_logger["task_state"],}) except AttributeError: train_log.update({"task_state": "Shutdown"}) metric = ray.get(self.policy_data_buffer.get_dict.remote()) policy_acc = float(metric["max_reward"]) current_num_of_trials = int(metric["num_of_trial"]) total_num_of_trials = int(metric["total_num_of_trials"]) train_log.update({ "policy_acc" : policy_acc, "current_num_of_trials" : current_num_of_trials, "total_num_of_trials" : total_num_of_trials, }) self.policy_acc = max(self.policy_acc, policy_acc) self.policy = ray.get(self.policy_data_buffer.get_best_policy.remote()) best_model_workspace = ray.get(self.policy_data_buffer.get_best_model_workspace.remote()) if self.policy is not None: with open(os.path.join(self.workspace, 'policy.pkl'), 'wb') as f: pickle.dump(self.policy, f) try: tmp_policy = deepcopy(self.policy) tmp_policy.reset() tmp_policy.export2onnx(os.path.join(self.workspace, 'policy.onnx'), verbose=False) except Exception as e: logger.info(f"Can't to export venv to ONNX. -> {e}") status_message = ray.get(self.policy_data_buffer.get_status.remote()) return self.policy, train_log, status_message, best_model_workspace
[docs] def get_parameter(self) -> Tuple[np.ndarray, Dict[str, Union[str, float]]]: train_log = {} if self.tuner_logger is not None: try: tuner_logger = self.tuner_logger.get_log.remote() tuner_logger = ray.get(tuner_logger) train_log.update({"task_state": tuner_logger["task_state"],}) except AttributeError: train_log.update({"task_state": "Shutdown"}) metric = ray.get(self.tuner_data_buffer.get_state.remote()) train_log.update(metric) self.best_parameter = train_log.pop('best_parameter') return self.best_parameter, train_log
def _check_license(self): from revive.utils.auth_utils import check_license check_license(self)