''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import os
import ray
import json
import math
import psutil
import torch
import warnings
import argparse
import importlib
import traceback
import numpy as np
from ray import tune
from loguru import logger
from copy import deepcopy
from abc import abstractmethod
from concurrent import futures
from ray import train
from ray.train.torch import TorchTrainer
from revive.utils.raysgd_utils import NUM_SAMPLES, AverageMeterCollection
from revive.computation.inference import *
from revive.computation.graph import FunctionDecisionNode
from revive.data.batch import Batch
from revive.data.dataset import data_creator
from ray.train.torch import TorchTrainer
from torch.utils.tensorboard import SummaryWriter
from ray.train import ScalingConfig
from ray.air import session
from revive.utils.common_utils import *
from revive.utils.tune_utils import get_tune_callbacks, CustomSearchGenerator, CustomBasicVariantGenerator, CLIReporter, VALID_SUMMARY_TYPES
from revive.utils.auth_utils import customer_uploadTrainLog
from revive.computation.modules import EnsembleLinear
warnings.filterwarnings('ignore')
[docs]
def catch_error(func):
'''push the training error message to data buffer'''
def wrapped_func(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception as e:
error_message = traceback.format_exc()
self.loguru_logger.error('Detect error:{}, Error Message: {}'.format(e,error_message))
ray.get(self._data_buffer.update_status.remote(self._traj_id, 'error', error_message))
self._stop_flag = True
try:
customer_uploadTrainLog(self.config["trainId"],
os.path.join(os.path.abspath(self._workspace),"revive.log"),
"train.policy",
"fail",
self._max_reward,
self.config["accessToken"])
except Exception as e:
logger.info(f"{e}")
return {
'stop_flag' : True,
'reward_trainPolicy_on_valEnv' : - np.inf,
}
return wrapped_func
[docs]
class PolicyOperator():
@property
def env(self):
return self.envs_train[0]
@property
def train_policy(self):
return self.train_models[:len(self.policy_name)]
@property
def val_policy(self):
return self.val_models[:len(self.policy_name)]
@property
def policy(self):
return self.train_policy
@property
def other_train_models(self):
return self.train_models[len(self.policy_name):]
@property
def other_val_models(self):
return self.val_models[len(self.policy_name):]
# NOTE: you need either write the `PARAMETER_DESCRIPTION` or overwrite `get_parameters` and `get_tune_parameters`.
PARAMETER_DESCRIPTION = [] # a list of dict to describe the parameter of the algorithm
[docs]
@classmethod
def get_parameters(cls, command=None, **kargs):
parser = list2parser(cls.PARAMETER_DESCRIPTION)
return parser.parse_known_args(command)[0].__dict__
# @classmethod
# def get_parameters(cls, command=None, **kargs):
# parser = argparse.ArgumentParser()
# for description in cls.PARAMETER_DESCRIPTION:
# names = ['--' + description['name']]
# if "abbreviation" in description.keys(): names.append('-' + description["abbreviation"])
# if type(description['type']) is str:
# try:
# description['type'] = eval(description['type'])
# except Exception as e:
# logger.error(f"{e}")
# logger.error(f"Please check the .json file for the input type {description['type']} which is not supported!")
# logger.error('Detect Error: {}'.format(e))
# raise e
# parser.add_argument(*names, type=description['type'], default=description['default'])
# return parser.parse_known_args(command)[0].__dict__
[docs]
@classmethod
def get_tune_parameters(cls, config : Dict[str, Any], **kargs):
r"""
Use ray.tune to wrap the parameters to be searched.
"""
_search_algo = config['policy_search_algo'].lower()
tune_params = {
"name": "policy_tune",
"reuse_actors": config["reuse_actors"],
"local_dir": config["workspace"],
"callbacks": get_tune_callbacks(),
"stop": {
"stop_flag": True
},
"verbose": config["verbose"],
}
if _search_algo == 'random':
random_search_config = {}
for description in cls.PARAMETER_DESCRIPTION:
if 'search_mode' in description.keys() and 'search_values' in description.keys():
if description['search_mode'] == 'continuous':
random_search_config[description['name']] = tune.uniform(*description['search_values'])
elif description['search_mode'] == 'grid' or description['search_mode'] == 'discrete':
random_search_config[description['name']] = tune.choice(description['search_values'])
config["total_num_of_trials"] = config['train_policy_trials']
tune_params['config'] = random_search_config
tune_params['num_samples'] = config["total_num_of_trials"]
tune_params['search_alg'] = CustomBasicVariantGenerator()
elif _search_algo == 'zoopt':
#from ray.tune.search.zoopt import ZOOptSearch
from revive.utils.tune_utils import ZOOptSearch
from zoopt import ValueType
if config['parallel_num'] == 'auto':
if config['use_gpu']:
num_of_gpu = int(ray.available_resources()['GPU'])
num_of_cpu = int(ray.available_resources()['CPU'])
parallel_num = min(int(num_of_gpu / config['policy_gpus_per_worker']), num_of_cpu)
else:
num_of_cpu = int(ray.available_resources()['CPU'])
parallel_num = num_of_cpu
else:
parallel_num = int(config['parallel_num'])
assert parallel_num > 0
config['parallel_num'] = parallel_num
dim_dict = {}
for description in cls.PARAMETER_DESCRIPTION:
if 'search_mode' in description.keys() and 'search_values' in description.keys():
if description['search_mode'] == 'continuous':
dim_dict[description['name']] = (ValueType.CONTINUOUS, description['search_values'], min(description['search_values']))
elif description['search_mode'] == 'discrete':
dim_dict[description['name']] = (ValueType.DISCRETE, description['search_values'])
elif description['search_mode'] == 'grid':
dim_dict[description['name']] = (ValueType.GRID, description['search_values'])
zoopt_search_config = {
"parallel_num": config['parallel_num']
}
config["total_num_of_trials"] = config['train_policy_trials']
tune_params['search_alg'] = ZOOptSearch(
algo="Asracos", # only support Asracos currently
budget=config["total_num_of_trials"],
dim_dict=dim_dict,
metric='reward_trainPolicy_on_valEnv',
mode="max",
**zoopt_search_config
)
tune_params['config'] = dim_dict
tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator
tune_params['num_samples'] = config["total_num_of_trials"]
elif 'grid' in _search_algo:
grid_search_config = {}
config["total_num_of_trials"] = 1
for description in cls.PARAMETER_DESCRIPTION:
if 'search_mode' in description.keys() and 'search_values' in description.keys():
if description['search_mode'] == 'grid':
grid_search_config[description['name']] = tune.grid_search(description['search_values'])
config["total_num_of_trials"] *= len(description['search_values'])
else:
warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \
f"However, since grid search does not support search type {description['search_mode']}, the parameter is skipped. " + \
f"If this parameter is important to the performance, you should consider other search algorithms. ")
tune_params['config'] = grid_search_config
tune_params['num_samples'] = config["total_num_of_trials"]
tune_params['search_alg'] = CustomBasicVariantGenerator()
elif 'bayes' in _search_algo:
from ray.tune.search.bayesopt import BayesOptSearch
bayes_search_config = {}
for description in cls.PARAMETER_DESCRIPTION:
if 'search_mode' in description.keys() and 'search_values' in description.keys():
if description['search_mode'] == 'continuous':
bayes_search_config[description['name']] = description['search_values']
else:
warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \
f"However, since bayesian search does not support search type {description['search_mode']}, the parameter is skipped. " + \
f"If this parameter is important to the performance, you should consider other search algorithms. ")
config["total_num_of_trials"] = config['train_policy_trials']
tune_params['config'] = bayes_search_config
tune_params['search_alg'] = BayesOptSearch(bayes_search_config, metric="reward_trainPolicy_on_valEnv", mode="max")
tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator
tune_params['num_samples'] = config["total_num_of_trials"]
else:
raise ValueError(f'search algorithm {_search_algo} is not supported!')
reporter = CLIReporter(parameter_columns=list(tune_params['config'].keys()), max_progress_rows=50, max_report_frequency=10, sort_by_metric=True)
reporter.add_metric_column("reward_trainPolicy_on_valEnv")
tune_params["progress_reporter"] = reporter
return tune_params
[docs]
def optimizer_creator(self, models : List[torch.nn.Module], config : Dict[str, Any]) -> List[torch.optim.Optimizer]:
r"""
Define optimizers for the created models.
Args:
:pmodels: list of all the models
:config: configuration parameters
Return:
a list of optimizers
"""
raise NotImplementedError
[docs]
def data_creator(self):
r"""
Create DataLoaders.
Args:
:config: configuration parameters
Return:
(train_loader, val_loader)
"""
# raise NotImplementedError
logger.warning('data_creator is using the test_horizon' )
return data_creator(self.config, val_horizon=self.config['test_horizon'], double=self.double_validation)
def _setup_data(self):
''' setup data used in training '''
self.train_dataset = ray.get(self.config['dataset'])
self.val_dataset = ray.get(self.config['val_dataset'])
if not self.double_validation:
train_loader_train, val_loader_val = self.data_creator()
self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False)
self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False)
self._train_loader_val = None
self._val_loader_val = None
else:
train_loader_train, val_loader_train, train_loader_val, val_loader_val = self.data_creator()
try:
self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False)
self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False)
self._train_loader_val = train.torch.prepare_data_loader(train_loader_val, move_to_device=False)
self._val_loader_val = train.torch.prepare_data_loader(val_loader_val, move_to_device=False)
except:
self._train_loader_train = train_loader_train
self._val_loader_train = val_loader_train
self._train_loader_val = train_loader_val
self._val_loader_val = val_loader_val
def _setup_models(self, config : Dict[str, Any]):
r'''setup models, optimizers and dataloaders.'''
logger.info(f'Init model and move model to {self._device}')
# Init train env models and optimizers
self.train_nodes = {policy_name: deepcopy(self._graph.get_node(policy_name)) for policy_name in self.policy_name}
self.train_models = self.model_creator(self.train_nodes)
for model_index, model in enumerate(self.train_models):
try:
self.train_models[model_index] = train.torch.prepare_model(model,move_to_device=False).to(self._device)
except:
self.train_models[model_index] = model.to(self._device)
self.train_optimizers = self.optimizer_creator(scope="train")
self.train_bc_optimizer = self.bc_optimizer_creator(scope="train")
for train_node,policy in zip(self.train_nodes.values(), self.train_policy):
train_node.set_network(policy)
if self.double_validation:
# Init val env models and optimizers
self.val_nodes = {policy_name: deepcopy(self._graph.get_node(policy_name)) for policy_name in self.policy_name}
self.val_models = self.model_creator(self.val_nodes)
for model_index, model in enumerate(self.val_models):
try:
self.val_models[model_index] = train.torch.prepare_model(model,move_to_device=False).to(self._device)
except:
self.val_models[model_index] = model.to(self._device)
self.val_optimizers = self.optimizer_creator(scope="val")
self.val_bc_optimizer = self.bc_optimizer_creator(scope="val")
for val_node,policy in zip(self.val_nodes.values(), self.val_policy):
val_node.set_network(policy)
else:
self.val_models = None
self.val_optimizers = None
self.train_bc_optimizer = None
[docs]
def model_creator(self, nodes):
graph = self._graph
policy_name = self.policy_name
total_dims = self.config['total_dims']
models = []
# Init policy by behavioral policy
if self.config['behavioral_policy_init'] and self.behaviour_policys:
for behaviour_policy in self.behaviour_policys:
target_policy = deepcopy(behaviour_policy)
target_policy.requires_grad_(True)
models.append(target_policy)
# Init new policy network
else:
if self.config['behavioral_policy_init'] and self.config["policy_bc_epoch"] == 0:
self.config["policy_bc_epoch"] = 200
for policy_name, node in nodes.items():
node_input_dim = get_input_dim_from_graph(graph, policy_name, total_dims)
if hasattr(node, 'custom_node'):
node_input_dim = get_input_dim_dict_from_graph(graph, policy_name, total_dims)
additional_kwargs = {}
if self.config.get('ts_conv_nodes', "auto") == 'auto':
pass
elif policy_name in self.config['ts_conv_nodes'] :
temp_input_dim = get_input_dim_dict_from_graph(graph, policy_name, total_dims)
has_ts = any('ts' in element for element in temp_input_dim)
if not has_ts:
pass
else:
input_dim, input_dim_config, ts_conv_net_config = \
dict_ts_conv(graph, temp_input_dim, total_dims, self.config,
net_hidden_features=self.config['policy_hidden_features'])
additional_kwargs['ts_conv_config'] = input_dim_config
additional_kwargs['ts_conv_net_config'] = ts_conv_net_config
logger.info(f"Policy Backbone: {self.config['policy_backbone'],}")
input_dim_dict = get_input_dim_dict_from_graph(graph, policy_name, total_dims)
node.initialize_network(node_input_dim, total_dims[policy_name]['output'],
hidden_features=self.config['policy_hidden_features'],
hidden_layers=self.config['policy_hidden_layers'],
backbone_type=self.config['policy_backbone'],
dist_config=self.config['dist_configs'][policy_name],
is_transition=False,
input_dim_dict=input_dim_dict,
**additional_kwargs)
target_policy = node.get_network()
models.append(target_policy)
return models
[docs]
def bc_optimizer_creator(self, scope):
if scope == "train":
target_policys = self.train_policy
else:
target_policys = self.val_policy
models_params = []
for target_policy in target_policys:
models_params.append({'params': target_policy.parameters(), 'lr': 1e-3})
bc_policy_optimizer = torch.optim.Adam(models_params)
return bc_policy_optimizer
@catch_error
def __init__(self, config : Dict[str, Any]):
'''setup everything for training'''
# parse information from config
self.config = config
self._data_buffer = self.config['policy_data_buffer']
self._workspace = self.config["workspace"]
self._user_func = self.config['user_func']
self._graph = self.config['graph']
self._processor = self._graph.processor
self.policy_name = self.config['target_policy_name']
if isinstance(self.policy_name, str):
self.policy_name = [self.policy_name, ]
# sort the policy by graph
self.policy_name = [policy_name for policy_name in self.policy_name if policy_name in self._graph.keys()]
self.double_validation = self.config['policy_double_validation']
self._filename = os.path.join(self._workspace, "train_policy.json")
self._data_buffer.set_total_trials.remote(self.config.get("total_num_of_trials", 1))
self._data_buffer.inc_trial.remote()
self._max_reward = -np.inf
# get id
self._ip = ray._private.services.get_node_ip_address()
logger.add(os.path.join(os.path.abspath(self._workspace),"revive.log"))
self.loguru_logger = logger
# set trail seed
setup_seed(self.config["global_seed"])
if 'tag' in self.config: # create by tune
tag = self.config['tag']
logger.info("{} is running".format(tag))
self._traj_id = int(tag.split('_')[0])
experiment_dir = os.path.join(self._workspace, 'policy_tune')
for traj_name in filter(lambda x: "ReviveLog" in x, os.listdir(experiment_dir)):
if len(traj_name.split('_')[1]) == 5: # create by random search or grid search
id_index = 3
else:
id_index = 2
if int(traj_name.split('_')[id_index]) == self._traj_id:
self._traj_dir = os.path.join(experiment_dir, traj_name)
break
else:
self._traj_id = 1
self._traj_dir = os.path.join(self._workspace, 'policy_train')
update_env_vars("pattern", "policy")
# setup constant
self._stop_flag = False
self._batch_cnt = 0
self._epoch_cnt = 0
self._max_reward = - np.inf
self._use_gpu = self.config["use_gpu"] and torch.cuda.is_available()
self._device = 'cuda' if self._use_gpu else 'cpu' # fix problem introduced in ray 1.1
# set default config
self._speedup = self.config.get("speedup",False)
# collect venv
env = ray.get(self.config['venv_data_buffer'].get_best_venv.remote())
self.envs = env.env_list
for env in self.envs:
env.to(self._device)
env.requires_grad_(False)
env.set_target_policy_name(self.policy_name)
if hasattr(env, "train_algo") and env.train_algo == "REVIVE_FILTER":
assert "filter" in self.config and "candidate_num" in self.config, f"venv trained from REVIVE_FILTER requires \'filter\' and \'candidate_num\'"
assert "ensemble_size" in self.config and "ensemble_choosing_interval" in self.config, f"venv trained from REVIVE_FILTER requires \'ensemble_size\' and \'ensemble_choosing_interval\'"
env.reset_ensemble_matcher(ensemble_size=self.config['ensemble_size'], ensemble_choosing_interval=self.config['ensemble_choosing_interval'])
logger.info(f"Find {len(self.envs)} envs.")
if len(self.envs) == 1:
logger.warning('Only one venv found, use it in both training and validation!')
self.envs_train = self.envs
self.envs_val = self.envs
else:
self.envs_train = self.envs[:len(self.envs)//2]
self.envs_val = self.envs[len(self.envs)//2:]
#if self.config['num_venv_in_use'] > len(self.envs_train):
# warnings.warn(f"Config requires {self.config['num_venv_in_use']} venvs, but only {len(self.envs_train)} venvs are available.")
if self.config['num_venv_in_use'] > len(self.envs_train):
logger.info("Adjusting the distribution to generate multiple env models.")
mu_shift_list = np.linspace(-0.15, 0.15, num=(self.config['num_venv_in_use']-1))
for mu_shift in mu_shift_list:
if mu_shift == 0:
continue
env_train = deepcopy(self.envs_train[0])
for node in env_train.graph.nodes.values():
if node.node_type == 'network':
node.network.dist_mu_shift = mu_shift
self.envs_train.append(env_train)
env_val = deepcopy(self.envs_train[0])
for node in env_val.graph.nodes.values():
if node.node_type == 'network':
node.network.dist_mu_shift = mu_shift
self.envs_val.append(env_train)
else:
self.envs_train = self.envs_train[:int(self.config['num_venv_in_use'])]
self.envs_val = self.envs_val[:int(self.config['num_venv_in_use'])]
self.behaviour_nodes = [self.envs[0].graph.get_node(policy_name) for policy_name in self.policy_name]
self.behaviour_policys = [behaviour_node.get_network() for behaviour_node in self.behaviour_nodes if behaviour_node.get_network()]
self.input_is_dict = False
for policy in self.behaviour_policys:
if policy.__class__.__name__.startswith("Ts"):
self.input_is_dict = True
# find average reward from expert data
dataset = ray.get(self.config['dataset'])
train_data = dataset.data
train_data.to_torch()
train_reward = self._user_func(train_data)
self.train_average_reward = float(train_reward.mean())
dataset = ray.get(self.config['val_dataset'])
val_data = dataset.data
val_data.to_torch()
val_reward = self._user_func(train_data)
self.val_average_reward = float(val_reward.mean())
# initialize FQE
self.fqe_evaluator = None
# prepare for training
self._setup_data()
self._setup_models(self.config)
self.disturbing_transition_function = self.config['disturbing_transition_function']
# define numbers of turbu networks
if self.disturbing_transition_function:
self.rnd_network_nums = self.config['disturbing_net_num']
self.disturbing_node_name = []
self.disturbing_node = {}
for node in self._graph.nodes.keys():
if node in self.policy_name:
continue
if self._graph.nodes[node].node_type == 'network':
if self.config['disturbing_nodes']=='auto':
pass
else:
if node in self.config['disturbing_nodes']:
pass
else:
continue
self.disturbing_node_name.append(node)
_input_nodes = self._graph[node]
input_dims = sum([self.config['total_dims'][_node_name]['input'] for _node_name in _input_nodes])
output_dims = self.config['total_dims'][node]['output']
with torch.random.fork_rng():
torch.manual_seed(21312)
_rnd_networks = torch.nn.Sequential(
EnsembleLinear(input_dims, 32, self.rnd_network_nums),
torch.nn.LeakyReLU(),
EnsembleLinear(32, 32, self.rnd_network_nums),
torch.nn.LeakyReLU(),
EnsembleLinear(32, output_dims // 2, self.rnd_network_nums),
torch.nn.Tanh()
)
self.disturbing_node[node] = {'network': _rnd_networks,
'input_nodes': _input_nodes,
'disturb_weight': self.config['disturbing_weight']}
logger.info(f'disturbing_transition_function for node name: {node}, input nodes: {_input_nodes}, input dims: {input_dims}, output dims: {output_dims}')
self._save_models(self._traj_dir)
self._update_metric()
# prepare for ReplayBuffer
try:
self.setup(self.config)
except:
pass
self.action_dims = {}
for policy_name in self.policy_name:
action_dims = []
for action_dim in self._graph.descriptions[policy_name]:
action_dims.append(list(action_dim.keys())[0])
self.action_dims[policy_name] = action_dims
self.nodes_map = {}
for node_name in list(self._graph.nodes) + list(self._graph.leaf):
node_dims = []
for node_dim in self._graph.descriptions[node_name]:
node_dims.append(list(node_dim.keys())[0])
self.nodes_map[node_name] = node_dims
self.global_step = 0
def _early_stop(self, info : Dict[str, Any]):
info["stop_flag"] = self._stop_flag
return info
def _update_metric(self):
try:
policy = torch.load(os.path.join(self._traj_dir, 'policy.pt'), map_location='cpu')
except:
policy = None
self._data_buffer.update_metric.remote(self._traj_id, {"reward" : self._max_reward, "ip" : self._ip, "policy" : policy, "traj_dir": self._traj_dir})
self._data_buffer.write.remote(self._filename)
if self._max_reward == ray.get(self._data_buffer.get_max_reward.remote()):
self._save_models(self._workspace, with_policy=False)
def _wrap_policy(self, policys : List[torch.nn.Module,], device = None) -> PolicyModelDev:
policy_nodes = deepcopy(self.behaviour_nodes)
policys = deepcopy(policys)
for policy_node,policy in zip(policy_nodes,policys):
policy_node.set_network(policy)
if device:
policy_node.to(device)
policys = PolicyModelDev(policy_nodes)
return policys
def _save_models(self, path : str, with_policy : bool = True):
torch.save(self.policy, os.path.join(path, "tuned_policy.pt"))
if with_policy:
policy = self._wrap_policy(self.policy, "cpu")
torch.save(policy, os.path.join(path, "policy.pt"))
policy = PolicyModel(policy)
self.infer_policy = policy
with open(os.path.join(self._traj_dir, "policy.pkl"), 'wb') as f:
pickle.dump(policy, f)
def _get_original_actions(self, batch_data : Batch) -> Batch:
# return a list with all actions [o, a_1, a_2, ... o']
# NOTE: we assume key `next_obs` in the data.
# NOTE: the leading dimensions will be flattened.
original_data = Batch()
for policy_name in list(self._graph.leaf) + list(self._graph.keys()):
action = batch_data[policy_name].view(-1, batch_data[policy_name].shape[-1])
original_data[policy_name] = self._processor.deprocess_single_torch(action, policy_name)
return original_data
[docs]
def get_ope_dataset(self):
r''' convert the dataset to OPEDataset used in d3pe '''
from d3pe.utils.data import OPEDataset
dataset = ray.get(self.config['dataset'])
expert_data = dataset.processor.process(dataset.data)
expert_data = expert_data
expert_data.to_torch()
expert_data = generate_rewards(expert_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data)))
expert_data.to_numpy()
policy_input_names = self._graph.get_node(self.policy_name).input_names
if all([node_name in self._graph.transition_map.keys() for node_name in policy_input_names]):
obs = np.concatenate([expert_data[node_name] for node_name in policy_input_names], axis=-1)
next_obs = np.concatenate([expert_data['next_' + node_name] for node_name in policy_input_names], axis=-1)
ope_dataset = OPEDataset(dict(
obs=obs.reshape((-1, obs.shape[-1])),
action=expert_data[self.policy_name].reshape((-1, expert_data[self.policy_name].shape[-1])),
reward=expert_data['reward'].reshape((-1, expert_data['reward'].shape[-1])),
done=expert_data['done'].reshape((-1, expert_data['done'].shape[-1])),
next_obs=next_obs.reshape((-1, next_obs.shape[-1])),
), start_indexes=dataset._start_indexes)
else:
data = dict(
action=expert_data[self.policy_name][:-1].reshape((-1, expert_data[self.policy_name].shape[-1])),
reward=expert_data['reward'][:-1].reshape((-1, expert_data['reward'].shape[-1])),
done=expert_data['done'][:-1].reshape((-1, expert_data['done'].shape[-1])),
)
expert_data.to_torch()
obs = get_input_from_graph(self._graph, self.policy_name, expert_data).numpy()
expert_data.to_numpy()
data['obs'] = obs[:-1].reshape((-1, obs.shape[-1]))
data['next_obs'] = obs[1:].reshape((-1, obs.shape[-1]))
ope_dataset = OPEDataset(data, dataset._start_indexes - np.arange(dataset._start_indexes.shape[0]))
return ope_dataset
[docs]
def venv_test(self, expert_data : Batch, target_policy, traj_length=None, scope : str = 'trainPolicy_on_valEnv'):
r""" Use the virtual env model to test the policy model"""
rewards = []
envs = self.envs_val if "valEnv" in scope else self.envs_train
for env in envs:
generated_data, info = self._run_rollout(expert_data, target_policy, env, traj_length, deterministic=self.config['deterministic_test'], clip=True)
if 'done' in generated_data.keys():
temp_done = self._processor.deprocess_single_torch(generated_data['done'], 'done')
not_done = ~temp_done.bool()
temp_reward = not_done * generated_data.reward
else:
not_done = torch.ones_like(generated_data.reward)
temp_reward = not_done * generated_data.reward
temp_reward = generated_data.reward * generated_data.valid_masks
reward = temp_reward.squeeze(dim=-1)
t = torch.arange(0, reward.shape[0]).to(reward)
discount = self.config['test_gamma'] ** t
discount_reward = torch.sum(discount.unsqueeze(dim=-1) * reward, dim=0)
if "done" in generated_data.keys():
rewards.append(discount_reward.mean().item())
else:
# rewards.append(discount_reward.mean().item() / self.config['test_horizon'])
rewards.append(discount_reward.mean().item() / reward.shape[0])
return {f'reward_{scope}' : np.mean(rewards)}
def _fqe_test(self, target_policy, scope='trainPolicy_on_valEnv'):
if 'offlinerl' in str(type(target_policy)): target_policy.get_action = lambda x: target_policy(x)
if self.fqe_evaluator is None: # initialize FQE evaluator in the first run
from d3pe.evaluator.fqe import FQEEvaluator
self.fqe_evaluator = FQEEvaluator()
self.fqe_evaluator.initialize(self.get_ope_dataset(), verbose=True)
with torch.enable_grad():
info = self.fqe_evaluator(target_policy)
for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k)
return info
def _env_test(self, target_policy, scope='trainPolicy_on_valEnv'):
env = create_env(self.config['task'])
if env is None:
return {}
policy = deepcopy(target_policy)
policy = policy.to('cpu')
policy = self._wrap_policy(policy)
policy = PolicyModel(policy)
reward, length = test_on_real_env(env, policy)
return {
f"real_reward_{scope}" : reward,
f"real_length_{scope}" : length,
}
def _run_rollout(self,
expert_data : Batch,
target_policy,
env : Union[VirtualEnvDev, List[VirtualEnvDev]],
traj_length : int = None,
maintain_grad_flow : bool = False,
deterministic : bool = True,
clip : bool = False):
traj_length = traj_length or expert_data.obs.shape[0]
if traj_length >= expert_data.shape[0] and len(self._graph.leaf) > 1:
traj_length = expert_data.shape[0]
warnings.warn(f'leaf node detected, connot run over the expert trajectory! Reset `traj_length` to {traj_length}!')
generated_data = self.generate_rollout(expert_data, target_policy, env, traj_length, maintain_grad_flow, deterministic, clip)
if self.config["pre_horzion"] > 0:
expert_data = expert_data[self.config["pre_horzion"]:]
generated_data = generated_data[self.config["pre_horzion"]:]
# if "uncertainty" in generated_data.keys():
# generated_data["reward"] -= self.config["reward_uncertainty_weight"] * generated_data["uncertainty"]
# If use the action for multi-steps in env.
if self.config["action_steps"] >= 2:
index = torch.arange(0,generated_data["reward"].shape[0],self.config["action_steps"])
actions_step_data = Batch()
for k,v in generated_data.items():
if k == "reward":
continue
actions_step_data[k] = v[index]
reward = generated_data["reward"]
reward_pad_steps = math.ceil(reward.shape[0] / self.config["action_steps"]) * self.config["action_steps"] - reward.shape[0]
reward_pad = torch.cat([reward, reward[-1:].repeat(reward_pad_steps,1,1)],axis=0)
actions_step_data["reward"] = reward_pad.reshape(self.config["action_steps"], -1, reward_pad.shape[1], reward_pad.shape[2]).mean(axis=0)
generated_data = actions_step_data
if "done" in generated_data.keys():
info = {
"reward": ((generated_data.reward*generated_data.valid_masks).sum()/generated_data.reward.shape[1]).item()
}
else:
info = {
"reward": generated_data.reward.mean().item()
}
return generated_data, info
[docs]
def generate_rollout(self,
expert_data : Batch,
target_policy,
env : Union[VirtualEnvDev, List[VirtualEnvDev]],
traj_length : int,
maintain_grad_flow : bool = False,
deterministic : bool = True,
clip : bool = False):
r"""Generate trajectories based on current policy.
Args:
:expert_data: sampled data from the dataset.
:target_policy: target_policy
:env: env
:traj_length: traj_length
:maintain_grad_flow: maintain_grad_flow
Return:
batch trajectories
"""
for policy in target_policy:
policy.reset()
if isinstance(env, list):
for _env in env:
_env.reset()
else:
env.reset()
assert len(self._graph.leaf) == 1 or traj_length <= expert_data.shape[0], \
'There is leaf node on the graph, cannot generate trajectory beyond expert data'
generated_data = []
random_select_start_index = self.config.get("random_select_start_index", False)
if random_select_start_index:
select_index = random.randint(0, expert_data.shape[0] - 1)
else:
select_index =0
current_batch = Batch({k : expert_data[select_index][k] for k in self._graph.leaf})
random_select_external_factors = self.config.get("random_select_external_factors", False)
if random_select_external_factors:
current_batch = Batch({k :current_batch[k] if k in self._graph.external_factors else current_batch[k][torch.randperm(current_batch[k].size(0))] for k in self._graph.leaf})
batch_size = current_batch.shape[0]
sample_fn = lambda dist: dist.rsample() if maintain_grad_flow else dist.sample()
if isinstance(env, list):
sample_env_nums = min(min(7,len(env)),batch_size)
env_id = random.sample(range(len(env)), k=sample_env_nums)
n = int(math.ceil(batch_size / float(sample_env_nums)))
env_batch_index = [range(batch_size)[i:min(i + n,batch_size)] for i in range(0, batch_size, n)]
# uncertainty_list = []
done = False
if self.disturbing_transition_function:
dtf_dict = deepcopy(self.disturbing_node)
for _node in dtf_dict.keys():
dtf_dict[_node]['rnd_idx'] = np.random.randint(-1, self.rnd_network_nums, (batch_size,))
for i in range(traj_length+1):
for policy_index, policy_name in enumerate(self.policy_name):
if isinstance(env, list):
result_batch = []
use_env_id = -1
for _env_id,_env in enumerate(env):
if _env_id not in env_id:
continue
use_env_id += 1
_current_batch = deepcopy(current_batch) #[env_batch_index[use_env_id],:]
_current_batch = _env.pre_computation(_current_batch,
deterministic,
clip,
policy_index,
{} if not self.disturbing_transition_function else dtf_dict,
filter=self.config['filter'],
candidate_num=self.config['candidate_num'])
result_batch.append(_current_batch)
current_batch = Batch.cat([_current_batch[_env_batch_index,:] for _current_batch,_env_batch_index in zip(result_batch,env_batch_index)])
# policy_inputs = [get_input_from_graph(self._graph, policy_name, _current_batch) for _current_batch in result_batch]
# policy_inputs = torch.stack(policy_inputs, dim=2)
# policy_inputs_mean = torch.mean(policy_inputs, dim=-1, keepdim=True)
# diff = policy_inputs - policy_inputs_mean
# uncertainty = torch.max(torch.norm(diff, dim=-1, keepdim=False), dim=1)[0].reshape(-1,1)
# if i > 0:
# uncertainty_list.append(uncertainty)
if i == traj_length:
done = True
break
else:
current_batch = env.pre_computation(current_batch,
deterministic,
clip,
policy_index,
{} if not self.disturbing_transition_function else dtf_dict,
filter=self.config['filter'],
candidate_num=self.config['candidate_num'])
policy_input = get_input_from_graph(self._graph, policy_name, current_batch) # [batch_size, dim]
if hasattr(list(self.train_nodes.values())[policy_index], 'custom_node') or self.input_is_dict:
policy_input = get_input_dict_from_graph(self._graph, policy_name, current_batch)
if hasattr(list(self.train_nodes.values())[policy_index], 'ts_conv_node'):
policy_input = get_input_dict_from_graph(self._graph, policy_name, current_batch)
if 'offlinerl' in str(type(target_policy[policy_index])):
action = target_policy[policy_index](policy_input)
current_batch[policy_name] = action
else:
# use policy infer
if i % self.config["action_steps"] == 0:
action_dist = target_policy[policy_index](policy_input)
action = sample_fn(action_dist)
action_log_prob = (action_dist.log_prob(action).unsqueeze(dim=-1)).detach()
current_batch[policy_name + '_log_prob'] = action_log_prob
current_batch[policy_name] = action
# use the last step action
else:
current_batch[policy_name + '_log_prob'] = action_log_prob
current_batch[policy_name] = deepcopy(action.detach())
if done:
break
if isinstance(env, list):
result_batch = []
use_env_id = -1
for _env_id,_env in enumerate(env):
if _env_id not in env_id:
continue
use_env_id += 1
_current_batch = deepcopy(current_batch)[env_batch_index[use_env_id],:]
_current_batch = _env.post_computation(_current_batch,
deterministic,
clip,
policy_index,
{} if not self.disturbing_transition_function else dtf_dict,
filter=self.config['filter'],
candidate_num=self.config['candidate_num'])
result_batch.append(_current_batch)
current_batch = Batch.cat(result_batch)
else:
current_batch = env.post_computation(current_batch,
deterministic,
clip,
policy_index,
{} if not self.disturbing_transition_function else dtf_dict,
filter=self.config['filter'],
candidate_num=self.config['candidate_num'])
generated_data.append(current_batch)
if i == traj_length - 1:
break
current_batch = Batch(self._graph.state_transition(current_batch))
for k in self._graph.leaf:
if not k in self._graph.transition_map.keys(): current_batch[k] = expert_data[i+1][k]
# for current_batch, uncertainty in zip(generated_data, uncertainty_list):
# current_batch["uncertainty"] = uncertainty
generated_data = Batch.stack(generated_data)
generated_data = generate_rewards(generated_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data)))
if "done" in generated_data.keys():
masks = (generated_data["done"] + 1) / 2
def get_valid_masks(masks):
masks = masks.reshape(-1)
indices = (masks == 1).nonzero(as_tuple=False)
if len(indices) == 0:
return (masks+1).reshape(-1,1,1)
index = indices[0]
loss_mask = torch.zeros_like(masks).to(masks.device)
loss_mask[:index + 1] = 1
return loss_mask.reshape(-1,1,1)
valid_masks = torch.cat([get_valid_masks(masks[:,id]) for id in range(masks.shape[1])],axis=1)
def get_valid_done_masks(masks):
masks = masks.reshape(-1)
indices = (masks == 1).nonzero(as_tuple=False)
if len(indices) == 0:
return (masks+1).reshape(-1,1,1)
index = indices[0]
loss_mask = torch.zeros_like(masks).to(masks.device)
loss_mask[index] = 1
return loss_mask.reshape(-1,1,1)
valid_done_masks = torch.cat([get_valid_done_masks(masks[:,id]) for id in range(masks.shape[1])],axis=1)
#generated_data["reward"] -= valid_done_masks*10.
else:
masks = torch.zeros_like(generated_data.reward)
masks[-1] = torch.ones_like(masks[-1])
valid_masks = torch.ones_like(masks)
valid_done_masks = torch.zeros_like(generated_data.reward)
generated_data["done_mask"] = masks
generated_data["valid_masks"] = valid_masks
generated_data["valid_done_masks"] = valid_done_masks
return generated_data
[docs]
@catch_error
def before_train_epoch(self):
update_env_vars("policy_epoch",self._epoch_cnt)
if hasattr(self, "model"):
self.model.train()
if hasattr(self, "models"):
try:
for _model in self.models:
if isinstance(_model, torch.nn.Module):
_model.train()
except:
self.models.train()
if hasattr(self, "train_models"):
try:
for _model in self.train_models:
if isinstance(_model, torch.nn.Module):
_model.train()
except:
self.train_models.train()
if hasattr(self, "val_models"):
try:
for _model in self.val_models:
if isinstance(_model, torch.nn.Module):
_model.train()
except:
self.val_models.train()
[docs]
@catch_error
def after_train_epoch(self):
pass
[docs]
@catch_error
def before_validate_epoch(self):
if hasattr(self, "model"):
self.model.eval()
if hasattr(self, "models"):
try:
for _model in self.models:
if isinstance(_model, torch.nn.Module):
_model.eval()
except:
self.models.eval()
if hasattr(self, "train_models"):
try:
for _model in self.train_models:
if isinstance(_model, torch.nn.Module):
_model.eval()
except:
self.train_models.eval()
if hasattr(self, "val_models"):
try:
for _model in self.val_models:
if isinstance(_model, torch.nn.Module):
_model.eval()
except:
self.val_models.eval()
[docs]
@catch_error
def after_validate_epoch(self):
pass
[docs]
@catch_error
def train_epoch(self):
info = dict()
self._epoch_cnt += 1
logger.info(f"Policy Train Epoch : {self._epoch_cnt} ")
def train_policy_on_train_env(self):
metric_meters_train = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._train_loader_train)):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0):
# batch = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in batch.items()})
batch.to_torch(device=self._device)
metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='train')
else:
metrics = self.train_batch(batch, batch_info=batch_info, scope='train')
metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
return metric_meters_train.summary()
if self.double_validation:
def train_policy_on_val_env(self):
metric_meters_val = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._val_loader_train)):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
if self._epoch_cnt <= self.config["policy_bc_epoch"]:
# batch = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in batch.items()})
batch.to_torch(device=self._device)
metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='val')
else:
metrics = self.train_batch(batch, batch_info=batch_info, scope='val')
metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
return metric_meters_val.summary()
if self._speedup and self.double_validation:
with futures.ThreadPoolExecutor() as executor:
metric_meters_train = executor.submit(train_policy_on_train_env, self)
metric_meters_val = executor.submit(train_policy_on_val_env, self)
futures.wait([metric_meters_train, metric_meters_val])
metric_meters_train = metric_meters_train.result()
metric_meters_val = metric_meters_val.result()
else:
metric_meters_train = train_policy_on_train_env(self)
if self.double_validation:
metric_meters_val = train_policy_on_val_env(self)
info.update(metric_meters_train)
if self.double_validation:
info.update(metric_meters_val)
return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
[docs]
@catch_error
def bc_train_batch(self, expert_data, batch_info, scope='train'):
if scope == 'train':
target_policys = self.train_policy
optimizer = self.train_bc_optimizer
else:
target_policys = self.val_policy
optimizer = self.val_bc_optimizer
info = {}
loss = 0
total_action_mae, action_mse = 0, 0
for policy_name, target_policy in zip(self.policy_name, target_policys):
target_policy.reset()
with torch.no_grad():
state = get_input_from_graph(self._graph, policy_name , expert_data)
if hasattr(self.train_nodes[policy_name], 'custom_node'):
state = get_input_dict_from_graph(self._graph, policy_name , expert_data)
if hasattr(self.train_nodes[policy_name], 'ts_conv_node') and self.train_nodes[policy_name].ts_conv_node :
state = get_input_dict_from_graph(self._graph, policy_name , expert_data)
action_dist = target_policy(state, field='bc')
action_log_prob = action_dist.log_prob(expert_data[policy_name])
if policy_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - expert_data[policy_name + "_isnan_index_"]
policy_loss = - (action_log_prob*isnan_index).mean()
action_mae = ((action_dist.mode - expert_data[policy_name])*isnan_index).abs().sum(dim=-1).mean().detach()
action_mse = (((action_dist.mode - expert_data[policy_name])*isnan_index)**2).sum(dim=-1).mean().detach()
else:
isnan_index = None
policy_loss = - action_log_prob.mean()
action_mae = ((action_dist.mode - expert_data[policy_name])).abs().sum(dim=-1).mean().detach()
action_mse = ((action_dist.mode - expert_data[policy_name])**2).sum(dim=-1).mean().detach()
loss += policy_loss
info[f"{self.NAME.lower()}/{policy_name}_policy_bc_nll"] = policy_loss.mean().item()
info[f"{self.NAME.lower()}/{policy_name}_policy_bc_mae"] = action_mae.mean().item()
info[f"{self.NAME.lower()}/{policy_name}_policy_bc_mse"] = action_mse.mean().item()
info[f"{self.NAME.lower()}/{policy_name}_policy_bc_std"] = action_dist.std.mean().item()
optimizer.zero_grad()
loss.backward()
grads = []
for param in target_policy.parameters():
try:
grads.append(param.grad.view(-1))
except:
continue
grads = torch.cat(grads).abs().mean().item()
optimizer.step()
info[f"{self.NAME.lower()}/total_policy_bc_loss"] = loss.mean().item()
info[f"{self.NAME.lower()}/policy_bc_grad"] = grads
return info
[docs]
@catch_error
def validate(self):
info = dict()
def validate_train_policy(self):
info = dict()
metric_meters_train = AverageMeterCollection()
with torch.no_grad():
for batch_idx, batch in enumerate(iter(self._val_loader_val)):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
metrics = self.validate_batch(batch, batch_info, scope='trainPolicy_on_valEnv')
metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
info = metric_meters_train.summary()
if self.double_validation:
metric_meters_train = AverageMeterCollection()
with torch.no_grad():
for batch_idx, batch in enumerate(iter(self._train_loader_val)):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
metrics = self.validate_batch(batch, batch_info, scope='trainPolicy_on_trainEnv')
metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
info.update(metric_meters_train.summary())
if (not self.config['real_env_test_frequency'] == 0) and \
self._epoch_cnt % self.config['real_env_test_frequency'] == 0:
info.update(self._env_test(self.policy))
if (not self.config['fqe_test_frequency'] == 0) and \
self._epoch_cnt % self.config['fqe_test_frequency'] == 0:
info.update(self._fqe_test(self.policy))
return info
if self.double_validation:
def validate_val_policy(self):
info = dict()
metric_meters_val = AverageMeterCollection()
with torch.no_grad():
for batch_idx, batch in enumerate(iter(self._train_loader_val)):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
metrics = self.validate_batch(batch, batch_info, scope='valPolicy_on_trainEnv')
metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
info.update(metric_meters_val.summary())
metric_meters_val = AverageMeterCollection()
with torch.no_grad():
for batch_idx, batch in enumerate(iter(self._val_loader_val)):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
metrics = self.validate_batch(batch, batch_info, scope='valPolicy_on_valEnv')
metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
info.update(metric_meters_val.summary())
if (not self.config['real_env_test_frequency'] == 0) and \
self._epoch_cnt % self.config['real_env_test_frequency'] == 0:
info.update(self.env_test(self.val_policy))
if (not self.config['fqe_test_frequency'] == 0) and \
self._epoch_cnt % self.config['fqe_test_frequency'] == 0:
info.update(self.fqe_test(self.val_policy))
return info
if self._speedup and self.double_validation:
with futures.ThreadPoolExecutor() as executor:
validate_train_policy_info = executor.submit(validate_train_policy, self)
validate_val_policy_info = executor.submit(validate_val_policy, self)
futures.wait([validate_train_policy_info, validate_val_policy_info])
validate_train_policy_info = validate_train_policy_info.result()
validate_val_policy_info = validate_val_policy_info.result()
else:
validate_train_policy_info = validate_train_policy(self)
if self.double_validation:
validate_val_policy_info = validate_val_policy(self)
info.update(validate_train_policy_info)
if self.double_validation:
info.update(validate_val_policy_info)
info = {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
if self.double_validation:
reward_flag = "reward_trainPolicy_on_valEnv"
else:
reward_flag = "reward_trainenv"
if info[reward_flag] > self._max_reward:
self._max_reward = info[reward_flag]
self._save_models(self._traj_dir)
self._update_metric()
info["stop_flag"] = self._stop_flag
if self._stop_flag:
# revive online server
try:
customer_uploadTrainLog(self.config["trainId"],
os.path.join(os.path.abspath(self._workspace),"revive.log"),
"train.policy",
"success",
self._max_reward,
self.config["accessToken"])
except Exception as e:
logger.info(f"{e}")
# save double validation plot after training
if self.double_validation:
plt_double_venv_validation(self._traj_dir,
self.train_average_reward,
self.val_average_reward,
os.path.join(self._traj_dir, 'double_validation.png'))
try:
graph = self.envs_train[0].graph
for policy_index, policy_name in enumerate(self.policy_name):
graph.nodes[policy_name] = deepcopy(self.infer_policy)._policy_model.nodes[policy_index]
# save rolllout action image
rollout_save_path = os.path.join(self._traj_dir, 'rollout_images')
save_rollout_action(rollout_save_path, graph, self._device, self.train_dataset, self.nodes_map,
pre_horzion=self.config["pre_horzion"], train_type='Policy')
# policy to tree and plot the tree
tree_save_path = os.path.join(self._traj_dir, 'policy_tree')
net_to_tree(tree_save_path, graph, self._device, self.train_dataset, self.action_dims, rnn_dataset=self.config['rnn_dataset'])
except Exception as e:
logger.info(f"{e}")
return info
[docs]
def train_batch(self, expert_data : Batch, batch_info : Dict[str, float], scope : str = 'train'):
raise NotImplementedError
[docs]
def validate_batch(self, expert_data : Batch, batch_info : Dict[str, float], scope : str = 'trainPolicy_on_valEnv'):
expert_data.to_torch(device=self._device)
if not self.double_validation:
info = self.venv_test(expert_data, self.policy, traj_length=self.config['test_horizon'], scope="trainenv")
elif "trainPolicy" in scope:
info = self.venv_test(expert_data, self.policy, traj_length=self.config['test_horizon'], scope=scope)
elif "valPolicy" in scope:
info = self.venv_test(expert_data, self.val_policy, traj_length=self.config['test_horizon'], scope=scope)
return info
[docs]
class PolicyAlgorithm:
def __init__(self, algo : str, workspace: str =None):
self.algo = algo
self.workspace = workspace
try:
self.algo_module = importlib.import_module(f'revive.dist.algo.policy.{self.algo.split(".")[0]}')
logger.info(f"Import encryption policy algorithm module -> {self.algo}!")
except:
self.algo_module = importlib.import_module(f'revive.algo.policy.{self.algo.split(".")[0]}')
logger.info(f"Import policy algorithm module -> {self.algo}!")
# Assume there is only one operator other than PolicyOperator
for k in self.algo_module.__dir__():
if 'Operator' in k and not k == 'PolicyOperator':
self.operator = getattr(self.algo_module, k)
self.operator_config = {}
[docs]
def get_train_func(self, config={}):
def train_func(config):
for k,v in self.operator_config.items():
if not k in config.keys():
config[k] = v
algo_operator = self.operator(config)
policy_val_freq = config.get("policy_val_freq", 1)
writer = SummaryWriter(algo_operator._traj_dir)
while True:
algo_operator.before_train_epoch()
train_stats = algo_operator.train_epoch()
algo_operator.after_train_epoch()
epoch = algo_operator._epoch_cnt
# Validate an epoch, support set model validation frequency by config["policy_val_freq"]
if (epoch - 1) % policy_val_freq == 0 or epoch == 1:
with torch.no_grad():
algo_operator.before_validate_epoch()
val_stats = algo_operator.validate()
algo_operator.after_validate_epoch()
# Report model metric to ray
session.report({"mean_accuracy": algo_operator._max_reward,
"reward_trainPolicy_on_valEnv": val_stats["reward_trainPolicy_on_valEnv"],
"stop_flag": val_stats["stop_flag"]})
# Write log to tensorboard
for k, v in [*train_stats.items(), *val_stats.items()]:
if type(v) in VALID_SUMMARY_TYPES:
writer.add_scalar(k, v, global_step=epoch)
elif isinstance(v, torch.Tensor):
v = v.view(-1)
writer.add_histogram(k, v, global_step=epoch)
writer.flush()
# Check stop_flag
train_stats.update(val_stats)
if train_stats.get('stop_flag', False):
break
torch.cuda.empty_cache()
return train_func
[docs]
def get_trainer(self, config):
try:
train_func = self.get_train_func(config)
trainer_resources = None
"""
try:
cluster_resources = ray.cluster_resources()
num_gpus = cluster_resources["GPU"]
num_cpus = cluster_resources["CPU"]
num_gpus_blocks = int(1/config['policy_gpus_per_worker']) * num_gpus
resources_per_trial_cpu = max(int(num_cpus/num_gpus_blocks),1)
trainer_resources = {"GPU": config['policy_gpus_per_worker'],"CPU":resources_per_trial_cpu}
except:
trainer_resources = None
"""
trainer = TorchTrainer(
train_func,
train_loop_config=config,
scaling_config=ScalingConfig(trainer_resources=trainer_resources,
num_workers=config['workers_per_trial'],
use_gpu=config['use_gpu']),
)
return trainer
except Exception as e:
logger.error('Detect Error: {}'.format(e))
raise e
[docs]
def get_trainable(self, config):
try:
trainer = self.get_trainer(config)
trainable = trainer.as_trainable()
return trainable
except Exception as e:
logger.error('Detect Error: {}'.format(e))
raise e
[docs]
def get_parameters(self, command=None):
try:
return self.operator.get_parameters(command)
except AttributeError:
raise AttributeError("Custom algorithm need to implement `get_parameters`")
except Exception as e:
logger.error('Detect Unknown Error:'.format(e))
raise e
[docs]
def get_tune_parameters(self, config):
try:
self.operator_config = config
return self.operator.get_tune_parameters(config)
except AttributeError:
raise AttributeError("Custom algorithm need to implement `get_tune_parameters`")
except Exception as e:
logger.error('Detect Unknown Error:'.format(e))
raise e