Source code for revive.algo.policy.base

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

import os
import ray
import math
import torch
import warnings
import argparse
import importlib
import traceback
import numpy as np
from ray import tune
from loguru import logger
from copy import deepcopy
from ray import train
from ray.train.torch import TorchTrainer
from revive.utils.raysgd_utils import NUM_SAMPLES, AverageMeterCollection
from revive.computation.inference import *
from revive.computation.graph import FunctionDecisionNode
from revive.data.batch import Batch
from revive.data.dataset import data_creator
from ray.train.torch import TorchTrainer
from revive.utils.common_utils import *
from revive.utils.tune_utils import get_tune_callbacks, CustomSearchGenerator, CustomBasicVariantGenerator, CLIReporter
from revive.utils.auth_utils import customer_uploadTrainLog

warnings.filterwarnings('ignore')

[docs]def catch_error(func): '''push the training error message to data buffer''' def wrapped_func(self, *args, **kwargs): try: return func(self, *args, **kwargs) except Exception as e: error_message = traceback.format_exc() logger.error('Detect error:{}, Error Message: {}'.format(e,error_message)) ray.get(self._data_buffer.update_status.remote(self._traj_id, 'error', error_message)) self._stop_flag = True try: customer_uploadTrainLog(self.config["trainId"], os.path.join(os.path.abspath(self._workspace),"revive.log"), "train.policy", "fail", self._max_reward, self.config["accessToken"]) except Exception as e: logger.info(f"{e}") return { 'stop_flag' : True, 'reward_trainPolicy_on_valEnv' : - np.inf, } return wrapped_func
[docs]class PolicyOperator(): @property def env(self): return self.envs_train[0] @property def policy(self): if isinstance(self.train_models, list) or isinstance(self.train_models, tuple): return self.train_models[:-1] else: return self.train_models @property def val_policy(self): if isinstance(self.val_models, list) or isinstance(self.val_models, tuple): return self.val_models[:-1] else: return self.val_models @property def other_models(self): if isinstance(self.models, list): return self.models[1:] else: return [] # NOTE: you need either write the `PARAMETER_DESCRIPTION` or overwrite `get_parameters` and `get_tune_parameters`. PARAMETER_DESCRIPTION = [] # a list of dict to describe the parameter of the algorithm
[docs] @classmethod def get_parameters(cls, command=None, **kargs): parser = argparse.ArgumentParser() for description in cls.PARAMETER_DESCRIPTION: names = ['--' + description['name']] if "abbreviation" in description.keys(): names.append('-' + description["abbreviation"]) if type(description['type']) is str: if 'str' in description['type']: description['type'] = str elif 'int' in description['type']: description['type'] = int elif 'float' in description['type']: description['type'] = float elif 'bool' in description['type']: description['type'] = bool parser.add_argument(*names, type=description['type'], default=description['default']) return parser.parse_known_args(command)[0].__dict__
[docs] @classmethod def get_tune_parameters(cls, config : Dict[str, Any], **kargs): r""" Use ray.tune to wrap the parameters to be searched. """ _search_algo = config['policy_search_algo'].lower() tune_params = { "name": "policy_tune", "reuse_actors": config["reuse_actors"], "local_dir": config["workspace"], "callbacks": get_tune_callbacks(), "stop": { "stop_flag": True }, "verbose": config["verbose"], } if _search_algo == 'random': random_search_config = {} for description in cls.PARAMETER_DESCRIPTION: if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'continuous': random_search_config[description['name']] = tune.uniform(*description['search_values']) elif description['search_mode'] == 'grid' or description['search_mode'] == 'discrete': random_search_config[description['name']] = tune.choice(description['search_values']) config["total_num_of_trials"] = config['train_policy_trials'] tune_params['config'] = random_search_config tune_params['num_samples'] = config["total_num_of_trials"] tune_params['search_alg'] = CustomBasicVariantGenerator() elif _search_algo == 'zoopt': #from ray.tune.search.zoopt import ZOOptSearch from revive.utils.tune_utils import ZOOptSearch from zoopt import ValueType if config['parallel_num'] == 'auto': if config['use_gpu']: num_of_gpu = int(ray.available_resources()['GPU']) num_of_cpu = int(ray.available_resources()['CPU']) parallel_num = min(int(num_of_gpu / config['policy_gpus_per_worker']), num_of_cpu) else: num_of_cpu = int(ray.available_resources()['CPU']) parallel_num = num_of_cpu else: parallel_num = int(config['parallel_num']) assert parallel_num > 0 config['parallel_num'] = parallel_num dim_dict = {} for description in cls.PARAMETER_DESCRIPTION: if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'continuous': dim_dict[description['name']] = (ValueType.CONTINUOUS, description['search_values'], min(description['search_values'])) elif description['search_mode'] == 'discrete': dim_dict[description['name']] = (ValueType.DISCRETE, description['search_values']) elif description['search_mode'] == 'grid': dim_dict[description['name']] = (ValueType.GRID, description['search_values']) zoopt_search_config = { "parallel_num": config['parallel_num'] } config["total_num_of_trials"] = config['train_policy_trials'] tune_params['search_alg'] = ZOOptSearch( algo="Asracos", # only support Asracos currently budget=config["total_num_of_trials"], dim_dict=dim_dict, metric='reward_trainPolicy_on_valEnv', mode="max", **zoopt_search_config ) tune_params['config'] = dim_dict tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator tune_params['num_samples'] = config["total_num_of_trials"] elif 'grid' in _search_algo: grid_search_config = {} config["total_num_of_trials"] = 1 for description in cls.PARAMETER_DESCRIPTION: if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'grid': grid_search_config[description['name']] = tune.grid_search(description['search_values']) config["total_num_of_trials"] *= len(description['search_values']) else: warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \ f"However, since grid search does not support search type {description['search_mode']}, the parameter is skipped. " + \ f"If this parameter is important to the performance, you should consider other search algorithms. ") tune_params['config'] = grid_search_config tune_params['num_samples'] = config["total_num_of_trials"] tune_params['search_alg'] = CustomBasicVariantGenerator() elif 'bayes' in _search_algo: from ray.tune.search.bayesopt import BayesOptSearch bayes_search_config = {} for description in cls.PARAMETER_DESCRIPTION: if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'continuous': bayes_search_config[description['name']] = description['search_values'] else: warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \ f"However, since bayesian search does not support search type {description['search_mode']}, the parameter is skipped. " + \ f"If this parameter is important to the performance, you should consider other search algorithms. ") config["total_num_of_trials"] = config['train_policy_trials'] tune_params['config'] = bayes_search_config tune_params['search_alg'] = BayesOptSearch(bayes_search_config, metric="reward_trainPolicy_on_valEnv", mode="max") tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator tune_params['num_samples'] = config["total_num_of_trials"] else: raise ValueError(f'search algorithm {_search_algo} is not supported!') reporter = CLIReporter(parameter_columns=list(tune_params['config'].keys()), max_progress_rows=50, max_report_frequency=10, sort_by_metric=True) reporter.add_metric_column("reward_trainPolicy_on_valEnv") tune_params["progress_reporter"] = reporter return tune_params
[docs] def model_creator(self, config : Dict[str, Any], node : FunctionDecisionNode) -> List[torch.nn.Module]: r""" Create all the models. The algorithm needs to define models for the nodes to be learned. Args: :config: configuration parameters Return: a list of models """ raise NotImplementedError
[docs] def optimizer_creator(self, models : List[torch.nn.Module], config : Dict[str, Any]) -> List[torch.optim.Optimizer]: r""" Define optimizers for the created models. Args: :pmodels: list of all the models :config: configuration parameters Return: a list of optimizers """ raise NotImplementedError
[docs] def data_creator(self, config : Dict[str, Any]): r""" Create DataLoaders. Args: :config: configuration parameters Return: (train_loader, val_loader) """ # raise NotImplementedError logger.warning('data_creator is using the test_horizon' ) return data_creator(config, val_horizon=config['test_horizon'], double=self.double_validation)
def _setup_data(self, config : Dict[str, Any]): ''' setup data used in training ''' self.train_dataset = ray.get(config['dataset']) self.val_dataset = ray.get(config['val_dataset']) if not self.double_validation: train_loader_train, val_loader_val = self.data_creator(config) self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False) self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False) self._train_loader_val = None self._val_loader_val = None else: train_loader_train, val_loader_train, train_loader_val, val_loader_val = self.data_creator(config) try: self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False) self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False) self._train_loader_val = train.torch.prepare_data_loader(train_loader_val, move_to_device=False) self._val_loader_val = train.torch.prepare_data_loader(val_loader_val, move_to_device=False) except: self._train_loader_train = train_loader_train self._val_loader_train = val_loader_train self._train_loader_val = train_loader_val self._val_loader_val = val_loader_val def _setup_models(self, config : Dict[str, Any]): r'''setup models, optimizers and dataloaders.''' if not self.double_validation: self.train_nodes = {policy_name:deepcopy(self._graph.get_node(policy_name)) for policy_name in self.policy_name} self.train_models = self.model_creator(config, self.train_nodes) for model_index, model in enumerate(self.train_models): try: self.train_models[model_index] = train.torch.prepare_model(model) except: self.train_models[model_index] = model.to(self._device) self.train_optimizers = self.optimizer_creator(self.train_models, config) self.val_models = None self.val_optimizers = None for train_node,policy in zip(self.train_nodes.values(), self.policy): train_node.set_network(policy) else: self.train_nodes = {policy_name:deepcopy(self._graph.get_node(policy_name)) for policy_name in self.policy_name} self.train_models = self.model_creator(config, self.train_nodes) for model_index, model in enumerate(self.train_models): try: self.train_models[model_index] = train.torch.prepare_model(model) except: self.train_models[model_index] = model.to(self._device) self.train_optimizers = self.optimizer_creator(self.train_models, config) self.val_nodes = {policy_name:deepcopy(self._graph.get_node(policy_name)) for policy_name in self.policy_name} self.val_models = self.model_creator(config, self.val_nodes) for model_index, model in enumerate(self.val_models): try: self.val_models[model_index] = train.torch.prepare_model(model) except: self.val_models[model_index] = model.to(self._device) self.val_optimizers = self.optimizer_creator(self.val_models, config) for train_node,policy in zip(self.train_nodes.values(), self.policy): train_node.set_network(policy) for val_node,policy in zip(self.val_nodes.values(), self.val_policy): val_node.set_network(policy) @catch_error def __init__(self, config : Dict[str, Any]): '''setup everything for training''' # parse information from config self.config = config self._data_buffer = config['policy_data_buffer'] self._workspace = config["workspace"] self._user_func = config['user_func'] self._graph = config['graph'] self._processor = self._graph.processor self.policy_name = self.config['target_policy_name'] if isinstance(self.policy_name, str): self.policy_name = [self.policy_name, ] # sort the policy by graph self.policy_name = [policy_name for policy_name in self.policy_name if policy_name in self._graph.keys()] self.double_validation = self.config['policy_double_validation'] self._filename = os.path.join(self._workspace, "train_policy.json") self._data_buffer.set_total_trials.remote(config.get("total_num_of_trials", 1)) self._data_buffer.inc_trial.remote() self._max_reward = -np.inf # get id self._ip = ray._private.services.get_node_ip_address() logger.add(os.path.join(os.path.abspath(self._workspace),"revive.log")) # set trail seed setup_seed(config["global_seed"]) if 'tag' in self.config: # create by tune tag = self.config['tag'] logger.info("{} is running".format(tag)) self._traj_id = int(tag.split('_')[0]) experiment_dir = os.path.join(self._workspace, 'policy_tune') for traj_name in filter(lambda x: "ReviveLog" in x, os.listdir(experiment_dir)): if len(traj_name.split('_')[1]) == 5: # create by random search or grid search id_index = 3 else: id_index = 2 if int(traj_name.split('_')[id_index]) == self._traj_id: self._traj_dir = os.path.join(experiment_dir, traj_name) break else: self._traj_id = 1 self._traj_dir = os.path.join(self._workspace, 'policy_train') update_env_vars("pattern", "policy") # setup constant self._stop_flag = False self._batch_cnt = 0 self._epoch_cnt = 0 self._max_reward = - np.inf self._use_gpu = self.config["use_gpu"] and torch.cuda.is_available() self._device = 'cuda' if self._use_gpu else 'cpu' # fix problem introduced in ray 1.1 # collect venv env = ray.get(config['venv_data_buffer'].get_best_venv.remote()) self.envs = env.env_list for env in self.envs: env.to(self._device) env.requires_grad_(False) env.set_target_policy_name(self.policy_name) logger.info(f"Find {len(self.envs)} envs.") if len(self.envs) == 1: logger.warning('Only one venv found, use it in both training and validation!') self.envs_train = self.envs self.envs_val = self.envs else: self.envs_train = self.envs[:len(self.envs)//2] self.envs_val = self.envs[len(self.envs)//2:] #if self.config['num_venv_in_use'] > len(self.envs_train): # warnings.warn(f"Config requires {self.config['num_venv_in_use']} venvs, but only {len(self.envs_train)} venvs are available.") if self.config['num_venv_in_use'] > len(self.envs_train): logger.info("Adjusting the distribution to generate multiple env models.") mu_shift_list = np.linspace(-0.15, 0.15, num=(self.config['num_venv_in_use']-1)) for mu_shift in mu_shift_list: if mu_shift == 0: continue env_train = deepcopy(self.envs_train[0]) for node in env_train.graph.nodes.values(): if node.node_type == 'network': node.network.dist_mu_shift = mu_shift self.envs_train.append(env_train) env_val = deepcopy(self.envs_train[0]) for node in env_val.graph.nodes.values(): if node.node_type == 'network': node.network.dist_mu_shift = mu_shift self.envs_val.append(env_train) else: self.envs_train = self.envs_train[:int(self.config['num_venv_in_use'])] self.envs_val = self.envs_val[:int(self.config['num_venv_in_use'])] self.behaviour_nodes = [self.envs[0].graph.get_node(policy_name) for policy_name in self.policy_name] self.behaviour_policys = [behaviour_node.get_network() for behaviour_node in self.behaviour_nodes if behaviour_node.get_network()] # find average reward from expert data dataset = ray.get(self.config['dataset']) train_data = dataset.data train_data.to_torch() train_reward = self._user_func(train_data) self.train_average_reward = float(train_reward.mean()) dataset = ray.get(self.config['val_dataset']) val_data = dataset.data val_data.to_torch() val_reward = self._user_func(train_data) self.val_average_reward = float(val_reward.mean()) # initialize FQE self.fqe_evaluator = None # prepare for training self._setup_data(config) self._setup_models(config) self._save_models(self._traj_dir) self._update_metric() # prepare for ReplayBuffer try: self.setup(config) except: pass self.action_dims = {} for policy_name in self.policy_name: action_dims = [] for action_dim in self._graph.descriptions[policy_name]: action_dims.append(list(action_dim.keys())[0]) self.action_dims[policy_name] = action_dims self.nodes_map = {} for node_name in list(self._graph.nodes) + list(self._graph.leaf): node_dims = [] for node_dim in self._graph.descriptions[node_name]: node_dims.append(list(node_dim.keys())[0]) self.nodes_map[node_name] = node_dims self.global_step = 0 def _early_stop(self, info : Dict[str, Any]): info["stop_flag"] = self._stop_flag return info def _update_metric(self): try: policy = torch.load(os.path.join(self._traj_dir, 'policy.pt'), map_location='cpu') except: policy = None self._data_buffer.update_metric.remote(self._traj_id, {"reward" : self._max_reward, "ip" : self._ip, "policy" : policy, "traj_dir": self._traj_dir}) self._data_buffer.write.remote(self._filename) if self._max_reward == ray.get(self._data_buffer.get_max_reward.remote()): self._save_models(self._workspace, with_policy=False) def _wrap_policy(self, policys : List[torch.nn.Module,], device = None) -> PolicyModelDev: policy_nodes = deepcopy(self.behaviour_nodes) policys = deepcopy(policys) for policy_node,policy in zip(policy_nodes,policys): policy_node.set_network(policy) if device: policy_node.to(device) policys = PolicyModelDev(policy_nodes) return policys def _save_models(self, path : str, with_policy : bool = True): torch.save(self.policy, os.path.join(path, "tuned_policy.pt")) if with_policy: policy = self._wrap_policy(self.policy, "cpu") torch.save(policy, os.path.join(path, "policy.pt")) policy = PolicyModel(policy) self.infer_policy = policy with open(os.path.join(self._traj_dir, "policy.pkl"), 'wb') as f: pickle.dump(policy, f) def _get_original_actions(self, batch_data : Batch) -> Batch: # return a list with all actions [o, a_1, a_2, ... o'] # NOTE: we assume key `next_obs` in the data. # NOTE: the leading dimensions will be flattened. original_data = Batch() for policy_name in list(self._graph.leaf) + list(self._graph.keys()): action = batch_data[policy_name].view(-1, batch_data[policy_name].shape[-1]) original_data[policy_name] = self._processor.deprocess_single_torch(action, policy_name) return original_data
[docs] def get_ope_dataset(self): r''' convert the dataset to OPEDataset used in d3pe ''' from d3pe.utils.data import OPEDataset dataset = ray.get(self.config['dataset']) expert_data = dataset.processor.process(dataset.data) expert_data = expert_data expert_data.to_torch() expert_data = generate_rewards(expert_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data))) expert_data.to_numpy() policy_input_names = self._graph.get_node(self.policy_name).input_names if all([node_name in self._graph.transition_map.keys() for node_name in policy_input_names]): obs = np.concatenate([expert_data[node_name] for node_name in policy_input_names], axis=-1) next_obs = np.concatenate([expert_data['next_' + node_name] for node_name in policy_input_names], axis=-1) ope_dataset = OPEDataset(dict( obs=obs.reshape((-1, obs.shape[-1])), action=expert_data[self.policy_name].reshape((-1, expert_data[self.policy_name].shape[-1])), reward=expert_data['reward'].reshape((-1, expert_data['reward'].shape[-1])), done=expert_data['done'].reshape((-1, expert_data['done'].shape[-1])), next_obs=next_obs.reshape((-1, next_obs.shape[-1])), ), start_indexes=dataset._start_indexes) else: data = dict( action=expert_data[self.policy_name][:-1].reshape((-1, expert_data[self.policy_name].shape[-1])), reward=expert_data['reward'][:-1].reshape((-1, expert_data['reward'].shape[-1])), done=expert_data['done'][:-1].reshape((-1, expert_data['done'].shape[-1])), ) expert_data.to_torch() obs = get_input_from_graph(self._graph, self.policy_name, expert_data).numpy() expert_data.to_numpy() data['obs'] = obs[:-1].reshape((-1, obs.shape[-1])) data['next_obs'] = obs[1:].reshape((-1, obs.shape[-1])) ope_dataset = OPEDataset(data, dataset._start_indexes - np.arange(dataset._start_indexes.shape[0])) return ope_dataset
[docs] def venv_test(self, expert_data : Batch, target_policy, traj_length=None, scope : str = 'trainPolicy_on_valEnv'): r""" Use the virtual env model to test the policy model""" rewards = [] envs = self.envs_val if "valEnv" in scope else self.envs_train for env in envs: generated_data, info = self._run_rollout(expert_data, target_policy, env, traj_length, deterministic=self.config['deterministic_test'], clip=True) if 'done' in generated_data.keys(): temp_done = self._processor.deprocess_single_torch(generated_data['done'], 'done') not_done = ~temp_done.bool() temp_reward = not_done * generated_data.reward else: not_done = torch.ones_like(generated_data.reward) temp_reward = not_done * generated_data.reward reward = temp_reward.squeeze(dim=-1) t = torch.arange(0, reward.shape[0]).to(reward) discount = self.config['test_gamma'] ** t discount_reward = torch.sum(discount.unsqueeze(dim=-1) * reward, dim=0) rewards.append(discount_reward.mean().item() / self.config['test_horizon']) return {f'reward_{scope}' : np.mean(rewards)}
def _fqe_test(self, target_policy, scope='trainPolicy_on_valEnv'): if 'offlinerl' in str(type(target_policy)): target_policy.get_action = lambda x: target_policy(x) if self.fqe_evaluator is None: # initialize FQE evaluator in the first run from d3pe.evaluator.fqe import FQEEvaluator self.fqe_evaluator = FQEEvaluator() self.fqe_evaluator.initialize(self.get_ope_dataset(), verbose=True) with torch.enable_grad(): info = self.fqe_evaluator(target_policy) for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k) return info def _env_test(self, target_policy, scope='trainPolicy_on_valEnv'): env = create_env(self.config['task']) if env is None: return {} policy = deepcopy(target_policy) policy = policy.to('cpu') policy = self._wrap_policy(policy) policy = PolicyModel(policy) reward, length = test_on_real_env(env, policy) return { f"real_reward_{scope}" : reward, f"real_length_{scope}" : length, } def _run_rollout(self, expert_data : Batch, target_policy, env : Union[VirtualEnvDev, List[VirtualEnvDev]], traj_length : int = None, maintain_grad_flow : bool = False, deterministic : bool = True, clip : bool = False): traj_length = traj_length or expert_data.obs.shape[0] if traj_length >= expert_data.shape[0] and len(self._graph.leaf) > 1: traj_length = expert_data.shape[0] warnings.warn(f'leaf node detected, connot run over the expert trajectory! Reset `traj_length` to {traj_length}!') generated_data = self.generate_rollout(expert_data, target_policy, env, traj_length, maintain_grad_flow, deterministic, clip) generated_data = generate_rewards(generated_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data))) if "uncertainty" in generated_data.keys(): generated_data["reward"] -= self.config["reward_uncertainty_weight"]*generated_data["uncertainty"] # If use the action for multi-steps in env. if self.config["action_steps"] >= 2: index = torch.arange(0,generated_data["reward"].shape[0],self.config["action_steps"]) actions_step_data = Batch() for k,v in generated_data.items(): if k == "reward": continue actions_step_data[k] = v[index] reward = generated_data["reward"] reward_pad_steps = math.ceil(reward.shape[0] / self.config["action_steps"]) * self.config["action_steps"] - reward.shape[0] reward_pad = torch.cat([reward, reward[-1:].repeat(reward_pad_steps,1,1)],axis=0) actions_step_data["reward"] = reward_pad.reshape(self.config["action_steps"], -1, reward_pad.shape[1], reward_pad.shape[2]).mean(axis=0) generated_data = actions_step_data info = { "reward": generated_data.reward.mean().item() } return generated_data, info
[docs] def generate_rollout(self, expert_data : Batch, target_policy, env : Union[VirtualEnvDev, List[VirtualEnvDev]], traj_length : int, maintain_grad_flow : bool = False, deterministic : bool = True, clip : bool = False): r"""Generate trajectories based on current policy. Args: :expert_data: sampled data from the dataset. :target_policy: target_policy :env: env :traj_length: traj_length :maintain_grad_flow: maintain_grad_flow Return: batch trajectories """ for policy in target_policy: policy.reset() if isinstance(env, list): for _env in env: _env.reset() else: env.reset() assert len(self._graph.leaf) == 1 or traj_length <= expert_data.shape[0], \ 'There is leaf node on the graph, cannot generate trajectory beyond expert data' generated_data = [] current_batch = Batch({k : expert_data[0][k] for k in self._graph.leaf}) batch_size = current_batch.shape[0] sample_fn = lambda dist: dist.rsample() if maintain_grad_flow else dist.sample() if isinstance(env, list): sample_env_nums = min(min(7,len(env)),batch_size) env_id = random.sample(range(len(env)), k=sample_env_nums) n = int(math.ceil(batch_size / float(sample_env_nums))) env_batch_index = [range(batch_size)[i:min(i + n,batch_size)] for i in range(0, batch_size, n)] uncertainty_list = [] done = False for i in range(traj_length+1): for policy_index, policy_name in enumerate(self.policy_name): if isinstance(env, list): result_batch = [] use_env_id = -1 for _env_id,_env in enumerate(env): if _env_id not in env_id: continue use_env_id += 1 _current_batch = deepcopy(current_batch) #[env_batch_index[use_env_id],:] _current_batch = _env.pre_computation(_current_batch, deterministic, clip, policy_index) result_batch.append(_current_batch) current_batch = Batch.cat([_current_batch[_env_batch_index,:] for _current_batch,_env_batch_index in zip(result_batch,env_batch_index)]) policy_inputs = [get_input_from_graph(self._graph, policy_name, _current_batch) for _current_batch in result_batch] policy_inputs = torch.stack(policy_inputs, dim=2) policy_inputs_mean = torch.mean(policy_inputs, dim=-1, keepdim=True) diff = policy_inputs - policy_inputs_mean uncertainty = torch.max(torch.norm(diff, dim=-1, keepdim=False), dim=1)[0].reshape(-1,1) if i > 0: uncertainty_list.append(uncertainty) if i == traj_length: done = True break else: current_batch = env.pre_computation(current_batch, deterministic, clip, policy_index) policy_input = get_input_from_graph(self._graph, policy_name, current_batch) if 'offlinerl' in str(type(target_policy[policy_index])): action = target_policy[policy_index](policy_input) current_batch[policy_name] = action else: # use policy infer if i % self.config["action_steps"] == 0: action_dist = target_policy[policy_index](policy_input) action = sample_fn(action_dist) action_log_prob = (action_dist.log_prob(action).unsqueeze(dim=-1)).detach() current_batch[policy_name + '_log_prob'] = action_log_prob current_batch[policy_name] = action # use the last step action else: current_batch[policy_name + '_log_prob'] = action_log_prob current_batch[policy_name] = deepcopy(action.detach()) if done: break if isinstance(env, list): result_batch = [] use_env_id = -1 for _env_id,_env in enumerate(env): if _env_id not in env_id: continue use_env_id += 1 _current_batch = deepcopy(current_batch)[env_batch_index[use_env_id],:] _current_batch = _env.post_computation(_current_batch, deterministic, clip, policy_index) result_batch.append(_current_batch) current_batch = Batch.cat(result_batch) else: current_batch = env.post_computation(current_batch, deterministic, clip, policy_index) generated_data.append(current_batch) if i == traj_length - 1 : break current_batch = Batch(self._graph.state_transition(current_batch)) for k in self._graph.leaf: if not k in self._graph.transition_map.keys(): current_batch[k] = expert_data[i+1][k] for current_batch, uncertainty in zip(generated_data, uncertainty_list): current_batch["uncertainty"] = uncertainty generated_data = Batch.stack(generated_data) return generated_data
[docs] @catch_error def before_train_epoch(self): update_env_vars("policy_epoch",self._epoch_cnt)
[docs] @catch_error def train_epoch(self): info = dict() self._epoch_cnt += 1 if hasattr(self, "model"): self.model.train() if hasattr(self, "models"): try: for _model in self.models: if isinstance(_model, torch.nn.Module): _model.train() except: self.models.train() metric_meters_train = AverageMeterCollection() for batch_idx, batch in enumerate(iter(self._train_loader_train)): batch_info = { "batch_idx": batch_idx, "global_step": self.global_step } batch_info.update(info) if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0): metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='train') else: metrics = self.train_batch(batch, batch_info=batch_info, scope='train') metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) self.global_step += 1 info = metric_meters_train.summary() if self.double_validation: metric_meters_val = AverageMeterCollection() for batch_idx, batch in enumerate(iter(self._val_loader_train)): batch_info = { "batch_idx": batch_idx, "global_step": self.global_step } batch_info.update(info) if self._epoch_cnt <= self.config["policy_bc_epoch"]: metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='val') else: metrics = self.train_batch(batch, batch_info=batch_info, scope='val') metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) self.global_step += 1 info.update(metric_meters_val.summary()) if os.path.exists(os.path.join(self._workspace,'.env.json')): import json with open(os.path.join(self._workspace,'.env.json'), 'r') as f: _data = json.load(f) if _data["REVIVE_STOP"]: self._stop_flag = True info["stop_flag"] = self._stop_flag return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
[docs] @catch_error def validate(self): logger.info(f"Epoch : {self._epoch_cnt} ") info = dict() # switch to evaluate mode if hasattr(self, "model"): self.model.eval() if hasattr(self, "models"): try: for _model in self.models: if isinstance(_model, torch.nn.Module): _model.eval() except: self.models.eval() metric_meters_train = AverageMeterCollection() with torch.no_grad(): for batch_idx, batch in enumerate(iter(self._val_loader_val)): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, scope='trainPolicy_on_valEnv') metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) info = metric_meters_train.summary() if self.double_validation: metric_meters_train = AverageMeterCollection() with torch.no_grad(): for batch_idx, batch in enumerate(iter(self._train_loader_val)): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, scope='trainPolicy_on_trainEnv') metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) info.update(metric_meters_train.summary()) if (not self.config['real_env_test_frequency'] == 0) and \ self._epoch_cnt % self.config['real_env_test_frequency'] == 0: info.update(self._env_test(self.policy)) if (not self.config['fqe_test_frequency'] == 0) and \ self._epoch_cnt % self.config['fqe_test_frequency'] == 0: info.update(self._fqe_test(self.policy)) if self.double_validation: metric_meters_val = AverageMeterCollection() with torch.no_grad(): for batch_idx, batch in enumerate(iter(self._train_loader_val)): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, scope='valPolicy_on_trainEnv') metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) info.update(metric_meters_val.summary()) metric_meters_val = AverageMeterCollection() with torch.no_grad(): for batch_idx, batch in enumerate(iter(self._val_loader_val)): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, scope='valPolicy_on_valEnv') metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) info.update(metric_meters_val.summary()) if (not self.config['real_env_test_frequency'] == 0) and \ self._epoch_cnt % self.config['real_env_test_frequency'] == 0: info.update(self.env_test(self.val_policy)) if (not self.config['fqe_test_frequency'] == 0) and \ self._epoch_cnt % self.config['fqe_test_frequency'] == 0: info.update(self.fqe_test(self.val_policy)) info = {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())} if self.double_validation: reward_flag = "reward_trainPolicy_on_valEnv" else: reward_flag = "reward_trainenv" if info[reward_flag] > self._max_reward: self._max_reward = info[reward_flag] self._save_models(self._traj_dir) self._update_metric() info["stop_flag"] = self._stop_flag if self._stop_flag: # revive online server try: customer_uploadTrainLog(self.config["trainId"], os.path.join(os.path.abspath(self._workspace),"revive.log"), "train.policy", "success", self._max_reward, self.config["accessToken"]) except Exception as e: logger.info(f"{e}") # save double validation plot after training if self.double_validation: plt_double_venv_validation(self._traj_dir, self.train_average_reward, self.val_average_reward, os.path.join(self._traj_dir, 'double_validation.png')) graph = self.envs_train[0].graph for policy_index, policy_name in enumerate(self.policy_name): graph.nodes[policy_name] = deepcopy(self.infer_policy)._policy_model.nodes[policy_index] # save rolllout action image rollout_save_path = os.path.join(self._traj_dir, 'rollout_images') save_rollout_action(rollout_save_path, graph, self._device, self.train_dataset, self.nodes_map) # policy to tree and plot the tree tree_save_path = os.path.join(self._traj_dir, 'policy_tree') net_to_tree(tree_save_path, graph, self._device, self.train_dataset, self.action_dims ) return info
[docs] def train_batch(self, expert_data : Batch, batch_info : Dict[str, float], scope : str = 'train'): raise NotImplementedError
[docs] def validate_batch(self, expert_data : Batch, batch_info : Dict[str, float], scope : str = 'trainPolicy_on_valEnv'): expert_data.to_torch(device=self._device) if not self.double_validation: info = self.venv_test(expert_data, self.policy, traj_length=self.config['test_horizon'], scope="trainenv") elif "trainPolicy" in scope: info = self.venv_test(expert_data, self.policy, traj_length=self.config['test_horizon'], scope=scope) elif "valPolicy" in scope: info = self.venv_test(expert_data, self.val_policy, traj_length=self.config['test_horizon'], scope=scope) return info
[docs]class PolicyAlgorithm: def __init__(self, algo : str, workspace: str =None): self.algo = algo self.workspace = workspace try: self.algo_module = importlib.import_module(f'revive.dist.algo.policy.{self.algo.split(".")[0]}') logger.info(f"Import encryption policy algorithm module -> {self.algo}!") except: self.algo_module = importlib.import_module(f'revive.algo.policy.{self.algo.split(".")[0]}') logger.info(f"Import policy algorithm module -> {self.algo}!") # Assume there is only one operator other than PolicyOperator for k in self.algo_module.__dir__(): if 'Operator' in k and not k == 'PolicyOperator': self.operator = getattr(self.algo_module, k) self.operator_config = {}
[docs] def get_train_func(self, config): from revive.utils.tune_utils import VALID_SUMMARY_TYPES from torch.utils.tensorboard import SummaryWriter from ray.air import session def train_func(config): for k,v in self.operator_config.items(): if not k in config.keys(): config[k] = v algo_operator = self.operator(config) writer = SummaryWriter(algo_operator._traj_dir) epoch = 0 while True: algo_operator.before_train_epoch() train_stats = algo_operator.train_epoch() val_stats = algo_operator.validate() session.report({"mean_accuracy": algo_operator._max_reward, "reward_trainPolicy_on_valEnv": val_stats["reward_trainPolicy_on_valEnv"], "stop_flag": val_stats["stop_flag"]}) # write tensorboard for k, v in [*train_stats.items(), *val_stats.items()]: if type(v) in VALID_SUMMARY_TYPES: writer.add_scalar(k, v, global_step=epoch) elif isinstance(v, torch.Tensor): v = v.view(-1) writer.add_histogram(k, v, global_step=epoch) writer.flush() # check stop_flag train_stats.update(val_stats) if train_stats.get('stop_flag', False): break epoch += 1 return train_func
[docs] def get_trainer(self, config): try: train_func = self.get_train_func(config) from ray.air.config import ScalingConfig trainer = TorchTrainer( train_func, train_loop_config=config, scaling_config=ScalingConfig(num_workers=config['workers_per_trial'], use_gpu=config['use_gpu']), ) return trainer except Exception as e: logger.error('Detect Error: {}'.format(e)) raise e
[docs] def get_trainable(self, config): try: trainer = self.get_trainer(config) trainable = trainer.as_trainable() return trainable except Exception as e: logger.error('Detect Error: {}'.format(e)) raise e
[docs] def get_parameters(self, command=None): try: return self.operator.get_parameters(command) except AttributeError: raise AttributeError("Custom algorithm need to implement `get_parameters`") except Exception as e: logger.error('Detect Unknown Error:'.format(e)) raise e
[docs] def get_tune_parameters(self, config): try: self.operator_config = config return self.operator.get_tune_parameters(config) except AttributeError: raise AttributeError("Custom algorithm need to implement `get_tune_parameters`") except Exception as e: logger.error('Detect Unknown Error:'.format(e)) raise e