''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import os
import sys
import json
import uuid
import socket
import pickle
import inspect
import warnings
from copy import deepcopy
from typing import Dict, Union, Optional, Tuple
import ray
import numpy as np
from loguru import logger
from revive.utils.common_utils import get_reward_fn, get_module, list2parser, setup_seed
from revive.computation.inference import PolicyModel, VirtualEnv, VirtualEnvDev
from revive.conf.config import DEBUG_CONFIG, DEFAULT_CONFIG
from revive.data.dataset import OfflineDataset
from revive.utils.server_utils import DataBufferEnv, DataBufferPolicy, DataBufferTuner, Logger, VenvTrain, TuneVenvTrain, PolicyTrain, TunePolicyTrain, ParameterTuner
warnings.filterwarnings('ignore')
[docs]class ReviveServer:
r"""
A class that uses `ray` to manage all the training tasks. It can automatic search for optimal hyper-parameters.
`ReviveServer` will do five steps to initialize:
1. Create or connect to a ray cluster. The behavior is controlled by `address` parameter. If the `address`
parameter is `None`, it will create its own cluster. If the `address` parameter is specified, it will
connect to the existing cluster.
2. Load config for training. The config is stored in `revive/config.py`. You can change these parameters
by editing the file, passing through command line or through `custom_config`.
3. Load data and its config, register reward function. The data files are specified by parameters
`dataset_file_path`, `dataset_desc_file_path` and `val_file_path`. Note the `val_file_path` is optional.
If it is not specified, revive will split the training data. All the data will be put into the ray
object store to share among the whole cluster.
4. Create the folder to store results. The top level folder of these logs are controlled by `log_dir` parameter.
If it is not provided, the default value is the `logs` folder under the revive repertory. The second-level folder
is controlled by the `run_id` parameter in the training config. If it is not specified, we will generate a random id
for the folder. All the training results will be placed in the second-level folder.
5. Create result server as ray actor, and try to load existing results in the log folder. This class is very useful when
you want to train a policy or tune parameters from an already trained simulator.
Initialization a Revive Server.
Args:
:dataset_file_path (str):
The file path where the training dataset is stored. If the val_file_path
is "None", Some data will be cut out from the training dataset as the validation dataset. (e.g., "/data/data.npz" )
:dataset_desc_file_path (str):
The file path where the data description file is stored. (e.g., "/data/test.yaml" )
:val_file_path (str):
The file path where the validate dataset is stored. If it's "None",
the validation dataset will be cut out from the training dataset.
:reward_file_path (str):
The storage path of the file that defines the reward function.
:target_policy_name (str):
Name of target policy to be optimized. Maximize the defined reward by optimizing the policy.
If it is None, the first policy in the graph will be chosen.
:log_dir (str):
Training log and saved model storage folder.
:run_id (str):
The ID of the current running experiment is used to distinguish different training.
When it is not provided, an ID will be automatically generated
:address (str):
The address of the ray cluster, If the `address` parameter is `None`, it will create its own cluster.
:venv_mode ("tune","once","None"):
Control the mode of venv training.
`tune` means conducting hyper-parameter search;
`once` means train with the default hyper-parameters;
`None` means skip.'
:policy_mode ("tune","once","None"):
Control the mode of venv training.
`tune` means conducting hyper-parameter search;
`once` means train with the default hyper-parameters;
`None` means skip.'
:tuning_mode ("max","min","None"):
Control the mode of parameter tuning.
`max` and `min` means enabling tuning and the direction;
`None` means skip.'
This feature is currently unstable
:tune_initial_state (str):
Initial state of parameter tuning, needed when tuning mode is enabled.
:debug (bool):
If it is True, Will enter debug mode for debugging.
:custom_config:
Json file path. The file content can be used to override the default parameters.
:kwargs:
Keyword parameters can be used to override default parameters
"""
def __init__(self,
dataset_file_path : str,
dataset_desc_file_path : str,
val_file_path : Optional[str] = None,
user_module_file_path : Optional[str] = None,
matcher_reward_file_path : Optional[str] = None,
reward_file_path : Optional[str] = None,
target_policy_name : str = None,
log_dir : str = None,
run_id : Optional[str] = None,
address : Optional[str] = None,
venv_mode : str = 'tune',
policy_mode : str = 'tune',
tuning_mode : str = 'None',
tune_initial_state : Optional[Dict[str, np.ndarray]] = None,
debug : bool = False,
revive_config_file_path : Optional[str] = None,
**kwargs):
assert policy_mode == 'None' or tuning_mode == 'None', 'Cannot perform both policy training and parameter tuning!'
# ray.init(local_mode=True) # debug only
''' get config '''
config = DEBUG_CONFIG if debug else DEFAULT_CONFIG
parser = list2parser(config)
self.config = parser.parse_known_args()[0].__dict__
self.run_id = run_id or uuid.uuid4().hex
self.workspace = os.path.abspath(os.path.join(log_dir, self.run_id))
self.config['workspace'] = self.workspace
os.makedirs(self.workspace, mode=0o777, exist_ok=True)
assert os.path.exists(self.workspace)
self.revive_log_path = os.path.join(os.path.abspath(self.workspace),"revive.log")
self.config["revive_log_path"] = self.revive_log_path
logger.add(self.revive_log_path)
if revive_config_file_path is not None:
with open(revive_config_file_path, 'r') as f:
custom_config = json.load(f)
self.config.update(custom_config)
for parameter_description in custom_config.get('base_config', {}):
self.config[parameter_description['name']] = parameter_description['default']
revive_config_save_file_path = os.path.join(self.workspace, "config.json")
with open(revive_config_save_file_path, 'w') as f:
json.dump(self.config,f)
self.revive_config_file_path = revive_config_save_file_path
''' preprocess config'''
# NOTE: in crypto mode, each trail is fixed to use one GPU.
self.config['is_crypto'] = os.environ.get('REVIVE_CRYPTO', 0)
setup_seed(self.config['global_seed'])
self.venv_mode = venv_mode
self.policy_mode = policy_mode
self.tuning_mode = tuning_mode
self.tune_initial_state = tune_initial_state
self.user_module = get_module(user_module_file_path, dataset_desc_file_path)
if self.user_module is not None:
functions = inspect.getmembers(self.user_module, inspect.isfunction)
function_dict = {name: func for name, func in functions}
self.config['user_module'] = function_dict
else:
self.config['user_module'] = None
self.rule_reward_module = get_module(matcher_reward_file_path, dataset_desc_file_path)
self.config['rule_reward_func'] = getattr(self.rule_reward_module, "get_reward", None)
self.config['rule_reward_func_normalize'] = getattr(self.rule_reward_module, "normalize", False)
self.config['rule_reward_func_weight'] = getattr(self.rule_reward_module, "weight", 1.0)
self.config['rule_reward_matching_nodes'] = getattr(self.rule_reward_module, "matching_nodes", [])
self.reward_func = get_reward_fn(reward_file_path, dataset_desc_file_path)
self.config['user_func'] = self.reward_func
''' create dataset '''
self.data_file = dataset_file_path
self.config_file = dataset_desc_file_path
self.val_file = val_file_path
self.dataset = OfflineDataset(self.data_file, self.config_file, revive_config=self.config, ignore_check=self.config['ignore_check'])
self._check_license()
self.runtime_env = {"env_vars": {"PYTHONPATH":os.pathsep.join(sys.path), "PYARMOR_LICENSE": sys.PYARMOR_LICENSE}}
ray.init(address=address, runtime_env=self.runtime_env)
if self.val_file:
self.val_dataset = OfflineDataset(self.val_file, self.config_file, revive_config=self.config, ignore_check=self.config['ignore_check'])
self.val_dataset.processor = self.dataset.processor # make sure dataprocessing is the same
self.config['val_dataset'] = ray.put(self.val_dataset)
else: # split the training set if validation set is not provided
self.dataset, self.val_dataset = self.dataset.split(self.config['val_split_ratio'], self.config['val_split_mode'])
self.config['val_dataset'] = ray.put(self.val_dataset)
self.config['dataset'] = ray.put(self.dataset)
self.config['graph'] = self.dataset.graph
self.graph = self.config['graph']
if not tuning_mode == 'None': assert len(self.dataset.graph.tunable) > 0, 'No tunable parameter detected, please check the config yaml!'
self.config['learning_nodes_num'] = self.dataset.learning_nodes_num
if target_policy_name is None:
target_policy_name = list(self.config['graph'].keys())[0]
logger.warning(f"Target policy name [{target_policy_name}] is chosen as default")
self.config['target_policy_name'] = target_policy_name.split(',')
logger.info(f"Target policy name {self.config['target_policy_name']} is chosen as default")
''' save a copy of the base graph '''
with open(os.path.join(self.workspace, 'graph.pkl'), 'wb') as f:
pickle.dump(self.config['graph'], f)
''' setup data buffers '''
self.driver_ip = socket.gethostbyname(socket.gethostname())
self.venv_data_buffer = ray.remote(DataBufferEnv).options(resources={}).remote(venv_max_num=self.config['num_venv_store'])
self.policy_data_buffer = ray.remote(DataBufferPolicy).options(resources={}).remote()
self.tuner_data_buffer = ray.remote(DataBufferTuner).options(resources={}).remote(self.tuning_mode, self.config['parameter_tuning_budget'])
self.config['venv_data_buffer'] = self.venv_data_buffer
self.config['policy_data_buffer'] = self.policy_data_buffer
self.config['tuner_data_buffer'] = self.tuner_data_buffer
''' try to load existing venv and policy '''
self.env_save_path = kwargs.get("env_save_path", None)
self.policy_save_path = kwargs.get("policy_save_path", None)
#self._reload_venv(os.path.join(self.workspace, 'env.pkl'))
#self._reload_policy(os.path.join(self.workspace, 'policy.pkl'))
self.venv_acc = - float('inf')
self.policy_acc = - float('inf')
self.venv_logger = None
self.policy_logger = None
self.tuner_logger = None
# heterogeneous_process init setting
self.heterogeneous_process = False
data = {"REVIVE_STOP" : False, "LOG_DIR":os.path.join(os.path.abspath(self.workspace),"revive.log")}
with open(os.path.join(self.workspace, ".env.json"), 'w') as f:
json.dump(data, f)
def _reload_venv(self, path: str, return_graph: bool = False):
r'''Reload a venv from the given path'''
try:
with open(path, 'rb') as f:
self.venv = pickle.load(f)
self.venv.check_version()
if not self.graph.is_equal_venv(self.venv.graph,self.config['target_policy_name']):
logger.error('Detect different graph between loaded venv and data config, it is mostly cased by change of config file, trying to rebuild ...')
logger.error('Please check if there are some changes between config files of learing Environment and Policy!')
sys.exit()
if not self.graph.is_equal_structure(self.venv.graph):
logger.warning('graph.is_equal_structure Detect different graph between loaded venv and data config, it is mostly cased by change of config file, trying to rebuild ...')
venv_list = []
for _venv in self.venv.env_list:
graph = deepcopy(self.graph)
graph.copy_graph_node(_venv.graph)
venv_list.append(VirtualEnvDev(graph))
self.venv = VirtualEnv(venv_list)
# if return_graph:
# return graph
if self.venv_mode == 'None' and self.policy_mode != 'None':
self.heterogeneous_process = True
else:
self.heterogeneous_process = False
if return_graph:
graph = deepcopy(self.graph)
graph.copy_graph_node(self.venv.graph)
return graph
ray.get(self.venv_data_buffer.set_best_venv.remote(self.venv))
except Exception as e:
logger.info(f"Don't load venv -> {e}")
self.venv = None
def _reload_policy(self, path : str):
r'''Reload a policy from the given path'''
try:
with open(path, 'rb') as f:
self.policy = pickle.load(f)
self.policy.check_version()
ray.get(self.policy_data_buffer.set_best_policy.remote(self.policy))
except Exception as e:
logger.info(f"Don't load policy -> {e}")
self.policy = None
[docs] def train(self, env_save_path : Optional[str] = None):
r"""
Train the virtual environment and policy.
Steps
1. Start ray worker train the virtual environment based on the data;
2. Start ray worker train train policy based on the virtual environment.
"""
self.train_venv()
logger.info(f"venv training finished !")
self.train_policy(env_save_path)
logger.info(f"policy training finished !")
self.tune_parameter(env_save_path)
[docs] def train_venv(self):
r"""
Start ray worker train the virtual environment based on the data;
"""
if self.env_save_path and os.path.exists(self.env_save_path):
graph = self._reload_venv(self.env_save_path, return_graph=True)
self.config['graph'] = graph
self.graph = graph
self.venv_logger = ray.remote(Logger).remote()
self.venv_logger.update.remote(key="task_state", value="Wait")
if self.venv_mode == 'None':
self.venv_logger.update.remote(key="task_state", value="End")
else:
if 'wdist' in self.config['venv_metric']:
self.config['max_distance'] = 2
self.config['min_distance'] = 0
elif 'mae' in self.config['venv_metric']:
self.config['max_distance'] = np.log(2)
self.config['min_distance'] = np.log(2) - 15
elif 'mse' in self.config['venv_metric']:
self.config['max_distance'] = np.log(4)
self.config['min_distance'] = np.log(4) - 15
elif 'nll' in self.config['venv_metric']:
self.config['max_distance'] = 0.5 * np.log(2 * np.pi)
self.config['min_distance'] = 0.5 * np.log(2 * np.pi) - 10
logger.info(f"Distance is between {self.config['min_distance']} and {self.config['max_distance']}")
if self.config["venv_algo"] == "revive":
self.config["venv_algo"] = "revive_p"
logger.remove()
logger.info(f"Are you done ?")
if self.venv_mode == 'once':
venv_trainer = ray.remote(VenvTrain).remote(self.config, self.venv_logger, command=sys.argv[1:])
venv_trainer.train.remote()
# NOTE: after task finish, the actor will be automatically killed by ray, since there is no reference to it
elif self.venv_mode == 'tune':
self.venv_trainer = ray.remote(TuneVenvTrain).remote(self.config, self.venv_logger, command=sys.argv[1:])
self.venv_trainer.train.remote()
logger.add(self.revive_log_path)
# breakpoint()
# self.venv_mode = None
[docs] def train_policy(self, env_save_path : Optional[str] = None):
r"""
Start ray worker train train policy based on the virtual environment.
Args:
:env_save_path: virtual environments path
.. note:: Before train policy, environment models and reward function should be provided.
"""
if not env_save_path:
env_save_path = os.path.join(self.workspace, 'env.pkl')
self._reload_venv(env_save_path)
if self.venv is None:
logger.warning(f"Can't load the exist env model.")
self.policy_logger = ray.remote(Logger).remote()
self.policy_logger.update.remote(key="task_state", value="Wait")
logger.remove()
if self.policy_mode == 'None':
self.policy_logger.update.remote(key="task_state", value="End")
elif self.policy_mode == 'once':
assert self.reward_func is not None, 'policy training need reward function'
policy_trainer = ray.remote(PolicyTrain).remote(self.config, self.policy_logger, self.venv_logger, command=sys.argv[1:])
policy_trainer.train.remote()
# NOTE: after task finish, the actor will be automatically killed by ray, since there is no reference to it
elif self.policy_mode == 'tune':
assert self.reward_func is not None, 'policy training need reward function'
self.policy_trainer = ray.remote(TunePolicyTrain).remote(self.config, self.policy_logger, self.venv_logger, command=sys.argv[1:])
self.policy_trainer.train.remote()
logger.add(self.revive_log_path)
[docs] def tune_parameter(self, env_save_path : Optional[str] = None):
r"""
Tune parameters on specified virtual environments.
Args:
:env_save_path: virtual environments path
.. note:: This feature is currently unstable.
"""
if env_save_path is not None:
self._reload_venv(env_save_path)
self.config['user_func'] = self.reward_func
self.tuner_logger = ray.remote(Logger).remote()
self.tuner_logger.update.remote(key="task_state", value="Wait")
if self.tuning_mode == 'None':
self.tuner_logger.update.remote(key="task_state", value="End")
else:
assert self.reward_func is not None, 'tuning parameter needs reward function'
self.tuner = ray.remote(ParameterTuner).remote(self.config, self.tuning_mode, self.tune_initial_state, self.tuner_logger, self.venv_logger)
self.tuner.run.remote()
[docs] def stop_train(self) -> None:
r"""Stop all training tasks.
"""
_data = {"REVIVE_STOP" : True}
with open(os.path.join(self.workspace, ".env.json"), 'w') as f:
json.dump(_data, f)
if self.venv_logger is not None:
venv_logger = self.venv_logger.get_log.remote()
venv_logger = ray.get(venv_logger)
if venv_logger["task_state"] != "End":
self.venv_logger.update.remote(key="task_state", value="Shutdown")
if self.policy_logger is not None:
policy_logger = self.policy_logger.get_log.remote()
policy_logger = ray.get(policy_logger)
if policy_logger["task_state"] != "End":
self.policy_logger.update.remote(key="task_state", value="Shutdown")
[docs] def get_virtualenv_env(self) -> Tuple[VirtualEnv, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]]:
r"""Get virtual environment models and train log.
:Returns: virtual environment models and train log
"""
assert self.dataset is not None
train_log = {}
if self.venv_logger is not None:
try:
venv_logger = self.venv_logger.get_log.remote()
venv_logger = ray.get(venv_logger)
train_log.update({"task_state": venv_logger["task_state"],})
except AttributeError:
train_log.update({"task_state": "Shutdown"})
metric = ray.get(self.venv_data_buffer.get_dict.remote())
venv_acc = float(metric["max_acc"])
current_num_of_trials = int(metric["num_of_trial"])
total_num_of_trials = int(metric["total_num_of_trials"])
train_log.update({
"venv_acc" : venv_acc,
"current_num_of_trials" : current_num_of_trials,
"total_num_of_trials" : total_num_of_trials,
})
self.venv_acc = max(self.venv_acc, venv_acc)
self.venv = ray.get(self.venv_data_buffer.get_best_venv.remote())
best_model_workspace = ray.get(self.venv_data_buffer.get_best_model_workspace.remote())
if self.venv is not None and venv_logger.get('task_state') != 'End':
with open(os.path.join(self.workspace, 'env.pkl'), 'wb') as f:
pickle.dump(self.venv, f)
try:
self.venv.export2onnx(os.path.join(self.workspace, 'env.onnx'), verbose=False)
except Exception as e:
logger.info(f"Can't to export venv to ONNX. -> {e}")
status_message = ray.get(self.venv_data_buffer.get_status.remote())
return self.venv, train_log, status_message, best_model_workspace
[docs] def get_policy_model(self) -> Tuple[PolicyModel, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]]:
r"""Get policy based on specified virtual environments.
:Return: policy models and train log
"""
assert self.dataset is not None
train_log = {}
if self.policy_logger is not None:
try:
policy_logger = self.policy_logger.get_log.remote()
policy_logger = ray.get(policy_logger)
train_log.update({"task_state": policy_logger["task_state"],})
except AttributeError:
train_log.update({"task_state": "Shutdown"})
metric = ray.get(self.policy_data_buffer.get_dict.remote())
policy_acc = float(metric["max_reward"])
current_num_of_trials = int(metric["num_of_trial"])
total_num_of_trials = int(metric["total_num_of_trials"])
train_log.update({
"policy_acc" : policy_acc,
"current_num_of_trials" : current_num_of_trials,
"total_num_of_trials" : total_num_of_trials,
})
self.policy_acc = max(self.policy_acc, policy_acc)
self.policy = ray.get(self.policy_data_buffer.get_best_policy.remote())
best_model_workspace = ray.get(self.policy_data_buffer.get_best_model_workspace.remote())
if self.policy is not None:
with open(os.path.join(self.workspace, 'policy.pkl'), 'wb') as f:
pickle.dump(self.policy, f)
try:
tmp_policy = deepcopy(self.policy)
tmp_policy.reset()
tmp_policy.export2onnx(os.path.join(self.workspace, 'policy.onnx'), verbose=False)
except Exception as e:
logger.info(f"Can't to export venv to ONNX. -> {e}")
if self.policy_mode == 'tune': # create by tune
self.policy_traj_dir = os.path.join(self.workspace, 'policy_tune')
else:
self.policy_traj_dir = os.path.join(self.workspace, 'policy_train')
if self.venv is not None and self.heterogeneous_process:
try:
with open(os.path.join(self.policy_traj_dir, 'env_for_policy.pkl'), 'wb') as f:
pickle.dump(self.venv, f)
self.venv.export2onnx(os.path.join(self.policy_traj_dir, 'env_for_policy.onnx'), verbose=False)
except Exception as e:
logger.info(f"Can't to export venv_of_policy to ONNX. -> {e}")
status_message = ray.get(self.policy_data_buffer.get_status.remote())
return self.policy, train_log, status_message, best_model_workspace
[docs] def get_parameter(self) -> Tuple[np.ndarray, Dict[str, Union[str, float]]]:
r"""Get tuned parameters based on specified virtual environments.
:Return: current best parameters and training log
"""
train_log = {}
if self.tuner_logger is not None:
try:
tuner_logger = self.tuner_logger.get_log.remote()
tuner_logger = ray.get(tuner_logger)
train_log.update({"task_state": tuner_logger["task_state"],})
except AttributeError:
train_log.update({"task_state": "Shutdown"})
metric = ray.get(self.tuner_data_buffer.get_state.remote())
train_log.update(metric)
self.best_parameter = train_log.pop('best_parameter')
return self.best_parameter, train_log
def _check_license(self):
from revive.utils.auth_utils import check_license
check_license(self)