''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import os
import ray
import shutil
import torch
import pickle
import warnings
import importlib
import traceback
import numpy as np
from scipy import stats
from loguru import logger
from copy import deepcopy
from ray import tune
from ray import train
from ray.train.torch import TorchTrainer
from revive.computation.graph import DesicionGraph
from revive.computation.inference import *
from revive.data.batch import Batch
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import NUM_SAMPLES, AverageMeterCollection
from revive.utils.tune_utils import get_tune_callbacks, CustomSearchGenerator, CustomBasicVariantGenerator, CLIReporter
from revive.utils.common_utils import *
from revive.utils.auth_utils import customer_uploadTrainLog
warnings.filterwarnings('ignore')
[docs]def catch_error(func):
'''push the training error message to data buffer'''
def wrapped_func(self, *args, **kwargs):
return func(self, *args, **kwargs)
"""
if True:
return func(self, *args, **kwargs)
else:
error_message = traceback.format_exc()
logger.warning('Detect error:{}, Error Message: {}'.format(e,error_message))
ray.get(self._data_buffer.update_status.remote(self._traj_id, 'error', error_message))
self._stop_flag = True
try:
customer_uploadTrainLog(self.config["trainId"],
os.path.join(os.path.abspath(self._workspace),"revive.log"),
"train.simulator",
"fail",
self._acc,
self.config["accessToken"])
except Exception as e:
logger.info(f"{e}")
return {
'stop_flag' : True,
'now_metric' : np.inf,
'least_metric' : np.inf,
}
"""
return wrapped_func
[docs]class VenvOperator():
r"""
The base venv class.
"""
NAME = None # this need to be set in any subclass
r"""
Name of the used algorithm.
"""
@property
def metric_name(self):
r"""
This define the metric we try to minimize with hyperparameter search.
"""
return f"{self.NAME}/average_{self.config['venv_metric']}"
@property
def nodes_models_train(self):
return self.train_models[:self.config['learning_nodes_num']]
@property
def other_models_train(self):
return self.train_models[self.config['learning_nodes_num']:]
@property
def nodes_models_val(self):
return self.val_models[:self.config['learning_nodes_num']]
@property
def other_models_val(self):
return self.val_models[self.config['learning_nodes_num']:]
# NOTE: you need either write the `PARAMETER_DESCRIPTION` or overwrite `get_parameters` and `get_tune_parameters`.
PARAMETER_DESCRIPTION = [] # a list of dict to describe the parameter of the algorithm
[docs] @classmethod
def get_parameters(cls, command=None, **kargs):
parser = list2parser(cls.PARAMETER_DESCRIPTION)
return parser.parse_known_args(command)[0].__dict__
[docs] @classmethod
def get_tune_parameters(cls, config : dict, **kargs):
r"""
Use ray.tune to wrap the parameters to be searched.
"""
_search_algo = config['venv_search_algo'].lower()
tune_params = {
"name": "venv_tune",
"reuse_actors": config["reuse_actors"],
"local_dir": config["workspace"],
"callbacks": get_tune_callbacks(),
"stop": {
"stop_flag": True
},
"verbose": config["verbose"],
}
if _search_algo == 'random':
random_search_config = {}
for description in cls.PARAMETER_DESCRIPTION:
if 'tune' in description.keys() and not description["tune"]:
continue
if 'search_mode' in description.keys() and 'search_values' in description.keys():
if description['search_mode'] == 'continuous':
random_search_config[description['name']] = tune.uniform(*description['search_values'])
elif description['search_mode'] == 'grid' or description['search_mode'] == 'discrete':
random_search_config[description['name']] = tune.choice(description['search_values'])
config["total_num_of_trials"] = config['train_venv_trials']
tune_params['config'] = random_search_config
tune_params['num_samples'] = config["total_num_of_trials"]
tune_params['search_alg'] = CustomBasicVariantGenerator()
elif _search_algo == 'zoopt':
# from ray.tune.search.zoopt import ZOOptSearch
from revive.utils.tune_utils import ZOOptSearch
from zoopt import ValueType
if config['parallel_num'] == 'auto' and 'search_values' in description.keys():
if config['use_gpu']:
num_of_gpu = int(ray.available_resources()['GPU'])
num_of_cpu = int(ray.available_resources()['CPU'])
parallel_num = min(int(num_of_gpu / config['venv_gpus_per_worker']), num_of_cpu)
else:
num_of_cpu = int(ray.available_resources()['CPU'])
parallel_num = num_of_cpu
else:
parallel_num = int(config['parallel_num'])
assert parallel_num > 0
config['parallel_num'] = parallel_num
dim_dict = {}
for description in cls.PARAMETER_DESCRIPTION:
if 'tune' in description.keys() and not description["tune"]:
continue
if 'search_mode' in description.keys():
if description['search_mode'] == 'continuous':
dim_dict[description['name']] = (ValueType.CONTINUOUS, description['search_values'], min(description['search_values']))
elif description['search_mode'] == 'discrete':
dim_dict[description['name']] = (ValueType.DISCRETE, description['search_values'])
elif description['search_mode'] == 'grid':
dim_dict[description['name']] = (ValueType.GRID, description['search_values'])
zoopt_search_config = {
"parallel_num": config['parallel_num']
}
config["total_num_of_trials"] = config['train_venv_trials']
tune_params['search_alg'] = ZOOptSearch(
algo="Asracos", # only support Asracos currently
budget=config["total_num_of_trials"],
dim_dict=dim_dict,
metric='least_metric',
mode="min",
**zoopt_search_config
)
tune_params['config'] = dim_dict
tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator
tune_params['num_samples'] = config["total_num_of_trials"]
elif 'grid' in _search_algo:
grid_search_config = {}
config["total_num_of_trials"] = 1
for description in cls.PARAMETER_DESCRIPTION:
if 'tune' in description.keys() and not description["tune"]:
continue
if 'search_mode' in description.keys() and 'search_values' in description.keys():
if description['search_mode'] == 'grid':
grid_search_config[description['name']] = tune.grid_search(description['search_values'])
config["total_num_of_trials"] *= len(description['search_values'])
else:
warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \
f"However, since grid search does not support search type {description['search_mode']}, the parameter is skipped. " + \
f"If this parameter is important to the performance, you should consider other search algorithms. ")
tune_params['config'] = grid_search_config
tune_params['num_samples'] = config["total_num_of_trials"]
tune_params['search_alg'] = CustomBasicVariantGenerator()
elif 'bayes' in _search_algo:
from ray.tune.search.bayesopt import BayesOptSearch
bayes_search_config = {}
for description in cls.PARAMETER_DESCRIPTION:
if 'tune' in description.keys() and not description["tune"]:
continue
if 'search_mode' in description.keys() and 'search_values' in description.keys():
if description['search_mode'] == 'continuous':
bayes_search_config[description['name']] = description['search_values']
else:
warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \
f"However, since bayesian search does not support search type {description['search_mode']}, the parameter is skipped. " + \
f"If this parameter is important to the performance, you should consider other search algorithms. ")
config["total_num_of_trials"] = config['train_venv_trials']
tune_params['config'] = bayes_search_config
tune_params['search_alg'] = BayesOptSearch(bayes_search_config, metric="least_metric", mode="min")
tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator
tune_params['num_samples'] = config["total_num_of_trials"]
else:
raise ValueError(f'search algorithm {_search_algo} is not supported!')
reporter = CLIReporter(parameter_columns=list(tune_params['config'].keys()), max_progress_rows=50, max_report_frequency=10, sort_by_metric=True)
reporter.add_metric_column("least_metric", representation='least metric loss')
reporter.add_metric_column("now_metric", representation='current metric loss')
tune_params["progress_reporter"] = reporter
return tune_params
[docs] def model_creator(self, config : dict, graph : DesicionGraph):
r"""
Create all the models. The algorithm needs to define models for the nodes to be learned.
Args:
:config: configuration parameters
Return:
a list of models
"""
raise NotImplementedError
[docs] def optimizer_creator(self, models : List[torch.nn.Module], config : dict):
r"""
Define optimizers for the created models.
Args:
:pmodels: list of all the models
:config: configuration parameters
Return:
a list of optimizers
"""
raise NotImplementedError
[docs] def data_creator(self, config : dict):
r"""
Create DataLoaders.
Args:
:config: configuration parameters
Return:
(train_loader, val_loader)
"""
return data_creator(config, training_mode='transition', val_horizon=config['venv_rollout_horizon'], double=True)
def _setup_componects(self, config : dict):
r'''setup models, optimizers and dataloaders.'''
# register data loader for double venv training
train_loader_train, val_loader_train, train_loader_val, val_loader_val = self.data_creator(config)
try:
self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False)
self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False)
self._train_loader_val = train.torch.prepare_data_loader(train_loader_val, move_to_device=False)
self._val_loader_val = train.torch.prepare_data_loader(val_loader_val, move_to_device=False)
except:
self._train_loader_train = train_loader_train
self._val_loader_train = val_loader_train
self._train_loader_val = train_loader_val
self._val_loader_val = val_loader_val
self.train_models = self.model_creator(config, self.graph_train)
for model_index, model in enumerate(self.train_models):
try:
self.train_models[model_index] = train.torch.prepare_model(model)
except:
self.train_models[model_index] = model.to(self._device)
self.train_optimizers = self.optimizer_creator(self.train_models, config)
self.val_models = self.model_creator(config, self.graph_val)
for model_index, model in enumerate(self.val_models):
try:
self.val_models[model_index] = train.torch.prepare_model(model)
except:
self.val_models[model_index] = model.to(self._device)
self.val_optimizers = self.optimizer_creator(self.val_models, config)
def _register_models_to_graph(self, graph : DesicionGraph, models : List[torch.nn.Module]):
index = 0
# register policy nodes
for node_name in list(graph.keys()):
if node_name in graph.transition_map.values():
continue
node = graph.get_node(node_name)
if node.node_type == 'network':
node.set_network(models[index])
index += 1
# register transition nodes
for node_name in graph.transition_map.values():
node = graph.get_node(node_name)
if node.node_type == 'network':
node.set_network(models[index])
index += 1
assert len(models) == index, f'Some models are not registered. Total models: {len(models)}, Registered: {index}.'
@catch_error
def __init__(self, config : dict):
r'''setup everything for training.
Args:
:config: configuration parameters
'''
# parse information from config
self.config = config
self.train_dataset = ray.get(config['dataset'])
self.val_dataset = ray.get(config['val_dataset'])
self._data_buffer = config['venv_data_buffer']
self._workspace = config["workspace"]
self._graph = deepcopy(config['graph'])
self._filename = os.path.join(self._workspace, "train_venv.json")
self._data_buffer.set_total_trials.remote(config.get("total_num_of_trials", 1))
self._data_buffer.inc_trial.remote()
self._least_metric_train = [np.inf] * len(self._graph.metric_nodes)
self._least_metric_val = [np.inf] * len(self._graph.metric_nodes)
self.least_val_metric = np.inf
self.least_train_metric = np.inf
self._acc = 0
self._num_venv_list = []
# get id
self._ip = ray._private.services.get_node_ip_address()
logger.add(os.path.join(os.path.abspath(self._workspace),"revive.log"))
# set trail seed
setup_seed(config["global_seed"])
if 'tag' in self.config: # create by tune
tag = self.config['tag']
logger.info("{} is running".format(tag))
self._traj_id = int(tag.split('_')[0])
experiment_dir = os.path.join(self._workspace, 'venv_tune')
traj_name_list = sorted(os.listdir(experiment_dir),key=lambda x:x[-19:], reverse=True)
for traj_name in filter(lambda x: "ReviveLog" in x, traj_name_list):
if len(traj_name.split('_')[1]) == 5: # create by random search or grid search
id_index = 3
else:
id_index = 2
if int(traj_name.split('_')[id_index]) == self._traj_id:
self._traj_dir = os.path.join(experiment_dir, traj_name)
break
else:
self._traj_id = 1
self._traj_dir = os.path.join(self._workspace, 'venv_train')
update_env_vars("pattern", "venv")
# setup constant
self._stop_flag = False
self._batch_cnt = 0
self._epoch_cnt = 0
self._last_wdist_epoch = 0
self._wdist_id_train = None
self._wdist_id_val = None
self._wdist_ready_train = False
self._wdist_ready_val = False
self._wdist_ready = False
self._use_gpu = self.config["use_gpu"] and torch.cuda.is_available()
self._device = 'cuda' if self._use_gpu else 'cpu' # fix problem introduced in ray 1.1
# if "shooting_" in self.config['venv_metric']:
# self.config['venv_metric'] = self.config['venv_metric'].replace("shooting_", "rollout_")
if self.config['venv_metric'] in ["mae","mse","nll"]:
self.config['venv_metric'] = "rollout_" + self.config['venv_metric']
# prepare for training
self.graph_train = deepcopy(self._graph)
self.graph_val = deepcopy(self._graph)
self._setup_componects(config)
# register models to graph
self._register_models_to_graph(self.graph_train, self.nodes_models_train)
self._register_models_to_graph(self.graph_val, self.nodes_models_val)
self.total_dim = 0
for node_name in self._graph.metric_nodes:
self.total_dim += self.config['total_dims'][node_name]['input']
self.nodes_dim_name_map = {}
for node_name in list(self._graph.nodes) + list(self._graph.leaf):
node_dims = []
for node_dim in self._graph.descriptions[node_name]:
node_dims.append(list(node_dim.keys())[0])
self.nodes_dim_name_map[node_name] = node_dims
logger.info(f"Nodes : {self.nodes_dim_name_map}")
self.graph_to_save_train = self.graph_train
self.graph_to_save_val = self.graph_val
self.best_graph_train = deepcopy(self.graph_train)
self.best_graph_val = deepcopy(self.graph_val)
if self._traj_dir.endswith("venv_train"):
if "venv.pkl" in os.listdir(self._traj_dir):
logger.info("Find existing checkpoint, Back up existing model.")
try:
self._traj_dir_bak = self._traj_dir+"_bak"
self._filename_bak = self._filename+"_bak"
shutil.copytree(self._traj_dir, self._traj_dir_bak )
shutil.copy(self._filename, self._filename_bak)
except:
pass
self._save_models(self._traj_dir)
self._update_metric()
self.global_step = 0
[docs] def nan_in_grad(self):
if hasattr(self, "nums_nan_in_grad"):
self.nums_nan_in_grad += 1
else:
self.nums_nan_in_grad = 1
if self.nums_nan_in_grad > 100:
self._stop_flag = True
logger.warning(f'Find too many nan in loss. Early stop.')
def _early_stop(self, info : dict):
info["stop_flag"] = self._stop_flag
return info
def _update_metric(self):
''' update metric to data buffer '''
# NOTE: Very large model (e.g. 1024 x 4 LSTM) cannot be updated.
try:
venv_train = torch.load(os.path.join(self._traj_dir, 'venv_train.pt'), map_location='cpu')
venv_val = torch.load(os.path.join(self._traj_dir, 'venv_val.pt'), map_location='cpu')
except:
venv_train = venv_val = None
try:
metric = self.least_metric
except:
metric = np.sum(self._least_metric_train) / self.total_dim
if 'mse' in self.metric_name or 'mae' in self.metric_name:
acc = (self.config['max_distance'] - np.log(metric)) / (self.config['max_distance'] - self.config['min_distance'])
else:
acc = (self.config['max_distance'] - metric) / (self.config['max_distance'] - self.config['min_distance'])
acc = min(max(acc,0),1)
acc = 0.5 + (0.4*acc)
if acc == 0.5:
acc = 0
self._acc = acc
self._data_buffer.update_metric.remote(self._traj_id, {
"metric": metric,
"acc" : acc,
"ip": self._ip,
"venv_train" : venv_train,
"venv_val" : venv_val,
"traj_dir" : self._traj_dir,
})
self._data_buffer.write.remote(self._filename)
if metric == ray.get(self._data_buffer.get_least_metric.remote()):
self._save_models(self._workspace, with_env=False)
def _save_models(self, path: str, with_env:bool=True, model_prefixes:str=""):
"""
param: path, where to save the models
param: with_env, whether to save venv along with the models
"""
if model_prefixes:
model_prefixes = model_prefixes + "_"
best_graph_train = deepcopy(self.graph_train)
best_graph_val = deepcopy(self.graph_val)
best_graph_train.reset()
for node_name in best_graph_train.keys():
node = best_graph_train.get_node(node_name)
if node.node_type == 'network':
network = deepcopy(node.get_network()).cpu()
torch.save(network, os.path.join(path, node_name + '_train.pt'))
best_graph_val.reset()
for node_name in best_graph_val.keys():
node = best_graph_val.get_node(node_name)
if node.node_type == 'network':
network = deepcopy(node.get_network()).cpu()
torch.save(network, os.path.join(path, node_name + '_val.pt'))
if with_env:
best_graph_train = deepcopy(best_graph_train).to("cpu")
venv_train = VirtualEnvDev(best_graph_train)
torch.save(venv_train, os.path.join(path, "venv_train.pt"))
best_graph_val = deepcopy(best_graph_val).to("cpu")
venv_val = VirtualEnvDev(best_graph_val)
torch.save(venv_val, os.path.join(path, "venv_val.pt"))
venv = VirtualEnv([venv_train, venv_val])
with open(os.path.join(path, model_prefixes + 'venv.pkl'), 'wb') as f:
pickle.dump(venv, f)
venv_list = ray.get(self._data_buffer.get_best_venv.remote()) #self._data_buffer.get_best_venv() #
with open(os.path.join(path, model_prefixes +'ensemble_env.pkl'), 'wb') as f:
pickle.dump(venv_list, f)
def _load_best_models(self):
best_graph_train = deepcopy(self.graph_train)
best_graph_val = deepcopy(self.graph_val)
for node_name in best_graph_train.keys():
best_node = best_graph_train.get_node(node_name)
current_node = self.graph_train.get_node(node_name)
if best_node.node_type == 'network':
current_node.get_network().load_state_dict(best_node.get_network().state_dict())
for node_name in best_graph_val.keys():
best_node = best_graph_val.get_node(node_name)
current_node = self.graph_val.get_node(node_name)
if best_node.node_type == 'network':
current_node.get_network().load_state_dict(best_node.get_network().state_dict())
@torch.no_grad()
def _log_histogram(self, expert_data, generated_data, scope='valEnv_on_trainData'):
if scope == 'valEnv_on_trainData':
graph = self.graph_val
else:
graph = self.graph_train
info = {}
for node_name in graph.keys():
# compute values
if graph.get_node(node_name).node_type == 'network':
node_dist = graph.compute_node(node_name, expert_data)
expert_action = expert_data[node_name]
generated_action = node_dist.sample()
policy_std = node_dist.std
else:
expert_action = expert_data[node_name]
generated_action = graph.compute_node(node_name, expert_data)
policy_std = torch.zeros_like(generated_action)
# make logs
error = expert_action - generated_action
expert_action = expert_action.cpu()
generated_action = generated_action.cpu()
error = error.cpu()
policy_std = policy_std.cpu()
for i in range(error.shape[-1]):
info[f'{node_name}_dim{i}_{scope}/error'] = error.select(-1, i)
info[f'{node_name}_dim{i}_{scope}/expert'] = expert_action.select(-1, i)
info[f'{node_name}_dim{i}_{scope}/sampled'] = generated_action.select(-1, i)
info[f'{node_name}_dim{i}_{scope}/sampled_std'] = policy_std.select(-1, i)
return info
def _env_test(self, scope : str = 'train'):
# pick target policy
graph = self.graph_train if scope == 'train' else self.graph_val
node = graph.get_node(self.config['target_policy_name'])
env = create_env(self.config['task'])
if env is None:
return {}
graph.reset()
node = deepcopy(node)
node = node.to('cpu')
policy = PolicyModelDev(node)
policy = PolicyModel(policy)
reward, length = test_on_real_env(env, policy)
return {
f"{self.NAME}/real_reward_{scope}" : reward,
f"{self.NAME}/real_length_{scope}" : length,
}
def _mse_test(self, expert_data, generated_data, scope='valEnv_on_trainData'):
info = {}
if not self.config['mse_test'] and not 'mse' in self.metric_name:
return info
if 'mse' in self.metric_name:
self.graph_to_save_train = self.graph_train
self.graph_to_save_val = self.graph_val
if scope == 'valEnv_on_trainData':
graph = self.graph_val
else:
graph = self.graph_train
graph.reset()
new_data = Batch({name : expert_data[name] for name in graph.leaf})
total_mse = 0
for node_name in graph.keys():
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"])
else:
isnan_index = None
if graph.get_node(node_name).node_type == 'network':
node_dist = graph.compute_node(node_name, new_data)
new_data[node_name] = node_dist.mode
else:
new_data[node_name] = graph.compute_node(node_name, new_data)
continue
if node_name in graph.metric_nodes:
if isnan_index is not None:
node_mse = (((new_data[node_name] - expert_data[node_name])*isnan_index) ** 2).sum(dim=-1).mean()
else:
node_mse = ((new_data[node_name] - expert_data[node_name]) ** 2).sum(dim=-1).mean()
total_mse += node_mse.item()
info[f"{self.NAME}/{node_name}_one_step_mse_{scope}"] = node_mse.item()
info[f"{self.NAME}/average_one_step_mse_{scope}"] = total_mse / self.total_dim
mse_error = 0
for node_name in graph.metric_nodes:
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"])
else:
isnan_index = None
if isnan_index is not None:
policy_rollout_mse = (((expert_data[node_name] - generated_data[node_name])*isnan_index) ** 2).sum(dim=-1).mean()
else:
policy_rollout_mse = ((expert_data[node_name] - generated_data[node_name]) ** 2).sum(dim=-1).mean()
mse_error += policy_rollout_mse.item()
info[f"{self.NAME}/{node_name}_rollout_mse_{scope}"] = policy_rollout_mse.item()
info[f"{self.NAME}/average_rollout_mse_{scope}"] = mse_error / self.total_dim
return info
def _nll_test(self, expert_data, generated_data, scope='valEnv_on_trainData'): # negative log likelihood
info = {}
if not self.config['nll_test'] and not 'nll' in self.metric_name:
return info
if 'nll' in self.metric_name:
self.graph_to_save_train = self.graph_train
self.graph_to_save_val = self.graph_val
if scope == 'valEnv_on_trainData':
graph = self.graph_val
else:
graph = self.graph_train
graph.reset()
new_data = Batch({name : expert_data[name] for name in graph.leaf})
total_nll = 0
for node_name in graph.keys():
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"],axis=-1)
else:
isnan_index = None
if node_name in graph.learnable_node_names and node_name in graph.metric_nodes:
node_dist = graph.compute_node(node_name, new_data)
new_data[node_name] = node_dist.mode
else:
new_data[node_name] = expert_data[node_name]
continue
if isnan_index is not None:
node_nll = - (node_dist.log_prob(expert_data[node_name])*isnan_index).mean()
else:
node_nll = - node_dist.log_prob(expert_data[node_name]).mean()
total_nll += node_nll.item()
info[f"{self.NAME}/{node_name}_one_step_nll_{scope}"] = node_nll.item()
info[f"{self.NAME}/average_one_step_nll_{scope}"] = total_nll / self.total_dim
total_nll = 0
for node_name in graph.metric_nodes:
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - expert_data[node_name + "_isnan_index_"]
else:
isnan_index = None
if node_name in graph.learnable_node_names:
node_dist = graph.compute_node(node_name, expert_data)
if isnan_index is not None:
policy_nll = - (node_dist.log_prob(expert_data[node_name]*isnan_index)).mean()
else:
policy_nll = - node_dist.log_prob(expert_data[node_name]).mean()
total_nll += policy_nll.item()
info[f"{self.NAME}/{node_name}_rollout_nll_{scope}"] = policy_nll.item()
else:
total_nll += 0
info[f"{self.NAME}/{node_name}_rollout_nll_{scope}"] = 0
info[f"{self.NAME}/average_rollout_nll_{scope}"] = total_nll / self.total_dim
return info
def _mae_test(self, expert_data, generated_data, scope='valEnv_on_trainData'):
info = {}
if not self.config['mae_test'] and not 'mae' in self.metric_name:
return info
if 'mae' in self.metric_name:
self.graph_to_save_train = self.graph_train
self.graph_to_save_val = self.graph_val
if scope == 'valEnv_on_trainData':
graph = self.graph_val
else:
graph = self.graph_train
graph.reset()
new_data = Batch({name : expert_data[name] for name in graph.leaf})
total_mae = 0
for node_name in graph.keys():
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"])
else:
isnan_index = None
if graph.get_node(node_name).node_type == 'network':
node_dist = graph.compute_node(node_name, new_data)
new_data[node_name] = node_dist.mode
else:
new_data[node_name] = graph.compute_node(node_name, new_data)
continue
if node_name in graph.metric_nodes:
if isnan_index is not None:
node_mae = ((new_data[node_name] - expert_data[node_name])*isnan_index).abs().sum(dim=-1).mean()
else:
node_mae = (new_data[node_name] - expert_data[node_name]).abs().sum(dim=-1).mean()
total_mae += node_mae.item()
info[f"{self.NAME}/{node_name}_one_step_mae_{scope}"] = node_mae.item()
info[f"{self.NAME}/average_one_step_mae_{scope}"] = total_mae / self.total_dim
mae_error = 0
for node_name in graph.keys():
if node_name in graph.metric_nodes:
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"])
else:
isnan_index = None
if isnan_index is not None:
policy_shooting_error = (torch.abs(expert_data[node_name] - generated_data[node_name])*isnan_index).sum(dim=-1).mean()
else:
policy_shooting_error = torch.abs(expert_data[node_name] - generated_data[node_name]).sum(dim=-1).mean()
mae_error += policy_shooting_error.item()
info[f"{self.NAME}/{node_name}_rollout_mae_{scope}"] = policy_shooting_error.item()
# TODO: plot rollout error
# rollout_error = torch.abs(expert_data[node_name] - generated_data[node_name]).reshape(expert_data.shape[0],-1).mean(dim=-1)
info[f"{self.NAME}/average_rollout_mae_{scope}"] = mae_error / self.total_dim
return info
def _wdist_test(self, expert_data, generated_data, scope='valEnv_on_trainData'):
info = {}
if not self.config['wdist_test'] and not 'wdist' in self.metric_name:
return info
if 'wdist' in self.metric_name:
self.graph_to_save_train = self.graph_train
self.graph_to_save_val = self.graph_val
if scope == 'valEnv_on_trainData':
graph = self.graph_val
else:
graph = self.graph_train
graph.reset()
wdist_error = []
for node_name in graph.keys():
if node_name in graph.metric_nodes:
# TODO: support isnan_index
node_dim = expert_data[node_name].shape[-1]
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"],axis=-1)
else:
isnan_index = None
wdist = [stats.wasserstein_distance(expert_data[node_name].reshape(-1, expert_data[node_name].shape[-1])[..., dim].cpu().numpy(),
generated_data[node_name].reshape(-1, generated_data[node_name].shape[-1])[..., dim].cpu().numpy())
for dim in range(node_dim)]
wdist = np.sum(wdist)
info[f"{self.NAME}/{node_name}_wdist_{scope}"] = wdist
wdist_error.append(wdist)
info[f"{self.NAME}/average_wdist_{scope}"] = np.sum(wdist_error) / self.total_dim
return info
[docs] @catch_error
def before_train_epoch(self):
update_env_vars("venv_epoch",self._epoch_cnt)
[docs] @catch_error
def train_epoch(self):
info = dict()
r"""Define the training process for an epoch."""
self._epoch_cnt += 1
logger.info(f"Train epoch : {self._epoch_cnt} ")
# switch to training mode
if hasattr(self, "model"):
self.model.train()
if hasattr(self, "models"):
for _model in self.models:
_model.train()
metric_meters_train = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._train_loader_train)):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
metrics = self.train_batch(batch, batch_info=batch_info, scope='train')
metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
metric_meters_val = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._val_loader_train)):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
metrics = self.train_batch(batch, batch_info=batch_info, scope='val')
metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
info = metric_meters_train.summary()
info.update(metric_meters_val.summary())
return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
[docs] @catch_error
def validate(self):
logger.info(f"Epoch : {self._epoch_cnt} ")
info = dict()
r"""Define the validate process after train one epoch."""
if hasattr(self, "model"):
self.model.eval()
if hasattr(self, "models"):
for _model in self.models:
_model.eval()
self.logged_histogram = {
'valEnv_on_trainData' : False,
'trainEnv_on_valData' : False,
}
with torch.no_grad():
metric_meters_train = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._train_loader_val)):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
metrics = self.validate_batch(batch, batch_info, 'valEnv_on_trainData')
metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
if batch_idx > 128:
break
metric_meters_val = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._val_loader_val)):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
metrics = self.validate_batch(batch, batch_info, 'trainEnv_on_valData')
metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
if batch_idx > 128:
break
info = metric_meters_train.summary()
info.update(metric_meters_val.summary())
info = {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
# run test on real environment
if (not self.config['real_env_test_frequency'] == 0) and \
self._epoch_cnt % self.config['real_env_test_frequency'] == 0:
info.update(self._env_test('train'))
info.update(self._env_test('val'))
if self._epoch_cnt >= self.config['save_start_epoch']:
need_update = []
for i, node_name in enumerate(self.graph_train.metric_nodes):
i = self.graph_train.metric_nodes.index(node_name)
if info[f"{self.NAME}/{node_name}_{self.config['venv_metric']}_trainEnv_on_valData"] < self._least_metric_train[i]:
self._least_metric_train[i] = info[f"{self.NAME}/{node_name}_{self.config['venv_metric']}_trainEnv_on_valData"]
# [ SIGNIFICANT OTHER ] deepcopy is necessary, otherwise, the best is always the same to the current!
self.best_graph_train.nodes[node_name] = deepcopy(self.graph_to_save_train.get_node(node_name))
need_update.append(True)
else:
need_update.append(False)
for i, node_name in enumerate(self.graph_val.metric_nodes):
i = self.graph_val.metric_nodes.index(node_name)
if info[f"{self.NAME}/{node_name}_{self.config['venv_metric']}_valEnv_on_trainData"] < self._least_metric_val[i]:
self._least_metric_val[i] = info[f"{self.NAME}/{node_name}_{self.config['venv_metric']}_valEnv_on_trainData"]
# [ SIGNIFICANT OTHER ] deepcopy is necessary, otherwise, the best is always the same to the current!
self.best_graph_val.nodes[node_name] = deepcopy(self.graph_to_save_val.get_node(node_name))
need_update.append(True)
else:
need_update.append(False)
if self.config["save_by_node"]:
info["least_metric"] = np.sum(self._least_metric_train) / self.total_dim
self.least_metric = info["least_metric"]
if True in need_update:
self._save_models(self._traj_dir)
self._update_metric()
info["stop_flag"] = self._stop_flag
info['now_metric'] = info[f'{self.metric_name}_trainEnv_on_valData']
if not self.config["save_by_node"]:
now_val_metric = info[f'{self.metric_name}_valEnv_on_trainData']
if self._epoch_cnt >= self.config['save_start_epoch']:
#if (info["now_metric"] <= self.least_train_metric or now_val_metric <= self.least_val_metric):
if info["now_metric"] <= self.least_train_metric:
self._save_models(self._traj_dir)
self._update_metric()
if info["now_metric"] <= self.least_train_metric:
self.least_train_metric = info["now_metric"]
if now_val_metric <= self.least_val_metric:
self.least_val_metric = now_val_metric
info["least_metric"] = self.least_train_metric
self.least_metric = self.least_train_metric
if self.config["venv_save_frequency"]:
if self._epoch_cnt % self.config["venv_save_frequency"] == 0:
self._save_models(self._traj_dir, model_prefixes = str(self._epoch_cnt))
# Save the k env for every task
_k = 1
while self._epoch_cnt % _k == 0:
if len(self._num_venv_list) < self.config['num_venv_store']:
self._num_venv_list.append(info["now_metric"])
pass
elif info["now_metric"] < np.max(self._num_venv_list):
_del_index = np.argmax(self._num_venv_list)
self._data_buffer.delet_deque_item.remote(self._traj_id, _del_index)
self._num_venv_list.pop(_del_index)
self._num_venv_list.append(info["now_metric"])
pass
else:
break
_best_graph_train = deepcopy(self.graph_train).to("cpu")
_best_graph_val = deepcopy(self.graph_val).to("cpu")
_best_graph_train.reset()
_best_graph_val.reset()
_venv_train = VirtualEnvDev(_best_graph_train)
_venv_val = VirtualEnvDev(_best_graph_val)
self._data_buffer.update_venv_deque_dict.remote(self._traj_id, _venv_train, _venv_val)
break
for k in list(info.keys()):
if self.NAME in k:
v = info.pop(k)
info['VAL_' + k] = v
'''plot histogram when training is finished'''
# [ OTHER ] more frequent valuation
if self._stop_flag or self._epoch_cnt % self.config["rollout_plt_frequency"] == 0:
if self.config["rollout_plt_frequency"] > 0:
histogram_path = os.path.join(self._traj_dir, 'histogram')
if not os.path.exists(histogram_path):
os.makedirs(histogram_path)
try:
save_histogram(histogram_path, self.best_graph_train, self._train_loader_val, device=self._device, scope='train')
save_histogram(histogram_path, self.best_graph_val, self._val_loader_val, device=self._device, scope='val')
if self.config["rollout_dataset_mode"] == "validate":
rollout_dataset = self.val_dataset
else:
rollout_dataset = self.train_dataset
# save rolllout action image
rollout_save_path = os.path.join(self._traj_dir, 'rollout_images')
nodes_map = deepcopy(self.nodes_dim_name_map)
# del step_node
if "step_node_" in nodes_map.keys():
nodes_map.pop("step_node_")
save_rollout_action(rollout_save_path, self.best_graph_train, self._device, rollout_dataset, deepcopy(nodes_map))
# [ OTHER ] not only plotting the best model, but also plotting the result of the current model
rollout_save_path = os.path.join(self._traj_dir, 'rollout_images_current')
save_rollout_action(rollout_save_path, self.graph_train, self._device, rollout_dataset, deepcopy(nodes_map))
except Exception as e:
logger.warning(e)
else:
logger.info("Don't plot images.")
if self._epoch_cnt == 1:
try:
info = self._load_checkpoint(info)
except:
logger.info("Don't Load checkpoint!")
if hasattr(self, "_traj_dir_bak") and os.path.exists(self._traj_dir_bak):
shutil.rmtree(self._traj_dir_bak)
os.remove(self._filename_bak)
if self._stop_flag:
if self.config["plt_response_curve"]:
response_curve_path = os.path.join(self._traj_dir, 'response_curve')
if not os.path.exists(response_curve_path):
os.makedirs(response_curve_path)
dataset = ray.get(self.config['dataset'])
plot_response_curve(response_curve_path, self.best_graph_train, self.best_graph_val, dataset=dataset.data, device=self._device)
try:
customer_uploadTrainLog(self.config["trainId"],
os.path.join(os.path.abspath(self._workspace),"revive.log"),
"train.simulator",
"success",
self._acc,
self.config["accessToken"])
except Exception as e:
error_message = traceback.format_exc()
error_message = ""
logger.info('Detect error:{}, Error Message: {}'.format(e,error_message))
return info
def _load_checkpoint(self,info):
if self._traj_dir.endswith("venv_train"):
self._load_models(self._traj_dir_bak)
with open(self._filename_bak, 'r') as f:
train_log = json.load(f)
metric = train_log["metrics"]["1"]
venv_train = torch.load(os.path.join(self._traj_dir_bak, 'venv_train.pt'), map_location='cpu')
venv_val = torch.load(os.path.join(self._traj_dir_bak, 'venv_val.pt'), map_location='cpu')
self._data_buffer.update_metric.remote(self._traj_id, {
"metric": metric["metric"],
"acc" : metric["acc"],
"ip": self._ip,
"venv_train" : venv_train,
"venv_val" : venv_val,
"traj_dir" : self._traj_dir,
})
self._data_buffer.write.remote(self._filename)
self._save_models(self._workspace, with_env=False)
return info
else:
with open(os.path.join(self._traj_dir,"params.json"), 'r') as f:
params = json.load(f)
experiment_dir = os.path.dirname(self._traj_dir)
dir_name_list = [dir_name for dir_name in os.listdir(experiment_dir) if dir_name.startswith("ReviveLog")]
dir_name_list = sorted(dir_name_list,key=lambda x:x[-19:], reverse=False)
for dir_name in dir_name_list:
dir_path = os.path.join(experiment_dir, dir_name)
if dir_path == self._traj_dir:
break
if os.path.isdir(dir_path):
params_json_path = os.path.join(dir_path, "params.json")
if os.path.exists(params_json_path):
with open(params_json_path, 'r') as f:
history_params = json.load(f)
if history_params == params:
result_json_path = os.path.join(dir_path, "result.json")
with open(result_json_path, 'r') as f:
history_result = []
for line in f.readlines():
history_result.append(json.loads(line))
if history_result:# and history_result[-1]["stop_flag"]:
for k in info.keys():
if k in history_result[-1].keys():
info[k] = history_result[-1][k]
self.least_metric = info["least_metric"]
# load model
logger.info("Find exist checkpoint, Load the model.")
self._load_models(dir_path)
self._save_models(self._traj_dir)
self._update_metric()
# check early stop
self._stop_flag = history_result[-1]["stop_flag"]
info["stop_flag"] = self._stop_flag
logger.info("Load checkpoint success!")
break
return info
def _load_models(self, path : str, with_env : bool = True):
"""
param: path, where to load the models
param: with_env, whether to load venv along with the models
"""
for node_name in self.best_graph_train.keys():
best_node = self.best_graph_train.get_node(node_name)
if best_node.node_type == 'network':
best_node.get_network().load_state_dict(torch.load(os.path.join(path, node_name + '_train.pt')).state_dict())
for node_name in self.best_graph_val.keys():
best_node = self.best_graph_val.get_node(node_name)
if best_node.node_type == 'network':
best_node.get_network().load_state_dict(torch.load(os.path.join(path, node_name + '_train.pt')).state_dict())
[docs] def train_batch(self, expert_data, batch_info, scope='train'):
r"""Define the training process for an batch data."""
raise NotImplementedError
[docs] def validate_batch(self, expert_data, batch_info, scope='valEnv_on_trainData'):
r"""Define the validate process for an batch data.
Args:
expert_data: The batch offline Data.
batch_info: A batch info dict.
scope: if ``scope=valEnv_on_trainData`` means training data test on the model trained by validation dataset.
"""
info = {}
if scope == 'valEnv_on_trainData':
graph = self.graph_val
else:
graph = self.graph_train
expert_data.to_torch(device=self._device)
traj_length = expert_data.shape[0]
# data is generated by taking the most likely action
sample_fn = lambda dist: dist.mode
generated_data = generate_rollout(expert_data, graph, traj_length, sample_fn, clip=True)
for node_name in graph.nodata_node_names:
assert node_name not in expert_data.keys()
assert node_name in generated_data.keys()
expert_data[node_name] = generated_data[node_name]
info.update(self._nll_test(expert_data, generated_data, scope))
info.update(self._mse_test(expert_data, generated_data, scope))
info.update(self._mae_test(expert_data, generated_data, scope))
info.update(self._wdist_test(expert_data, generated_data, scope))
# log histogram info
if (not self.config['histogram_log_frequency'] == 0) and \
self._epoch_cnt % self.config['histogram_log_frequency'] == 0 and \
not self.logged_histogram[scope]:
info.update(self._log_histogram(expert_data, generated_data, scope))
self.logged_histogram[scope] = True
return info
[docs]class VenvAlgorithm:
''' Class use to manage venv algorithms '''
def __init__(self, algo : str, workspace: str =None):
self.algo = algo
self.workspace = workspace
if self.algo == "revive" or self.algo == "revive_p" or self.algo == "revivep" or self.algo == "revive_ppo":
self.algo = "revive_p"
elif self.algo == "revive_t" or self.algo == "revivet" or self.algo == "revive_td3":
self.algo = "revivet"
elif self.algo == "bc":
self.algo = "bc"
else:
raise NotImplementedError
try:
self.algo_module = importlib.import_module(f'revive.dist.algo.venv.{self.algo}')
logger.info(f"Import encryption venv algorithm module -> {self.algo}!")
except:
self.algo_module = importlib.import_module(f'revive.algo.venv.{self.algo}')
logger.info(f"Import venv algorithm module -> {self.algo}!")
# Assume there is only one operator other than VenvOperator
for k in self.algo_module.__dir__():
if 'Operator' in k and not k == 'VenvOperator':
self.operator = getattr(self.algo_module, k)
self.operator_config = {}
[docs] def get_train_func(self, config):
from revive.utils.tune_utils import VALID_SUMMARY_TYPES
from torch.utils.tensorboard import SummaryWriter
from ray.air import session
def train_func(config):
for k,v in self.operator_config.items():
if not k in config.keys():
config[k] = v
algo_operator = self.operator(config)
writer = SummaryWriter(algo_operator._traj_dir)
epoch = 0
while True:
algo_operator.before_train_epoch()
train_stats = algo_operator.train_epoch()
val_stats = algo_operator.validate()
session.report({"mean_accuracy": algo_operator._acc,
"least_metric": val_stats["least_metric"],
"now_metric": val_stats["now_metric"],
"stop_flag": val_stats["stop_flag"]})
# write tensorboard
for k, v in [*train_stats.items(), *val_stats.items()]:
if type(v) in VALID_SUMMARY_TYPES:
writer.add_scalar(k, v, global_step=epoch)
elif isinstance(v, torch.Tensor):
v = v.view(-1)
writer.add_histogram(k, v, global_step=epoch)
writer.flush()
# check stop_flag
train_stats.update(val_stats)
if train_stats.get('stop_flag', False):
break
epoch += 1
return train_func
[docs] def get_trainer(self, config):
try:
train_func = self.get_train_func(config)
from ray.air.config import ScalingConfig
trainer = TorchTrainer(
train_func,
train_loop_config=config,
scaling_config=ScalingConfig(num_workers=config['workers_per_trial'], use_gpu=config['use_gpu']),
)
return trainer
except Exception as e:
logger.error('Detect Error: {}'.format(e))
raise e
[docs] def get_trainable(self, config):
try:
trainer = self.get_trainer(config)
trainable = trainer.as_trainable()
return trainable
except Exception as e:
logger.error('Detect Error: {}'.format(e))
raise e
[docs] def get_parameters(self, command=None):
try:
return self.operator.get_parameters(command)
except AttributeError:
raise AttributeError("Custom algorithm need to implement `get_parameters`")
except Exception as e:
logger.error('Detect Unknown Error:'.format(e))
raise e
[docs] def get_tune_parameters(self, config):
try:
self.operator_config = config
return self.operator.get_tune_parameters(config)
except AttributeError:
raise AttributeError("Custom algorithm need to implement `get_tune_parameters`")
except Exception as e:
logger.error('Detect Unknown Error:'.format(e))
raise e