Source code for revive.algo.policy.base

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

import os
import ray
import json
import math
import psutil
import torch
import warnings
import argparse
import importlib
import traceback
import numpy as np
from ray import tune
from loguru import logger
from copy import deepcopy
from abc import abstractmethod
from concurrent import futures
from ray import train
from ray.train.torch import TorchTrainer
from revive.utils.raysgd_utils import NUM_SAMPLES, AverageMeterCollection
from revive.computation.inference import *
from revive.computation.graph import FunctionDecisionNode
from revive.data.batch import Batch
from revive.data.dataset import data_creator
from ray.train.torch import TorchTrainer
from torch.utils.tensorboard import SummaryWriter
from ray.train import ScalingConfig
from ray.air import session
from revive.utils.common_utils import *
from revive.utils.tune_utils import get_tune_callbacks, CustomSearchGenerator, CustomBasicVariantGenerator, CLIReporter, VALID_SUMMARY_TYPES
from revive.utils.auth_utils import customer_uploadTrainLog
from revive.computation.modules import EnsembleLinear
warnings.filterwarnings('ignore')



[docs] def catch_error(func): '''push the training error message to data buffer''' def wrapped_func(self, *args, **kwargs): try: return func(self, *args, **kwargs) except Exception as e: error_message = traceback.format_exc() self.loguru_logger.error('Detect error:{}, Error Message: {}'.format(e,error_message)) ray.get(self._data_buffer.update_status.remote(self._traj_id, 'error', error_message)) self._stop_flag = True try: customer_uploadTrainLog(self.config["trainId"], os.path.join(os.path.abspath(self._workspace),"revive.log"), "train.policy", "fail", self._max_reward, self.config["accessToken"]) except Exception as e: logger.info(f"{e}") return { 'stop_flag' : True, 'reward_trainPolicy_on_valEnv' : - np.inf, } return wrapped_func
[docs] class PolicyOperator(): @property def env(self): return self.envs_train[0] @property def train_policy(self): return self.train_models[:len(self.policy_name)] @property def val_policy(self): return self.val_models[:len(self.policy_name)] @property def policy(self): return self.train_policy @property def other_train_models(self): return self.train_models[len(self.policy_name):] @property def other_val_models(self): return self.val_models[len(self.policy_name):] # NOTE: you need either write the `PARAMETER_DESCRIPTION` or overwrite `get_parameters` and `get_tune_parameters`. PARAMETER_DESCRIPTION = [] # a list of dict to describe the parameter of the algorithm
[docs] @classmethod def get_parameters(cls, command=None, **kargs): parser = list2parser(cls.PARAMETER_DESCRIPTION) return parser.parse_known_args(command)[0].__dict__
# @classmethod # def get_parameters(cls, command=None, **kargs): # parser = argparse.ArgumentParser() # for description in cls.PARAMETER_DESCRIPTION: # names = ['--' + description['name']] # if "abbreviation" in description.keys(): names.append('-' + description["abbreviation"]) # if type(description['type']) is str: # try: # description['type'] = eval(description['type']) # except Exception as e: # logger.error(f"{e}") # logger.error(f"Please check the .json file for the input type {description['type']} which is not supported!") # logger.error('Detect Error: {}'.format(e)) # raise e # parser.add_argument(*names, type=description['type'], default=description['default']) # return parser.parse_known_args(command)[0].__dict__
[docs] @classmethod def get_tune_parameters(cls, config : Dict[str, Any], **kargs): r""" Use ray.tune to wrap the parameters to be searched. """ _search_algo = config['policy_search_algo'].lower() tune_params = { "name": "policy_tune", "reuse_actors": config["reuse_actors"], "local_dir": config["workspace"], "callbacks": get_tune_callbacks(), "stop": { "stop_flag": True }, "verbose": config["verbose"], } if _search_algo == 'random': random_search_config = {} for description in cls.PARAMETER_DESCRIPTION: if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'continuous': random_search_config[description['name']] = tune.uniform(*description['search_values']) elif description['search_mode'] == 'grid' or description['search_mode'] == 'discrete': random_search_config[description['name']] = tune.choice(description['search_values']) config["total_num_of_trials"] = config['train_policy_trials'] tune_params['config'] = random_search_config tune_params['num_samples'] = config["total_num_of_trials"] tune_params['search_alg'] = CustomBasicVariantGenerator() elif _search_algo == 'zoopt': #from ray.tune.search.zoopt import ZOOptSearch from revive.utils.tune_utils import ZOOptSearch from zoopt import ValueType if config['parallel_num'] == 'auto': if config['use_gpu']: num_of_gpu = int(ray.available_resources()['GPU']) num_of_cpu = int(ray.available_resources()['CPU']) parallel_num = min(int(num_of_gpu / config['policy_gpus_per_worker']), num_of_cpu) else: num_of_cpu = int(ray.available_resources()['CPU']) parallel_num = num_of_cpu else: parallel_num = int(config['parallel_num']) assert parallel_num > 0 config['parallel_num'] = parallel_num dim_dict = {} for description in cls.PARAMETER_DESCRIPTION: if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'continuous': dim_dict[description['name']] = (ValueType.CONTINUOUS, description['search_values'], min(description['search_values'])) elif description['search_mode'] == 'discrete': dim_dict[description['name']] = (ValueType.DISCRETE, description['search_values']) elif description['search_mode'] == 'grid': dim_dict[description['name']] = (ValueType.GRID, description['search_values']) zoopt_search_config = { "parallel_num": config['parallel_num'] } config["total_num_of_trials"] = config['train_policy_trials'] tune_params['search_alg'] = ZOOptSearch( algo="Asracos", # only support Asracos currently budget=config["total_num_of_trials"], dim_dict=dim_dict, metric='reward_trainPolicy_on_valEnv', mode="max", **zoopt_search_config ) tune_params['config'] = dim_dict tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator tune_params['num_samples'] = config["total_num_of_trials"] elif 'grid' in _search_algo: grid_search_config = {} config["total_num_of_trials"] = 1 for description in cls.PARAMETER_DESCRIPTION: if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'grid': grid_search_config[description['name']] = tune.grid_search(description['search_values']) config["total_num_of_trials"] *= len(description['search_values']) else: warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \ f"However, since grid search does not support search type {description['search_mode']}, the parameter is skipped. " + \ f"If this parameter is important to the performance, you should consider other search algorithms. ") tune_params['config'] = grid_search_config tune_params['num_samples'] = config["total_num_of_trials"] tune_params['search_alg'] = CustomBasicVariantGenerator() elif 'bayes' in _search_algo: from ray.tune.search.bayesopt import BayesOptSearch bayes_search_config = {} for description in cls.PARAMETER_DESCRIPTION: if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'continuous': bayes_search_config[description['name']] = description['search_values'] else: warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \ f"However, since bayesian search does not support search type {description['search_mode']}, the parameter is skipped. " + \ f"If this parameter is important to the performance, you should consider other search algorithms. ") config["total_num_of_trials"] = config['train_policy_trials'] tune_params['config'] = bayes_search_config tune_params['search_alg'] = BayesOptSearch(bayes_search_config, metric="reward_trainPolicy_on_valEnv", mode="max") tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator tune_params['num_samples'] = config["total_num_of_trials"] else: raise ValueError(f'search algorithm {_search_algo} is not supported!') reporter = CLIReporter(parameter_columns=list(tune_params['config'].keys()), max_progress_rows=50, max_report_frequency=10, sort_by_metric=True) reporter.add_metric_column("reward_trainPolicy_on_valEnv") tune_params["progress_reporter"] = reporter return tune_params
[docs] def optimizer_creator(self, models : List[torch.nn.Module], config : Dict[str, Any]) -> List[torch.optim.Optimizer]: r""" Define optimizers for the created models. Args: :pmodels: list of all the models :config: configuration parameters Return: a list of optimizers """ raise NotImplementedError
[docs] def data_creator(self): r""" Create DataLoaders. Args: :config: configuration parameters Return: (train_loader, val_loader) """ # raise NotImplementedError logger.warning('data_creator is using the test_horizon' ) return data_creator(self.config, val_horizon=self.config['test_horizon'], double=self.double_validation)
def _setup_data(self): ''' setup data used in training ''' self.train_dataset = ray.get(self.config['dataset']) self.val_dataset = ray.get(self.config['val_dataset']) if not self.double_validation: train_loader_train, val_loader_val = self.data_creator() self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False) self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False) self._train_loader_val = None self._val_loader_val = None else: train_loader_train, val_loader_train, train_loader_val, val_loader_val = self.data_creator() try: self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False) self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False) self._train_loader_val = train.torch.prepare_data_loader(train_loader_val, move_to_device=False) self._val_loader_val = train.torch.prepare_data_loader(val_loader_val, move_to_device=False) except: self._train_loader_train = train_loader_train self._val_loader_train = val_loader_train self._train_loader_val = train_loader_val self._val_loader_val = val_loader_val def _setup_models(self, config : Dict[str, Any]): r'''setup models, optimizers and dataloaders.''' logger.info(f'Init model and move model to {self._device}') # Init train env models and optimizers self.train_nodes = {policy_name: deepcopy(self._graph.get_node(policy_name)) for policy_name in self.policy_name} self.train_models = self.model_creator(self.train_nodes) for model_index, model in enumerate(self.train_models): try: self.train_models[model_index] = train.torch.prepare_model(model,move_to_device=False).to(self._device) except: self.train_models[model_index] = model.to(self._device) self.train_optimizers = self.optimizer_creator(scope="train") self.train_bc_optimizer = self.bc_optimizer_creator(scope="train") for train_node,policy in zip(self.train_nodes.values(), self.train_policy): train_node.set_network(policy) if self.double_validation: # Init val env models and optimizers self.val_nodes = {policy_name: deepcopy(self._graph.get_node(policy_name)) for policy_name in self.policy_name} self.val_models = self.model_creator(self.val_nodes) for model_index, model in enumerate(self.val_models): try: self.val_models[model_index] = train.torch.prepare_model(model,move_to_device=False).to(self._device) except: self.val_models[model_index] = model.to(self._device) self.val_optimizers = self.optimizer_creator(scope="val") self.val_bc_optimizer = self.bc_optimizer_creator(scope="val") for val_node,policy in zip(self.val_nodes.values(), self.val_policy): val_node.set_network(policy) else: self.val_models = None self.val_optimizers = None self.train_bc_optimizer = None
[docs] def model_creator(self, nodes): graph = self._graph policy_name = self.policy_name total_dims = self.config['total_dims'] models = [] # Init policy by behavioral policy if self.config['behavioral_policy_init'] and self.behaviour_policys: for behaviour_policy in self.behaviour_policys: target_policy = deepcopy(behaviour_policy) target_policy.requires_grad_(True) models.append(target_policy) # Init new policy network else: if self.config['behavioral_policy_init'] and self.config["policy_bc_epoch"] == 0: self.config["policy_bc_epoch"] = 200 for policy_name, node in nodes.items(): node_input_dim = get_input_dim_from_graph(graph, policy_name, total_dims) if hasattr(node, 'custom_node'): node_input_dim = get_input_dim_dict_from_graph(graph, policy_name, total_dims) additional_kwargs = {} if self.config.get('ts_conv_nodes', "auto") == 'auto': pass elif policy_name in self.config['ts_conv_nodes'] : temp_input_dim = get_input_dim_dict_from_graph(graph, policy_name, total_dims) has_ts = any('ts' in element for element in temp_input_dim) if not has_ts: pass else: input_dim, input_dim_config, ts_conv_net_config = \ dict_ts_conv(graph, temp_input_dim, total_dims, self.config, net_hidden_features=self.config['policy_hidden_features']) additional_kwargs['ts_conv_config'] = input_dim_config additional_kwargs['ts_conv_net_config'] = ts_conv_net_config logger.info(f"Policy Backbone: {self.config['policy_backbone'],}") input_dim_dict = get_input_dim_dict_from_graph(graph, policy_name, total_dims) node.initialize_network(node_input_dim, total_dims[policy_name]['output'], hidden_features=self.config['policy_hidden_features'], hidden_layers=self.config['policy_hidden_layers'], backbone_type=self.config['policy_backbone'], dist_config=self.config['dist_configs'][policy_name], is_transition=False, input_dim_dict=input_dim_dict, **additional_kwargs) target_policy = node.get_network() models.append(target_policy) return models
[docs] def bc_optimizer_creator(self, scope): if scope == "train": target_policys = self.train_policy else: target_policys = self.val_policy models_params = [] for target_policy in target_policys: models_params.append({'params': target_policy.parameters(), 'lr': 1e-3}) bc_policy_optimizer = torch.optim.Adam(models_params) return bc_policy_optimizer
@catch_error def __init__(self, config : Dict[str, Any]): '''setup everything for training''' # parse information from config self.config = config self._data_buffer = self.config['policy_data_buffer'] self._workspace = self.config["workspace"] self._user_func = self.config['user_func'] self._graph = self.config['graph'] self._processor = self._graph.processor self.policy_name = self.config['target_policy_name'] if isinstance(self.policy_name, str): self.policy_name = [self.policy_name, ] # sort the policy by graph self.policy_name = [policy_name for policy_name in self.policy_name if policy_name in self._graph.keys()] self.double_validation = self.config['policy_double_validation'] self._filename = os.path.join(self._workspace, "train_policy.json") self._data_buffer.set_total_trials.remote(self.config.get("total_num_of_trials", 1)) self._data_buffer.inc_trial.remote() self._max_reward = -np.inf # get id self._ip = ray._private.services.get_node_ip_address() logger.add(os.path.join(os.path.abspath(self._workspace),"revive.log")) self.loguru_logger = logger # set trail seed setup_seed(self.config["global_seed"]) if 'tag' in self.config: # create by tune tag = self.config['tag'] logger.info("{} is running".format(tag)) self._traj_id = int(tag.split('_')[0]) experiment_dir = os.path.join(self._workspace, 'policy_tune') for traj_name in filter(lambda x: "ReviveLog" in x, os.listdir(experiment_dir)): if len(traj_name.split('_')[1]) == 5: # create by random search or grid search id_index = 3 else: id_index = 2 if int(traj_name.split('_')[id_index]) == self._traj_id: self._traj_dir = os.path.join(experiment_dir, traj_name) break else: self._traj_id = 1 self._traj_dir = os.path.join(self._workspace, 'policy_train') update_env_vars("pattern", "policy") # setup constant self._stop_flag = False self._batch_cnt = 0 self._epoch_cnt = 0 self._max_reward = - np.inf self._use_gpu = self.config["use_gpu"] and torch.cuda.is_available() self._device = 'cuda' if self._use_gpu else 'cpu' # fix problem introduced in ray 1.1 # set default config self._speedup = self.config.get("speedup",False) # collect venv env = ray.get(self.config['venv_data_buffer'].get_best_venv.remote()) self.envs = env.env_list for env in self.envs: env.to(self._device) env.requires_grad_(False) env.set_target_policy_name(self.policy_name) if hasattr(env, "train_algo") and env.train_algo == "REVIVE_FILTER": assert "filter" in self.config and "candidate_num" in self.config, f"venv trained from REVIVE_FILTER requires \'filter\' and \'candidate_num\'" assert "ensemble_size" in self.config and "ensemble_choosing_interval" in self.config, f"venv trained from REVIVE_FILTER requires \'ensemble_size\' and \'ensemble_choosing_interval\'" env.reset_ensemble_matcher(ensemble_size=self.config['ensemble_size'], ensemble_choosing_interval=self.config['ensemble_choosing_interval']) logger.info(f"Find {len(self.envs)} envs.") if len(self.envs) == 1: logger.warning('Only one venv found, use it in both training and validation!') self.envs_train = self.envs self.envs_val = self.envs else: self.envs_train = self.envs[:len(self.envs)//2] self.envs_val = self.envs[len(self.envs)//2:] #if self.config['num_venv_in_use'] > len(self.envs_train): # warnings.warn(f"Config requires {self.config['num_venv_in_use']} venvs, but only {len(self.envs_train)} venvs are available.") if self.config['num_venv_in_use'] > len(self.envs_train): logger.info("Adjusting the distribution to generate multiple env models.") mu_shift_list = np.linspace(-0.15, 0.15, num=(self.config['num_venv_in_use']-1)) for mu_shift in mu_shift_list: if mu_shift == 0: continue env_train = deepcopy(self.envs_train[0]) for node in env_train.graph.nodes.values(): if node.node_type == 'network': node.network.dist_mu_shift = mu_shift self.envs_train.append(env_train) env_val = deepcopy(self.envs_train[0]) for node in env_val.graph.nodes.values(): if node.node_type == 'network': node.network.dist_mu_shift = mu_shift self.envs_val.append(env_train) else: self.envs_train = self.envs_train[:int(self.config['num_venv_in_use'])] self.envs_val = self.envs_val[:int(self.config['num_venv_in_use'])] self.behaviour_nodes = [self.envs[0].graph.get_node(policy_name) for policy_name in self.policy_name] self.behaviour_policys = [behaviour_node.get_network() for behaviour_node in self.behaviour_nodes if behaviour_node.get_network()] self.input_is_dict = False for policy in self.behaviour_policys: if policy.__class__.__name__.startswith("Ts"): self.input_is_dict = True # find average reward from expert data dataset = ray.get(self.config['dataset']) train_data = dataset.data train_data.to_torch() train_reward = self._user_func(train_data) self.train_average_reward = float(train_reward.mean()) dataset = ray.get(self.config['val_dataset']) val_data = dataset.data val_data.to_torch() val_reward = self._user_func(train_data) self.val_average_reward = float(val_reward.mean()) # initialize FQE self.fqe_evaluator = None # prepare for training self._setup_data() self._setup_models(self.config) self.disturbing_transition_function = self.config['disturbing_transition_function'] # define numbers of turbu networks if self.disturbing_transition_function: self.rnd_network_nums = self.config['disturbing_net_num'] self.disturbing_node_name = [] self.disturbing_node = {} for node in self._graph.nodes.keys(): if node in self.policy_name: continue if self._graph.nodes[node].node_type == 'network': if self.config['disturbing_nodes']=='auto': pass else: if node in self.config['disturbing_nodes']: pass else: continue self.disturbing_node_name.append(node) _input_nodes = self._graph[node] input_dims = sum([self.config['total_dims'][_node_name]['input'] for _node_name in _input_nodes]) output_dims = self.config['total_dims'][node]['output'] with torch.random.fork_rng(): torch.manual_seed(21312) _rnd_networks = torch.nn.Sequential( EnsembleLinear(input_dims, 32, self.rnd_network_nums), torch.nn.LeakyReLU(), EnsembleLinear(32, 32, self.rnd_network_nums), torch.nn.LeakyReLU(), EnsembleLinear(32, output_dims // 2, self.rnd_network_nums), torch.nn.Tanh() ) self.disturbing_node[node] = {'network': _rnd_networks, 'input_nodes': _input_nodes, 'disturb_weight': self.config['disturbing_weight']} logger.info(f'disturbing_transition_function for node name: {node}, input nodes: {_input_nodes}, input dims: {input_dims}, output dims: {output_dims}') self._save_models(self._traj_dir) self._update_metric() # prepare for ReplayBuffer try: self.setup(self.config) except: pass self.action_dims = {} for policy_name in self.policy_name: action_dims = [] for action_dim in self._graph.descriptions[policy_name]: action_dims.append(list(action_dim.keys())[0]) self.action_dims[policy_name] = action_dims self.nodes_map = {} for node_name in list(self._graph.nodes) + list(self._graph.leaf): node_dims = [] for node_dim in self._graph.descriptions[node_name]: node_dims.append(list(node_dim.keys())[0]) self.nodes_map[node_name] = node_dims self.global_step = 0 def _early_stop(self, info : Dict[str, Any]): info["stop_flag"] = self._stop_flag return info def _update_metric(self): try: policy = torch.load(os.path.join(self._traj_dir, 'policy.pt'), map_location='cpu') except: policy = None self._data_buffer.update_metric.remote(self._traj_id, {"reward" : self._max_reward, "ip" : self._ip, "policy" : policy, "traj_dir": self._traj_dir}) self._data_buffer.write.remote(self._filename) if self._max_reward == ray.get(self._data_buffer.get_max_reward.remote()): self._save_models(self._workspace, with_policy=False) def _wrap_policy(self, policys : List[torch.nn.Module,], device = None) -> PolicyModelDev: policy_nodes = deepcopy(self.behaviour_nodes) policys = deepcopy(policys) for policy_node,policy in zip(policy_nodes,policys): policy_node.set_network(policy) if device: policy_node.to(device) policys = PolicyModelDev(policy_nodes) return policys def _save_models(self, path : str, with_policy : bool = True): torch.save(self.policy, os.path.join(path, "tuned_policy.pt")) if with_policy: policy = self._wrap_policy(self.policy, "cpu") torch.save(policy, os.path.join(path, "policy.pt")) policy = PolicyModel(policy) self.infer_policy = policy with open(os.path.join(self._traj_dir, "policy.pkl"), 'wb') as f: pickle.dump(policy, f) def _get_original_actions(self, batch_data : Batch) -> Batch: # return a list with all actions [o, a_1, a_2, ... o'] # NOTE: we assume key `next_obs` in the data. # NOTE: the leading dimensions will be flattened. original_data = Batch() for policy_name in list(self._graph.leaf) + list(self._graph.keys()): action = batch_data[policy_name].view(-1, batch_data[policy_name].shape[-1]) original_data[policy_name] = self._processor.deprocess_single_torch(action, policy_name) return original_data
[docs] def get_ope_dataset(self): r''' convert the dataset to OPEDataset used in d3pe ''' from d3pe.utils.data import OPEDataset dataset = ray.get(self.config['dataset']) expert_data = dataset.processor.process(dataset.data) expert_data = expert_data expert_data.to_torch() expert_data = generate_rewards(expert_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data))) expert_data.to_numpy() policy_input_names = self._graph.get_node(self.policy_name).input_names if all([node_name in self._graph.transition_map.keys() for node_name in policy_input_names]): obs = np.concatenate([expert_data[node_name] for node_name in policy_input_names], axis=-1) next_obs = np.concatenate([expert_data['next_' + node_name] for node_name in policy_input_names], axis=-1) ope_dataset = OPEDataset(dict( obs=obs.reshape((-1, obs.shape[-1])), action=expert_data[self.policy_name].reshape((-1, expert_data[self.policy_name].shape[-1])), reward=expert_data['reward'].reshape((-1, expert_data['reward'].shape[-1])), done=expert_data['done'].reshape((-1, expert_data['done'].shape[-1])), next_obs=next_obs.reshape((-1, next_obs.shape[-1])), ), start_indexes=dataset._start_indexes) else: data = dict( action=expert_data[self.policy_name][:-1].reshape((-1, expert_data[self.policy_name].shape[-1])), reward=expert_data['reward'][:-1].reshape((-1, expert_data['reward'].shape[-1])), done=expert_data['done'][:-1].reshape((-1, expert_data['done'].shape[-1])), ) expert_data.to_torch() obs = get_input_from_graph(self._graph, self.policy_name, expert_data).numpy() expert_data.to_numpy() data['obs'] = obs[:-1].reshape((-1, obs.shape[-1])) data['next_obs'] = obs[1:].reshape((-1, obs.shape[-1])) ope_dataset = OPEDataset(data, dataset._start_indexes - np.arange(dataset._start_indexes.shape[0])) return ope_dataset
[docs] def venv_test(self, expert_data : Batch, target_policy, traj_length=None, scope : str = 'trainPolicy_on_valEnv'): r""" Use the virtual env model to test the policy model""" rewards = [] envs = self.envs_val if "valEnv" in scope else self.envs_train for env in envs: generated_data, info = self._run_rollout(expert_data, target_policy, env, traj_length, deterministic=self.config['deterministic_test'], clip=True) if 'done' in generated_data.keys(): temp_done = self._processor.deprocess_single_torch(generated_data['done'], 'done') not_done = ~temp_done.bool() temp_reward = not_done * generated_data.reward else: not_done = torch.ones_like(generated_data.reward) temp_reward = not_done * generated_data.reward temp_reward = generated_data.reward * generated_data.valid_masks reward = temp_reward.squeeze(dim=-1) t = torch.arange(0, reward.shape[0]).to(reward) discount = self.config['test_gamma'] ** t discount_reward = torch.sum(discount.unsqueeze(dim=-1) * reward, dim=0) if "done" in generated_data.keys(): rewards.append(discount_reward.mean().item()) else: # rewards.append(discount_reward.mean().item() / self.config['test_horizon']) rewards.append(discount_reward.mean().item() / reward.shape[0]) return {f'reward_{scope}' : np.mean(rewards)}
def _fqe_test(self, target_policy, scope='trainPolicy_on_valEnv'): if 'offlinerl' in str(type(target_policy)): target_policy.get_action = lambda x: target_policy(x) if self.fqe_evaluator is None: # initialize FQE evaluator in the first run from d3pe.evaluator.fqe import FQEEvaluator self.fqe_evaluator = FQEEvaluator() self.fqe_evaluator.initialize(self.get_ope_dataset(), verbose=True) with torch.enable_grad(): info = self.fqe_evaluator(target_policy) for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k) return info def _env_test(self, target_policy, scope='trainPolicy_on_valEnv'): env = create_env(self.config['task']) if env is None: return {} policy = deepcopy(target_policy) policy = policy.to('cpu') policy = self._wrap_policy(policy) policy = PolicyModel(policy) reward, length = test_on_real_env(env, policy) return { f"real_reward_{scope}" : reward, f"real_length_{scope}" : length, } def _run_rollout(self, expert_data : Batch, target_policy, env : Union[VirtualEnvDev, List[VirtualEnvDev]], traj_length : int = None, maintain_grad_flow : bool = False, deterministic : bool = True, clip : bool = False): traj_length = traj_length or expert_data.obs.shape[0] if traj_length >= expert_data.shape[0] and len(self._graph.leaf) > 1: traj_length = expert_data.shape[0] warnings.warn(f'leaf node detected, connot run over the expert trajectory! Reset `traj_length` to {traj_length}!') generated_data = self.generate_rollout(expert_data, target_policy, env, traj_length, maintain_grad_flow, deterministic, clip) if self.config["pre_horzion"] > 0: expert_data = expert_data[self.config["pre_horzion"]:] generated_data = generated_data[self.config["pre_horzion"]:] # if "uncertainty" in generated_data.keys(): # generated_data["reward"] -= self.config["reward_uncertainty_weight"] * generated_data["uncertainty"] # If use the action for multi-steps in env. if self.config["action_steps"] >= 2: index = torch.arange(0,generated_data["reward"].shape[0],self.config["action_steps"]) actions_step_data = Batch() for k,v in generated_data.items(): if k == "reward": continue actions_step_data[k] = v[index] reward = generated_data["reward"] reward_pad_steps = math.ceil(reward.shape[0] / self.config["action_steps"]) * self.config["action_steps"] - reward.shape[0] reward_pad = torch.cat([reward, reward[-1:].repeat(reward_pad_steps,1,1)],axis=0) actions_step_data["reward"] = reward_pad.reshape(self.config["action_steps"], -1, reward_pad.shape[1], reward_pad.shape[2]).mean(axis=0) generated_data = actions_step_data if "done" in generated_data.keys(): info = { "reward": ((generated_data.reward*generated_data.valid_masks).sum()/generated_data.reward.shape[1]).item() } else: info = { "reward": generated_data.reward.mean().item() } return generated_data, info
[docs] def generate_rollout(self, expert_data : Batch, target_policy, env : Union[VirtualEnvDev, List[VirtualEnvDev]], traj_length : int, maintain_grad_flow : bool = False, deterministic : bool = True, clip : bool = False): r"""Generate trajectories based on current policy. Args: :expert_data: sampled data from the dataset. :target_policy: target_policy :env: env :traj_length: traj_length :maintain_grad_flow: maintain_grad_flow Return: batch trajectories """ for policy in target_policy: policy.reset() if isinstance(env, list): for _env in env: _env.reset() else: env.reset() assert len(self._graph.leaf) == 1 or traj_length <= expert_data.shape[0], \ 'There is leaf node on the graph, cannot generate trajectory beyond expert data' generated_data = [] random_select_start_index = self.config.get("random_select_start_index", False) if random_select_start_index: select_index = random.randint(0, expert_data.shape[0] - 1) else: select_index =0 current_batch = Batch({k : expert_data[select_index][k] for k in self._graph.leaf}) random_select_external_factors = self.config.get("random_select_external_factors", False) if random_select_external_factors: current_batch = Batch({k :current_batch[k] if k in self._graph.external_factors else current_batch[k][torch.randperm(current_batch[k].size(0))] for k in self._graph.leaf}) batch_size = current_batch.shape[0] sample_fn = lambda dist: dist.rsample() if maintain_grad_flow else dist.sample() if isinstance(env, list): sample_env_nums = min(min(7,len(env)),batch_size) env_id = random.sample(range(len(env)), k=sample_env_nums) n = int(math.ceil(batch_size / float(sample_env_nums))) env_batch_index = [range(batch_size)[i:min(i + n,batch_size)] for i in range(0, batch_size, n)] # uncertainty_list = [] done = False if self.disturbing_transition_function: dtf_dict = deepcopy(self.disturbing_node) for _node in dtf_dict.keys(): dtf_dict[_node]['rnd_idx'] = np.random.randint(-1, self.rnd_network_nums, (batch_size,)) for i in range(traj_length+1): for policy_index, policy_name in enumerate(self.policy_name): if isinstance(env, list): result_batch = [] use_env_id = -1 for _env_id,_env in enumerate(env): if _env_id not in env_id: continue use_env_id += 1 _current_batch = deepcopy(current_batch) #[env_batch_index[use_env_id],:] _current_batch = _env.pre_computation(_current_batch, deterministic, clip, policy_index, {} if not self.disturbing_transition_function else dtf_dict, filter=self.config['filter'], candidate_num=self.config['candidate_num']) result_batch.append(_current_batch) current_batch = Batch.cat([_current_batch[_env_batch_index,:] for _current_batch,_env_batch_index in zip(result_batch,env_batch_index)]) # policy_inputs = [get_input_from_graph(self._graph, policy_name, _current_batch) for _current_batch in result_batch] # policy_inputs = torch.stack(policy_inputs, dim=2) # policy_inputs_mean = torch.mean(policy_inputs, dim=-1, keepdim=True) # diff = policy_inputs - policy_inputs_mean # uncertainty = torch.max(torch.norm(diff, dim=-1, keepdim=False), dim=1)[0].reshape(-1,1) # if i > 0: # uncertainty_list.append(uncertainty) if i == traj_length: done = True break else: current_batch = env.pre_computation(current_batch, deterministic, clip, policy_index, {} if not self.disturbing_transition_function else dtf_dict, filter=self.config['filter'], candidate_num=self.config['candidate_num']) policy_input = get_input_from_graph(self._graph, policy_name, current_batch) # [batch_size, dim] if hasattr(list(self.train_nodes.values())[policy_index], 'custom_node') or self.input_is_dict: policy_input = get_input_dict_from_graph(self._graph, policy_name, current_batch) if hasattr(list(self.train_nodes.values())[policy_index], 'ts_conv_node'): policy_input = get_input_dict_from_graph(self._graph, policy_name, current_batch) if 'offlinerl' in str(type(target_policy[policy_index])): action = target_policy[policy_index](policy_input) current_batch[policy_name] = action else: # use policy infer if i % self.config["action_steps"] == 0: action_dist = target_policy[policy_index](policy_input) action = sample_fn(action_dist) action_log_prob = (action_dist.log_prob(action).unsqueeze(dim=-1)).detach() current_batch[policy_name + '_log_prob'] = action_log_prob current_batch[policy_name] = action # use the last step action else: current_batch[policy_name + '_log_prob'] = action_log_prob current_batch[policy_name] = deepcopy(action.detach()) if done: break if isinstance(env, list): result_batch = [] use_env_id = -1 for _env_id,_env in enumerate(env): if _env_id not in env_id: continue use_env_id += 1 _current_batch = deepcopy(current_batch)[env_batch_index[use_env_id],:] _current_batch = _env.post_computation(_current_batch, deterministic, clip, policy_index, {} if not self.disturbing_transition_function else dtf_dict, filter=self.config['filter'], candidate_num=self.config['candidate_num']) result_batch.append(_current_batch) current_batch = Batch.cat(result_batch) else: current_batch = env.post_computation(current_batch, deterministic, clip, policy_index, {} if not self.disturbing_transition_function else dtf_dict, filter=self.config['filter'], candidate_num=self.config['candidate_num']) generated_data.append(current_batch) if i == traj_length - 1: break current_batch = Batch(self._graph.state_transition(current_batch)) for k in self._graph.leaf: if not k in self._graph.transition_map.keys(): current_batch[k] = expert_data[i+1][k] # for current_batch, uncertainty in zip(generated_data, uncertainty_list): # current_batch["uncertainty"] = uncertainty generated_data = Batch.stack(generated_data) generated_data = generate_rewards(generated_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data))) if "done" in generated_data.keys(): masks = (generated_data["done"] + 1) / 2 def get_valid_masks(masks): masks = masks.reshape(-1) indices = (masks == 1).nonzero(as_tuple=False) if len(indices) == 0: return (masks+1).reshape(-1,1,1) index = indices[0] loss_mask = torch.zeros_like(masks).to(masks.device) loss_mask[:index + 1] = 1 return loss_mask.reshape(-1,1,1) valid_masks = torch.cat([get_valid_masks(masks[:,id]) for id in range(masks.shape[1])],axis=1) def get_valid_done_masks(masks): masks = masks.reshape(-1) indices = (masks == 1).nonzero(as_tuple=False) if len(indices) == 0: return (masks+1).reshape(-1,1,1) index = indices[0] loss_mask = torch.zeros_like(masks).to(masks.device) loss_mask[index] = 1 return loss_mask.reshape(-1,1,1) valid_done_masks = torch.cat([get_valid_done_masks(masks[:,id]) for id in range(masks.shape[1])],axis=1) #generated_data["reward"] -= valid_done_masks*10. else: masks = torch.zeros_like(generated_data.reward) masks[-1] = torch.ones_like(masks[-1]) valid_masks = torch.ones_like(masks) valid_done_masks = torch.zeros_like(generated_data.reward) generated_data["done_mask"] = masks generated_data["valid_masks"] = valid_masks generated_data["valid_done_masks"] = valid_done_masks return generated_data
[docs] @catch_error def before_train_epoch(self): update_env_vars("policy_epoch",self._epoch_cnt) if hasattr(self, "model"): self.model.train() if hasattr(self, "models"): try: for _model in self.models: if isinstance(_model, torch.nn.Module): _model.train() except: self.models.train() if hasattr(self, "train_models"): try: for _model in self.train_models: if isinstance(_model, torch.nn.Module): _model.train() except: self.train_models.train() if hasattr(self, "val_models"): try: for _model in self.val_models: if isinstance(_model, torch.nn.Module): _model.train() except: self.val_models.train()
[docs] @catch_error def after_train_epoch(self): pass
[docs] @catch_error def before_validate_epoch(self): if hasattr(self, "model"): self.model.eval() if hasattr(self, "models"): try: for _model in self.models: if isinstance(_model, torch.nn.Module): _model.eval() except: self.models.eval() if hasattr(self, "train_models"): try: for _model in self.train_models: if isinstance(_model, torch.nn.Module): _model.eval() except: self.train_models.eval() if hasattr(self, "val_models"): try: for _model in self.val_models: if isinstance(_model, torch.nn.Module): _model.eval() except: self.val_models.eval()
[docs] @catch_error def after_validate_epoch(self): pass
[docs] @catch_error def train_epoch(self): info = dict() self._epoch_cnt += 1 logger.info(f"Policy Train Epoch : {self._epoch_cnt} ") def train_policy_on_train_env(self): metric_meters_train = AverageMeterCollection() for batch_idx, batch in enumerate(iter(self._train_loader_train)): batch_info = { "batch_idx": batch_idx, "global_step": self.global_step } batch_info.update(info) if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0): # batch = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in batch.items()}) batch.to_torch(device=self._device) metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='train') else: metrics = self.train_batch(batch, batch_info=batch_info, scope='train') metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) self.global_step += 1 return metric_meters_train.summary() if self.double_validation: def train_policy_on_val_env(self): metric_meters_val = AverageMeterCollection() for batch_idx, batch in enumerate(iter(self._val_loader_train)): batch_info = { "batch_idx": batch_idx, "global_step": self.global_step } batch_info.update(info) if self._epoch_cnt <= self.config["policy_bc_epoch"]: # batch = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in batch.items()}) batch.to_torch(device=self._device) metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='val') else: metrics = self.train_batch(batch, batch_info=batch_info, scope='val') metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) self.global_step += 1 return metric_meters_val.summary() if self._speedup and self.double_validation: with futures.ThreadPoolExecutor() as executor: metric_meters_train = executor.submit(train_policy_on_train_env, self) metric_meters_val = executor.submit(train_policy_on_val_env, self) futures.wait([metric_meters_train, metric_meters_val]) metric_meters_train = metric_meters_train.result() metric_meters_val = metric_meters_val.result() else: metric_meters_train = train_policy_on_train_env(self) if self.double_validation: metric_meters_val = train_policy_on_val_env(self) info.update(metric_meters_train) if self.double_validation: info.update(metric_meters_val) return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
[docs] @catch_error def bc_train_batch(self, expert_data, batch_info, scope='train'): if scope == 'train': target_policys = self.train_policy optimizer = self.train_bc_optimizer else: target_policys = self.val_policy optimizer = self.val_bc_optimizer info = {} loss = 0 total_action_mae, action_mse = 0, 0 for policy_name, target_policy in zip(self.policy_name, target_policys): target_policy.reset() with torch.no_grad(): state = get_input_from_graph(self._graph, policy_name , expert_data) if hasattr(self.train_nodes[policy_name], 'custom_node'): state = get_input_dict_from_graph(self._graph, policy_name , expert_data) if hasattr(self.train_nodes[policy_name], 'ts_conv_node') and self.train_nodes[policy_name].ts_conv_node : state = get_input_dict_from_graph(self._graph, policy_name , expert_data) action_dist = target_policy(state, field='bc') action_log_prob = action_dist.log_prob(expert_data[policy_name]) if policy_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - expert_data[policy_name + "_isnan_index_"] policy_loss = - (action_log_prob*isnan_index).mean() action_mae = ((action_dist.mode - expert_data[policy_name])*isnan_index).abs().sum(dim=-1).mean().detach() action_mse = (((action_dist.mode - expert_data[policy_name])*isnan_index)**2).sum(dim=-1).mean().detach() else: isnan_index = None policy_loss = - action_log_prob.mean() action_mae = ((action_dist.mode - expert_data[policy_name])).abs().sum(dim=-1).mean().detach() action_mse = ((action_dist.mode - expert_data[policy_name])**2).sum(dim=-1).mean().detach() loss += policy_loss info[f"{self.NAME.lower()}/{policy_name}_policy_bc_nll"] = policy_loss.mean().item() info[f"{self.NAME.lower()}/{policy_name}_policy_bc_mae"] = action_mae.mean().item() info[f"{self.NAME.lower()}/{policy_name}_policy_bc_mse"] = action_mse.mean().item() info[f"{self.NAME.lower()}/{policy_name}_policy_bc_std"] = action_dist.std.mean().item() optimizer.zero_grad() loss.backward() grads = [] for param in target_policy.parameters(): try: grads.append(param.grad.view(-1)) except: continue grads = torch.cat(grads).abs().mean().item() optimizer.step() info[f"{self.NAME.lower()}/total_policy_bc_loss"] = loss.mean().item() info[f"{self.NAME.lower()}/policy_bc_grad"] = grads return info
[docs] @catch_error def validate(self): info = dict() def validate_train_policy(self): info = dict() metric_meters_train = AverageMeterCollection() with torch.no_grad(): for batch_idx, batch in enumerate(iter(self._val_loader_val)): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, scope='trainPolicy_on_valEnv') metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) info = metric_meters_train.summary() if self.double_validation: metric_meters_train = AverageMeterCollection() with torch.no_grad(): for batch_idx, batch in enumerate(iter(self._train_loader_val)): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, scope='trainPolicy_on_trainEnv') metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) info.update(metric_meters_train.summary()) if (not self.config['real_env_test_frequency'] == 0) and \ self._epoch_cnt % self.config['real_env_test_frequency'] == 0: info.update(self._env_test(self.policy)) if (not self.config['fqe_test_frequency'] == 0) and \ self._epoch_cnt % self.config['fqe_test_frequency'] == 0: info.update(self._fqe_test(self.policy)) return info if self.double_validation: def validate_val_policy(self): info = dict() metric_meters_val = AverageMeterCollection() with torch.no_grad(): for batch_idx, batch in enumerate(iter(self._train_loader_val)): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, scope='valPolicy_on_trainEnv') metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) info.update(metric_meters_val.summary()) metric_meters_val = AverageMeterCollection() with torch.no_grad(): for batch_idx, batch in enumerate(iter(self._val_loader_val)): batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, scope='valPolicy_on_valEnv') metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) info.update(metric_meters_val.summary()) if (not self.config['real_env_test_frequency'] == 0) and \ self._epoch_cnt % self.config['real_env_test_frequency'] == 0: info.update(self.env_test(self.val_policy)) if (not self.config['fqe_test_frequency'] == 0) and \ self._epoch_cnt % self.config['fqe_test_frequency'] == 0: info.update(self.fqe_test(self.val_policy)) return info if self._speedup and self.double_validation: with futures.ThreadPoolExecutor() as executor: validate_train_policy_info = executor.submit(validate_train_policy, self) validate_val_policy_info = executor.submit(validate_val_policy, self) futures.wait([validate_train_policy_info, validate_val_policy_info]) validate_train_policy_info = validate_train_policy_info.result() validate_val_policy_info = validate_val_policy_info.result() else: validate_train_policy_info = validate_train_policy(self) if self.double_validation: validate_val_policy_info = validate_val_policy(self) info.update(validate_train_policy_info) if self.double_validation: info.update(validate_val_policy_info) info = {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())} if self.double_validation: reward_flag = "reward_trainPolicy_on_valEnv" else: reward_flag = "reward_trainenv" if info[reward_flag] > self._max_reward: self._max_reward = info[reward_flag] self._save_models(self._traj_dir) self._update_metric() info["stop_flag"] = self._stop_flag if self._stop_flag: # revive online server try: customer_uploadTrainLog(self.config["trainId"], os.path.join(os.path.abspath(self._workspace),"revive.log"), "train.policy", "success", self._max_reward, self.config["accessToken"]) except Exception as e: logger.info(f"{e}") # save double validation plot after training if self.double_validation: plt_double_venv_validation(self._traj_dir, self.train_average_reward, self.val_average_reward, os.path.join(self._traj_dir, 'double_validation.png')) try: graph = self.envs_train[0].graph for policy_index, policy_name in enumerate(self.policy_name): graph.nodes[policy_name] = deepcopy(self.infer_policy)._policy_model.nodes[policy_index] # save rolllout action image rollout_save_path = os.path.join(self._traj_dir, 'rollout_images') save_rollout_action(rollout_save_path, graph, self._device, self.train_dataset, self.nodes_map, pre_horzion=self.config["pre_horzion"], train_type='Policy') # policy to tree and plot the tree tree_save_path = os.path.join(self._traj_dir, 'policy_tree') net_to_tree(tree_save_path, graph, self._device, self.train_dataset, self.action_dims, rnn_dataset=self.config['rnn_dataset']) except Exception as e: logger.info(f"{e}") return info
[docs] def train_batch(self, expert_data : Batch, batch_info : Dict[str, float], scope : str = 'train'): raise NotImplementedError
[docs] def validate_batch(self, expert_data : Batch, batch_info : Dict[str, float], scope : str = 'trainPolicy_on_valEnv'): expert_data.to_torch(device=self._device) if not self.double_validation: info = self.venv_test(expert_data, self.policy, traj_length=self.config['test_horizon'], scope="trainenv") elif "trainPolicy" in scope: info = self.venv_test(expert_data, self.policy, traj_length=self.config['test_horizon'], scope=scope) elif "valPolicy" in scope: info = self.venv_test(expert_data, self.val_policy, traj_length=self.config['test_horizon'], scope=scope) return info
[docs] class PolicyAlgorithm: def __init__(self, algo : str, workspace: str =None): self.algo = algo self.workspace = workspace try: self.algo_module = importlib.import_module(f'revive.dist.algo.policy.{self.algo.split(".")[0]}') logger.info(f"Import encryption policy algorithm module -> {self.algo}!") except: self.algo_module = importlib.import_module(f'revive.algo.policy.{self.algo.split(".")[0]}') logger.info(f"Import policy algorithm module -> {self.algo}!") # Assume there is only one operator other than PolicyOperator for k in self.algo_module.__dir__(): if 'Operator' in k and not k == 'PolicyOperator': self.operator = getattr(self.algo_module, k) self.operator_config = {}
[docs] def get_train_func(self, config={}): def train_func(config): for k,v in self.operator_config.items(): if not k in config.keys(): config[k] = v algo_operator = self.operator(config) policy_val_freq = config.get("policy_val_freq", 1) writer = SummaryWriter(algo_operator._traj_dir) while True: algo_operator.before_train_epoch() train_stats = algo_operator.train_epoch() algo_operator.after_train_epoch() epoch = algo_operator._epoch_cnt # Validate an epoch, support set model validation frequency by config["policy_val_freq"] if (epoch - 1) % policy_val_freq == 0 or epoch == 1: with torch.no_grad(): algo_operator.before_validate_epoch() val_stats = algo_operator.validate() algo_operator.after_validate_epoch() # Report model metric to ray session.report({"mean_accuracy": algo_operator._max_reward, "reward_trainPolicy_on_valEnv": val_stats["reward_trainPolicy_on_valEnv"], "stop_flag": val_stats["stop_flag"]}) # Write log to tensorboard for k, v in [*train_stats.items(), *val_stats.items()]: if type(v) in VALID_SUMMARY_TYPES: writer.add_scalar(k, v, global_step=epoch) elif isinstance(v, torch.Tensor): v = v.view(-1) writer.add_histogram(k, v, global_step=epoch) writer.flush() # Check stop_flag train_stats.update(val_stats) if train_stats.get('stop_flag', False): break torch.cuda.empty_cache() return train_func
[docs] def get_trainer(self, config): try: train_func = self.get_train_func(config) trainer_resources = None """ try: cluster_resources = ray.cluster_resources() num_gpus = cluster_resources["GPU"] num_cpus = cluster_resources["CPU"] num_gpus_blocks = int(1/config['policy_gpus_per_worker']) * num_gpus resources_per_trial_cpu = max(int(num_cpus/num_gpus_blocks),1) trainer_resources = {"GPU": config['policy_gpus_per_worker'],"CPU":resources_per_trial_cpu} except: trainer_resources = None """ trainer = TorchTrainer( train_func, train_loop_config=config, scaling_config=ScalingConfig(trainer_resources=trainer_resources, num_workers=config['workers_per_trial'], use_gpu=config['use_gpu']), ) return trainer except Exception as e: logger.error('Detect Error: {}'.format(e)) raise e
[docs] def get_trainable(self, config): try: trainer = self.get_trainer(config) trainable = trainer.as_trainable() return trainable except Exception as e: logger.error('Detect Error: {}'.format(e)) raise e
[docs] def get_parameters(self, command=None): try: return self.operator.get_parameters(command) except AttributeError: raise AttributeError("Custom algorithm need to implement `get_parameters`") except Exception as e: logger.error('Detect Unknown Error:'.format(e)) raise e
[docs] def get_tune_parameters(self, config): try: self.operator_config = config return self.operator.get_tune_parameters(config) except AttributeError: raise AttributeError("Custom algorithm need to implement `get_tune_parameters`") except Exception as e: logger.error('Detect Unknown Error:'.format(e)) raise e