''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import os
import ray
import json
import time
import psutil
import traceback
import warnings
from ray import tune
from loguru import logger
from collections import deque
from ray.tune import CLIReporter
from collections import OrderedDict
from revive.computation.inference import *
from revive.utils.common_utils import update_description, setup_seed
from revive.algo.venv import VenvAlgorithm
from revive.algo.policy import PolicyAlgorithm
from revive.utils.tune_utils import CustomBasicVariantGenerator, CustomSearchGenerator, SysStopper, get_tune_callbacks
[docs]
class DataBufferEnv:
def __init__(self, venv_max_num : int = 10):
'''
: param venv_max_num: Max number of venv stored in best_venv
'''
self.venv_max_num = venv_max_num
self.num_of_trial = 0
self.least_metric = float('inf')
self.max_acc = - float('inf')
self.total_num_of_trials = -1
self.best_id = None
self.metric_dict = OrderedDict()
self.status_dict = OrderedDict()
self.best_venv = None
self.best_model_info = {}
self.best_model_workspace = None
self.venv_deque_dict = dict()
[docs]
def update_status(self, task_id : int, status : str, message : str = ''):
old_message = '' if not task_id in self.status_dict.keys() else self.status_dict[task_id][1]
self.status_dict[task_id] = (status, old_message + message)
[docs]
def get_status(self) -> Dict[int, Tuple[str, str]]:
return self.status_dict
[docs]
def set_best_venv(self, venv : VirtualEnv):
self.best_venv = venv
[docs]
def get_best_venv(self) -> VirtualEnv:
return self.best_venv
[docs]
def get_best_venv_info(self):
return self.best_model_info
[docs]
def get_best_model_workspace(self) -> str:
return self.best_model_workspace
[docs]
def set_total_trials(self, trials : int):
self.total_num_of_trials = trials
[docs]
def inc_trial(self) -> int:
self.num_of_trial += 1
return self.num_of_trial
[docs]
def get_num_of_trial(self) -> int:
return self.num_of_trial
[docs]
def update_venv_deque_dict(self, task_id, venv_train, venv_val):
if task_id not in self.venv_deque_dict.keys():
self.venv_deque_dict[task_id] = deque(maxlen=self.venv_max_num)
self.venv_deque_dict[task_id].append((venv_train, venv_val))
[docs]
def delet_deque_item(self,task_id,index):
del self.venv_deque_dict[task_id][index]
[docs]
def update_metric(self, task_id : int, metric : Dict[int, Union[float, VirtualEnvDev]]):
self.metric_dict[task_id] = metric
self.metric_dict = OrderedDict(sorted(self.metric_dict.items(), key=lambda x: x[1]['metric']))
self.best_id, info = list(self.metric_dict.items())[0]
self.least_metric = info['metric']
self.max_acc = info['acc']
self.best_model_workspace = info['traj_dir']
self.best_model_info = info
"""
# Save the top-k env for every task
if task_id not in self.venv_deque_dict.keys():
self.venv_deque_dict[task_id] = deque(maxlen=self.venv_max_num)
self.venv_deque_dict[task_id].append((metric['venv_train'], metric['venv_val']))
"""
# self.update_venv_deque_dict(task_id, metric['venv_train'], metric['venv_val'])
venv_list = self.get_venv_list()
if len(self.metric_dict.values()) <= 1:
if (len(venv_list) <= self.venv_max_num) and (self.best_id in self.venv_deque_dict.keys()):
venv_list += [venv_pair for venv_pair in list(self.venv_deque_dict[self.best_id])[:-1][::-1]]
venv_list = venv_list[:self.venv_max_num]
else:
venv_list = venv_list[:self.venv_max_num]
venv = VirtualEnv([pair[0] for pair in venv_list] + [pair[1] for pair in venv_list])
self.set_best_venv(venv)
[docs]
def get_max_acc(self) -> float:
return self.max_acc
[docs]
def get_least_metric(self) -> float:
return self.least_metric
[docs]
def get_best_id(self) -> int:
return self.best_id
[docs]
def get_venv_list(self) -> List[VirtualEnvDev]:
return [(metric['venv_train'], metric['venv_val']) for metric in self.metric_dict.values() if metric['venv_train'] is not None]
[docs]
def get_dict(self):
# clean out venv references
metric_dict = OrderedDict()
for id, mdict in self.metric_dict.items():
new_mdict = {k : v for k, v in mdict.items() if not isinstance(v, VirtualEnvDev)}
metric_dict[id] = new_mdict
return {
"num_of_trial" : self.num_of_trial,
"total_num_of_trials" : self.total_num_of_trials,
"least_metric" : self.least_metric,
"max_acc" : self.max_acc,
"best_id" : self.best_id,
"metrics" : metric_dict
}
[docs]
def write(self, filename : str):
with open(filename, 'w') as f:
json.dump(self.get_dict(), f, indent=4)
[docs]
class DataBufferPolicy:
def __init__(self):
self.num_of_trial = 0
self.max_reward = - float('inf')
self.total_num_of_trials = -1
self.best_id = None
self.reward_dict = OrderedDict()
self.status_dict = OrderedDict()
self.best_policy = None
self.best_model_workspace = None
[docs]
def update_status(self, task_id : int, status : str, message : str = ''):
old_message = '' if not task_id in self.status_dict.keys() else self.status_dict[task_id][1]
self.status_dict[task_id] = (status, old_message + message)
[docs]
def get_status(self) -> Dict[int, Tuple[str, str]]:
return self.status_dict
[docs]
def set_best_policy(self, policy : PolicyModel):
self.best_policy = policy
[docs]
def get_best_policy(self) -> PolicyModel:
return self.best_policy
[docs]
def get_best_model_workspace(self) -> str:
return self.best_model_workspace
[docs]
def set_total_trials(self, trials : int):
self.total_num_of_trials = trials
[docs]
def inc_trial(self) -> int:
self.num_of_trial += 1
return self.num_of_trial
[docs]
def get_num_of_trial(self) -> int:
return self.num_of_trial
[docs]
def update_metric(self, task_id : int, metric : Dict[str, Union[float, PolicyModelDev]]):
self.reward_dict[task_id] = metric
self.reward_dict = OrderedDict(sorted(self.reward_dict.items(), key=lambda x: x[1]['reward'], reverse=True))
self.best_id, info = list(self.reward_dict.items())[0]
self.max_reward = info['reward']
self.best_model_workspace = info['traj_dir']
self.set_best_policy(PolicyModel(self.reward_dict[self.best_id]['policy']))
[docs]
def get_max_reward(self):
return self.max_reward
[docs]
def get_best_id(self):
return self.best_id
[docs]
def get_dict(self):
# clean out policy references
reward_dict = OrderedDict()
for id, mdict in self.reward_dict.items():
new_mdict = {k : v for k, v in mdict.items() if not isinstance(v, PolicyModelDev)}
reward_dict[id] = new_mdict
return {
"num_of_trial" : self.num_of_trial,
"total_num_of_trials" : self.total_num_of_trials,
"max_reward" : self.max_reward,
"best_id" : self.best_id,
"rewards" : reward_dict
}
[docs]
def write(self, filename : str):
with open(filename, 'w') as f:
json.dump(self.get_dict(), f, indent=4)
[docs]
class DataBufferTuner:
def __init__(self, mode : str, budget : int):
self.mode = mode
self.current_trail = 0
self.budget = budget
self.best_metric = - float('inf') if self.mode == 'max' else float('inf')
self.best_parameter = None
[docs]
def get_state(self):
return {
'best_metric' : self.best_metric,
'best_parameter' : self.best_parameter,
'searched_trail' : self.current_trail,
'budget' : self.budget
}
[docs]
def update(self, parameter : Dict[str, np.ndarray], metric : float):
self.current_trail += 1
if self.mode == 'max':
if metric > self.best_metric:
self.best_metric = metric
self.best_parameter = parameter
else:
if metric < self.best_metric:
self.best_metric = metric
self.best_parameter = parameter
[docs]
class Logger:
"""
This is a class called Logger that logs key-value pairs.
"""
def __init__(self):
self.log = {}
[docs]
def get_log(self):
return self.log
[docs]
def update(self, key, value):
self.log[key] = value
[docs]
def trial_str_creator(trial):
return "{}_{}".format("ReviveLog", trial.trial_id)
[docs]
def catch_error(func):
'''Push the training error message to data buffer'''
def wrapped_func(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception as e:
error_message = traceback.format_exc()
self.logru_logger.error('Detect error:{}, Error Message: {}'.format(e,error_message))
self.logger.update.remote(key="task_state", value="End")
return wrapped_func
[docs]
class TuneVenvTrain(object):
def __init__(self, config, venv_logger, command=None):
logger.add(config["revive_log_path"])
self.logru_logger = logger
self.config = config
self.logger = venv_logger
self.workspace = os.path.join(self.config["workspace"], 'venv_tune')
if not os.path.exists(self.workspace):
os.makedirs(self.workspace)
self.algo = VenvAlgorithm(self.config["venv_algo"],self.workspace)
if 'venv_algo_config' in config.keys() and self.config['venv_algo'] in config['venv_algo_config'].keys():
update_description(self.algo.operator.PARAMETER_DESCRIPTION, config['venv_algo_config'][self.config['venv_algo']])
self.config.update(self.algo.get_parameters(command))
os.environ['TUNE_GLOBAL_CHECKPOINT_S'] = self.config['global_checkpoint_period']
[docs]
@catch_error
def train(self):
self.logger.update.remote(key="task_state", value="Run")
tune_params = self.algo.get_tune_parameters(self.config)
trainer = self.algo.get_train_func(self.config)
# Seet the seed
setup_seed(self.config["global_seed"])
tune_config = {"mode" : "max",
"metric" : "mean_accuracy",
"search_alg" : tune_params["search_alg"],
"num_samples" : tune_params["num_samples"],
"reuse_actors" : tune_params["reuse_actors"]}
run_config = {"name" : tune_params["name"],
"local_dir": tune_params["local_dir"],
"stop" : SysStopper(workspace = self.config['workspace']),
"callbacks" : tune_params["callbacks"],
"verbose": tune_params["verbose"]}
"""
try:
cluster_resources = ray.cluster_resources()
num_gpus = cluster_resources["GPU"]
num_cpus = cluster_resources["CPU"]
num_gpus_blocks = int(1/self.config['venv_gpus_per_worker']) * num_gpus
resources_per_trial_cpu = max(int(num_cpus/num_gpus_blocks),1)
except:
resources_per_trial_cpu = 1
"""
resources_per_trial_cpu = 1
_tuner_kwargs = {"trial_name_creator" : trial_str_creator,
"resources_per_trial":{"cpu": resources_per_trial_cpu, "gpu": self.config['venv_gpus_per_worker']},
"progress_reporter":tune_params["progress_reporter"]}
tuner = tune.Tuner(
trainer,
tune_config=tune.TuneConfig(**tune_config),
run_config=ray.air.config.RunConfig(**run_config),
_tuner_kwargs = _tuner_kwargs,
param_space=tune_params['config'])
results = tuner.fit()
self.logger.update.remote(key="task_state", value="End")
[docs]
class TunePolicyTrain(object):
def __init__(self, config, policy_logger, venv_logger=None, command=None):
logger.add(config["revive_log_path"])
self.logru_logger = logger
self.config = config
self.logger = policy_logger
self.venv_logger = venv_logger
self.workspace = os.path.join(self.config["workspace"], 'policy_tune')
if not os.path.exists(self.workspace):
os.makedirs(self.workspace)
self.algo = PolicyAlgorithm(self.config['policy_algo'], self.workspace)
if 'policy_algo_config' in config.keys() and self.config['policy_algo'] in config['policy_algo_config'].keys():
update_description(self.algo.operator.PARAMETER_DESCRIPTION, config['policy_algo_config'][self.config['policy_algo']])
self.config.update(self.algo.get_parameters(command))
os.environ['TUNE_GLOBAL_CHECKPOINT_S'] = self.config['global_checkpoint_period']
[docs]
@catch_error
def train(self,):
if self.venv_logger is not None:
while True: # block until venv train finish
log = ray.get(self.venv_logger.get_log.remote())
if log.get('task_state') == 'End':
break
time.sleep(10)
if not os.path.exists(os.path.join(self.config['workspace'], 'env.pkl')):
logger.error(f"Don't find env model.")
import sys
sys.exit()
self.logger.update.remote(key="task_state", value="Run")
tune_params = self.algo.get_tune_parameters(self.config)
trainer = self.algo.get_train_func(self.config)
# Seet the seed
setup_seed(self.config["global_seed"])
tune_config = {"mode" : "max",
"metric" : "mean_accuracy",
"search_alg" : tune_params["search_alg"],
"num_samples" : tune_params["num_samples"],
"reuse_actors" : tune_params["reuse_actors"]}
run_config = {"name" : tune_params["name"],
"local_dir": tune_params["local_dir"],
"stop" : SysStopper(workspace = self.config['workspace']),
"callbacks" : tune_params["callbacks"],
"verbose": tune_params["verbose"]}
"""
try:
cluster_resources = ray.cluster_resources()
num_gpus = cluster_resources["GPU"]
num_cpus = cluster_resources["CPU"]
num_gpus_blocks = int(1/self.config['policy_gpus_per_worker']) * num_gpus
resources_per_trial_cpu = max(int(num_cpus/num_gpus_blocks),1)
except:
resources_per_trial_cpu = 1
"""
resources_per_trial_cpu = 1
_tuner_kwargs = {"trial_name_creator" : trial_str_creator,
"resources_per_trial":{"cpu": resources_per_trial_cpu, "gpu": self.config['policy_gpus_per_worker']},
"progress_reporter":tune_params["progress_reporter"]}
tuner = tune.Tuner(
trainer,
tune_config=tune.TuneConfig(**tune_config),
run_config=ray.air.config.RunConfig(**run_config),
_tuner_kwargs = _tuner_kwargs)
results = tuner.fit()
self.logger.update.remote(key="task_state", value="End")
[docs]
class VenvTrain(object):
def __init__(self, config, venv_logger, command=None):
logger.add(config["revive_log_path"])
self.logru_logger = logger
self.config = config
self.logger = venv_logger
self.workspace = os.path.join(self.config["workspace"], 'venv_train')
if not os.path.exists(self.workspace):
os.makedirs(self.workspace)
self.algo = VenvAlgorithm(self.config["venv_algo"], self.workspace)
if 'venv_algo_config' in config.keys() and self.config['venv_algo'] in config['venv_algo_config'].keys():
update_description(self.algo.operator.PARAMETER_DESCRIPTION, config['venv_algo_config'][self.config['venv_algo']])
self.config.update(self.algo.get_parameters(command))
[docs]
@catch_error
def train(self):
self.logger.update.remote(key="task_state", value="Run")
# Dynamically obtain the corresponding trainer object for the algorithm
trainer = self.algo.get_trainer(self.config)
trainer.run_config.verbose = 0
# Seet the seed
setup_seed(self.config["global_seed"])
trainer.fit()
self.logger.update.remote(key="task_state", value="End")
# trainer.shutdown() # Without this line, GPU memory will leak
[docs]
class PolicyTrain(object):
def __init__(self, config, policy_logger, venv_logger=None, command=None):
logger.add(config["revive_log_path"])
self.logru_logger = logger
self.config = config
self.logger = policy_logger
self.venv_logger = venv_logger
self.workspace = os.path.join(self.config["workspace"], 'policy_train')
if not os.path.exists(self.workspace):
os.makedirs(self.workspace)
self.algo = PolicyAlgorithm(self.config['policy_algo'], self.workspace)
if 'policy_algo_config' in config.keys() and self.config['policy_algo'] in config['policy_algo_config'].keys():
update_description(self.algo.operator.PARAMETER_DESCRIPTION, config['policy_algo_config'][self.config['policy_algo']])
self.config.update(self.algo.get_parameters(command))
[docs]
@catch_error
def train(self):
if self.venv_logger is not None:
while True: # block until venv train finish
log = ray.get(self.venv_logger.get_log.remote())
if log.get('task_state') == 'End':
break
time.sleep(10)
while True: # block until venv available
if os.path.exists(os.path.join(self.config['workspace'], 'env.pkl')):
break
else:
logger.error(f"Don't find env model.")
self.logger.update.remote(key="task_state", value="End")
self.logger.update.remote(key="task_state", value="Run")
# Dynamically obtain the corresponding trainer object for the algorithm
trainer = self.algo.get_trainer(self.config)
trainer.run_config.verbose = 0
# Seet the seed
setup_seed(self.config["global_seed"])
trainer.fit()
self.logger.update.remote(key="task_state", value="End")
# trainer.shutdown() # Without this line, GPU memory will leak
[docs]
def default_evaluate(config):
static = config.pop('static')
env = ray.get(static['venv_buffer'].get_best_venv.remote())
state = static['state']
objective = static['objective']
buffer = static['buffer']
graph = env.graph
parameter = {}
for tunable_name in graph.tunable:
parameter[tunable_name] = np.array(
[config[parameter_name] for parameter_name in sorted(filter(lambda x: tunable_name in x, config.keys()))]
)
state[0].update(parameter)
states = env.infer_k_steps(state)
value = sum([objective(s) for s in states])
buffer.update.remote(parameter, value)
return {'objective' : value}
[docs]
class ParameterTuner(object):
def __init__(self, config, mode, initial_state, logger, venv_logger=None):
self.config = config
self.mode = mode
self.initial_state = initial_state
self.logger = logger
self.venv_logger = venv_logger
[docs]
def run(self):
if self.venv_logger is not None:
while True: # block until venv train finish
log = ray.get(self.venv_logger.get_log.remote())
if log.get('task_state') == 'End':
break
time.sleep(10)
while True: # block until venv available
if os.path.exists(os.path.join(self.config['workspace'], 'env.pkl')):
break
logger.info('Waiting for venv ...')
time.sleep(10)
self.logger.update.remote(key="task_state", value="Run")
env = ray.get(self.config['venv_data_buffer'].get_best_venv.remote())
graph = env.graph
dataset = ray.get(self.config['dataset'])
if len(graph.external_factors) - len(graph.tunable) > 0:
for k, v in self.initial_state.items():
if len(v.shape) == 2:
horizon = v.shape[0]
state = [{node_name : self.initial_state[node_name] for node_name in graph.transition_map.keys()}] + [{}] * (horizon - 1)
for i in range(horizon):
for k, v in self.initial_state.items():
if len(v.shape) == 2: state[i][k] = v[i]
warnings.warn(f'Detect leaf node on graph, reset rollout horizon to {horizon}!')
else:
if self.config['parameter_tuning_rollout_horizon'] > dataset.max_length:
warnings.warn('Detect rollout length higher than max length in the dataset!')
state = [self.initial_state] + [{}] * (self.config['parameter_tuning_rollout_horizon'] - 1)
static_config = {'static' : {'venv_buffer' : self.config['venv_data_buffer'], 'state' : state, 'objective' : self.config['user_func'], 'buffer' : self.config['tuner_data_buffer']}}
reporter = CLIReporter(max_progress_rows=50)
reporter.add_metric_column("objective")
tune_params = {
"name": "parameter_tuning",
"progress_reporter": reporter,
'metric' : 'objective',
'mode' : self.mode,
"reuse_actors": self.config["reuse_actors"],
"local_dir": self.config["workspace"],
"loggers": get_tune_callbacks(),
"verbose": self.config["verbose"],
'num_samples' : self.config['parameter_tuning_budget']
}
if self.config['parameter_tuning_algorithm'] == 'random':
random_search_config = static_config
for tunable_name in graph.tunable:
for i, d in enumerate(dataset.raw_columns[tunable_name]):
name = list(d.keys())[0]
_config = d[name]
if _config['type'] == 'continuous':
random_search_config[f'{tunable_name}_{"%.09d" % i}'] = tune.uniform(_config['min'], _config['max'])
elif _config['type'] == 'discrete':
random_search_config[f'{tunable_name}_{"%.09d" % i}'] = tune.grid_search(np.linspace(_config['min'], _config['max'], _config['num']).tolist())
elif _config['type'] == 'category':
random_search_config[f'{tunable_name}_{"%.09d" % i}'] = tune.grid_search(_config['values'])
tune_params['config'] = random_search_config
tune_params['search_alg'] = CustomBasicVariantGenerator()
elif self.config['parameter_tuning_algorithm'] == 'zoopt':
from ray.tune.suggest.zoopt import ZOOptSearch
from zoopt import ValueType
num_of_cpu = int(ray.available_resources()['CPU'])
parallel_num = num_of_cpu
assert parallel_num > 0
dim_dict = {}
for tunable_name in graph.tunable:
for i, d in enumerate(dataset.raw_columns[tunable_name]):
name = list(d.keys())[0]
_config = d[name]
if _config['type'] == 'continuous':
dim_dict[f'{tunable_name}_{"%.09d" % i}'] = (ValueType.CONTINUOUS, [_config['min'], _config['max']], 1e-10)
elif _config['type'] == 'discrete':
dim_dict[f'{tunable_name}_{"%.09d" % i}'] = (ValueType.DISCRETE, np.linspace(_config['min'], _config['max'], _config['num']).tolist())
elif _config['type'] == 'category':
dim_dict[f'{tunable_name}_{"%.09d" % i}'] = (ValueType.GRID, _config['values'])
zoopt_search_config = {
"parallel_num": parallel_num
}
tune_params['search_alg'] = ZOOptSearch(
algo="Asracos", # only support Asracos currently
budget=self.config['parameter_tuning_budget'],
dim_dict=dim_dict,
metric='objective',
mode=self.mode,
**zoopt_search_config
)
tune_params['config'] = static_config
tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator
analysis = tune.run(
default_evaluate,
**tune_params
)
self.logger.update.remote(key="task_state", value="End")