Source code for revive.algo.venv.base

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

import os 
import ray
import json
import psutil
import shutil
import torch
import pickle
import warnings
import importlib
import traceback
import numpy as np
from scipy import stats
from loguru import logger
from copy import deepcopy
from concurrent import futures
from ray import tune
from ray import train
from ray.train.torch import TorchTrainer
from revive.computation.graph import DesicionGraph
from revive.computation.inference import *
from revive.data.batch import Batch
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import NUM_SAMPLES, AverageMeterCollection
from revive.utils.tune_utils import get_tune_callbacks, CustomSearchGenerator, CustomBasicVariantGenerator, CLIReporter
from revive.utils.common_utils import *
from revive.utils.auth_utils import customer_uploadTrainLog

from revive.utils.tune_utils import VALID_SUMMARY_TYPES
from torch.utils.tensorboard import SummaryWriter
from ray.train import ScalingConfig
from ray.air import session

import time

warnings.filterwarnings('ignore')

[docs] def catch_error(func): '''push the training error message to data buffer''' def wrapped_func(self, *args, **kwargs): try: return func(self, *args, **kwargs) except Exception as e: error_message = traceback.format_exc() self.logru_logger.error('Detect error:{}, Error Message: {}'.format(e,error_message)) ray.get(self._data_buffer.update_status.remote(self._traj_id, 'error', error_message)) self._stop_flag = True try: customer_uploadTrainLog(self.config["trainId"], os.path.join(os.path.abspath(self._workspace),"revive.log"), "train.policy", "fail", self._max_reward, self.config["accessToken"]) except Exception as e: logger.info(f"{e}") return { 'stop_flag' : True, 'reward_trainPolicy_on_valEnv' : - np.inf, } return wrapped_func
[docs] class VenvOperator(): r""" The base venv class.validate_epoch """ NAME = None # this need to be set in any subclass r""" Name of the used algorithm. """ @property def metric_name(self): r""" This define the metric we try to minimize with hyperparameter search. """ return f"{self.NAME}/average_{self.config['venv_metric']}" @property def nodes_models_train(self): return self.train_models[:self.config['learning_nodes_num']] @property def other_models_train(self): return self.train_models[self.config['learning_nodes_num']:] @property def nodes_models_val(self): return self.val_models[:self.config['learning_nodes_num']] @property def other_models_val(self): return self.val_models[self.config['learning_nodes_num']:] # NOTE: you need either write the `PARAMETER_DESCRIPTION` or overwrite `get_parameters` and `get_tune_parameters`. PARAMETER_DESCRIPTION = [] # a list of dict to describe the parameter of the algorithm
[docs] @classmethod def get_parameters(cls, command=None, **kargs): parser = list2parser(cls.PARAMETER_DESCRIPTION) return parser.parse_known_args(command)[0].__dict__
[docs] @classmethod def get_tune_parameters(cls, config : dict, **kargs): r""" Use ray.tune to wrap the parameters to be searched. """ _search_algo = config['venv_search_algo'].lower() tune_params = { "name": "venv_tune", "reuse_actors": config["reuse_actors"], "local_dir": config["workspace"], "callbacks": get_tune_callbacks(), "stop": { "stop_flag": True }, "verbose": config["verbose"], } if _search_algo == 'random': random_search_config = {} for description in cls.PARAMETER_DESCRIPTION: if 'tune' in description.keys() and not description["tune"]: continue if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'continuous': random_search_config[description['name']] = tune.uniform(*description['search_values']) elif description['search_mode'] == 'grid' or description['search_mode'] == 'discrete': random_search_config[description['name']] = tune.choice(description['search_values']) config["total_num_of_trials"] = config['train_venv_trials'] tune_params['config'] = random_search_config tune_params['num_samples'] = config["total_num_of_trials"] tune_params['search_alg'] = CustomBasicVariantGenerator() elif _search_algo == 'zoopt': # from ray.tune.search.zoopt import ZOOptSearch from revive.utils.tune_utils import ZOOptSearch from zoopt import ValueType if config['parallel_num'] == 'auto': if config['use_gpu']: num_of_gpu = int(ray.available_resources()['GPU']) num_of_cpu = int(ray.available_resources()['CPU']) parallel_num = min(int(num_of_gpu / config['venv_gpus_per_worker']), num_of_cpu) else: num_of_cpu = int(ray.available_resources()['CPU']) parallel_num = num_of_cpu else: parallel_num = int(config['parallel_num']) assert parallel_num > 0 config['parallel_num'] = parallel_num dim_dict = {} for description in cls.PARAMETER_DESCRIPTION: if 'tune' in description.keys() and not description["tune"]: continue if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'continuous': dim_dict[description['name']] = (ValueType.CONTINUOUS, description['search_values'], min(description['search_values'])) elif description['search_mode'] == 'discrete': dim_dict[description['name']] = (ValueType.DISCRETE, description['search_values']) elif description['search_mode'] == 'grid': dim_dict[description['name']] = (ValueType.GRID, description['search_values']) zoopt_search_config = { "parallel_num": config['parallel_num'] } config["total_num_of_trials"] = config['train_venv_trials'] tune_params['search_alg'] = ZOOptSearch( algo="Asracos", # only support Asracos currently budget=config["total_num_of_trials"], dim_dict=dim_dict, metric='least_metric', mode="min", **zoopt_search_config ) tune_params['config'] = dim_dict tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator tune_params['num_samples'] = config["total_num_of_trials"] elif 'grid' in _search_algo: grid_search_config = {} config["total_num_of_trials"] = 1 for description in cls.PARAMETER_DESCRIPTION: if 'tune' in description.keys() and not description["tune"]: continue if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'grid': grid_search_config[description['name']] = tune.grid_search(description['search_values']) config["total_num_of_trials"] *= len(description['search_values']) else: warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \ f"However, since grid search does not support search type {description['search_mode']}, the parameter is skipped. " + \ f"If this parameter is important to the performance, you should consider other search algorithms. ") tune_params['config'] = grid_search_config tune_params['num_samples'] = config["total_num_of_trials"] tune_params['search_alg'] = CustomBasicVariantGenerator() elif 'bayes' in _search_algo: from ray.tune.search.bayesopt import BayesOptSearch bayes_search_config = {} for description in cls.PARAMETER_DESCRIPTION: if 'tune' in description.keys() and not description["tune"]: continue if 'search_mode' in description.keys() and 'search_values' in description.keys(): if description['search_mode'] == 'continuous': bayes_search_config[description['name']] = description['search_values'] else: warnings.warn(f"Detect parameter {description['name']} is define as searchable in `PARAMETER_DESCRIPTION`. " + \ f"However, since bayesian search does not support search type {description['search_mode']}, the parameter is skipped. " + \ f"If this parameter is important to the performance, you should consider other search algorithms. ") config["total_num_of_trials"] = config['train_venv_trials'] tune_params['config'] = bayes_search_config tune_params['search_alg'] = BayesOptSearch(bayes_search_config, metric="least_metric", mode="min") tune_params['search_alg'] = CustomSearchGenerator(tune_params['search_alg']) # wrap with our generator tune_params['num_samples'] = config["total_num_of_trials"] else: raise ValueError(f'search algorithm {_search_algo} is not supported!') reporter = CLIReporter(parameter_columns=list(tune_params['config'].keys()), max_progress_rows=50, max_report_frequency=10, sort_by_metric=True) reporter.add_metric_column("least_metric", representation='least metric loss') reporter.add_metric_column("now_metric", representation='current metric loss') tune_params["progress_reporter"] = reporter return tune_params
[docs] def model_creator(self, config : dict, graph : DesicionGraph): r""" Create all the models. The algorithm needs to define models for the nodes to be learned. Args: :config: configuration parameters Return: a list of models """ raise NotImplementedError
[docs] def optimizer_creator(self, models : List[torch.nn.Module], config : dict): r""" Define optimizers for the created models. Args: :pmodels: list of all the models :config: configuration parameters Return: a list of optimizers """ raise NotImplementedError
[docs] def data_creator(self): r""" Create DataLoaders. Args: :config: configuration parameters Return: (train_loader, val_loader) """ return data_creator(self.config, training_mode='transition', val_horizon=self.config['venv_rollout_horizon'], double=True)
def _setup_componects(self): r'''setup models, optimizers and dataloaders.''' # register data loader for double venv training policy_backbone = self.config['policy_backbone'] transition_backbone = self.config['transition_backbone'] if len(self.graph_train.ts_node_frames): logger.warning("If the network defines frame splicing nodes as inputs for network nodes, \ it is recommended to use 'gru' or 'lstm' backbone, which usually results in better performance.") if self.config["venv_train_dataset_mode"] != "trajectory" and len(self.graph_train.ts_node_frames)==0: if 'lstm' in policy_backbone or 'gru' in policy_backbone or 'rnn' in policy_backbone: self.config["venv_train_dataset_mode"] = "trajectory" logger.warning(f'Discovered the use of policy backbone -> {policy_backbone}, automatically switching to trajectory data loading mode') elif 'lstm' in transition_backbone or 'gru' in transition_backbone or 'rnn' in transition_backbone: self.config["venv_train_dataset_mode"] = "trajectory" logger.warning(f'Discovered the use of transition backbone -> {transition_backbone}, automatically switching to trajectory data loading mode') elif self.config["pre_horzion"] > 0: logger.warning(f'The "pre_horzion" parameter has been configured. Switch dataloader to trajectory mode!') self.config["venv_train_dataset_mode"] = "trajectory" else: pass train_loader_train, val_loader_train, train_loader_val, val_loader_val = self.data_creator() try: self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False) self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False) self._train_loader_val = train.torch.prepare_data_loader(train_loader_val, move_to_device=False) self._val_loader_val = train.torch.prepare_data_loader(val_loader_val, move_to_device=False) except: self._train_loader_train = train_loader_train self._val_loader_train = val_loader_train self._train_loader_val = train_loader_val self._val_loader_val = val_loader_val self.train_models = self.model_creator(self.config, self.graph_train) logger.info(f'Move model to {self._device}') for model_index, model in enumerate(self.train_models): try: self.train_models[model_index] = train.torch.prepare_model(model,move_to_device=False).to(self._device) except: self.train_models[model_index] = model.to(self._device) self.train_optimizers = self.optimizer_creator(self.train_models, self.config) self.val_models = self.model_creator(self.config, self.graph_val) for model_index, model in enumerate(self.val_models): try: self.val_models[model_index] = train.torch.prepare_model(model,move_to_device=False).to(self._device) except: self.val_models[model_index] = model.to(self._device) self.val_optimizers = self.optimizer_creator(self.val_models, self.config) def _register_models_to_graph(self, graph : DesicionGraph, models : List[torch.nn.Module]): index = 0 # register policy nodes for node_name in list(graph.keys()): if node_name in graph.transition_map.values(): continue node = graph.get_node(node_name) if node.node_type == 'network': node.set_network(models[index]) index += 1 # register transition nodes for node_name in graph.transition_map.values(): node = graph.get_node(node_name) if node.node_type == 'network': node.set_network(models[index]) index += 1 assert len(models) == index, f'Some models are not registered. Total models: {len(models)}, Registered: {index}.' @catch_error def __init__(self, config : dict): r'''setup everything for training. Args: :config: configuration parameters ''' # parse information from config self.config = config self.train_dataset = ray.get(self.config['dataset']) self.val_dataset = ray.get(self.config['val_dataset']) if self.config['rnn_dataset']: self.rnn_train_dataset = ray.get(self.config['rnn_dataset']) self.rnn_val_dataset = ray.get(self.config['rnn_val_dataset']) self._data_buffer = self.config['venv_data_buffer'] self._workspace = self.config["workspace"] self._graph = deepcopy(self.config['graph']) self._filename = os.path.join(self._workspace, "train_venv.json") self._data_buffer.set_total_trials.remote(self.config.get("total_num_of_trials", 1)) self._data_buffer.inc_trial.remote() self._least_metric_train = [np.inf] * len(self._graph.metric_nodes) self._least_metric_val = [np.inf] * len(self._graph.metric_nodes) self.least_val_metric = np.inf self.least_train_metric = np.inf self._acc = 0 self._num_venv_list = [] # get id self._ip = ray._private.services.get_node_ip_address() logger.add(os.path.join(os.path.abspath(self._workspace),"revive.log")) self.logru_logger = logger # set trail seed setup_seed(self.config["global_seed"]) if 'tag' in self.config: # create by tune tag = self.config['tag'] logger.info("{} is running".format(tag)) self._traj_id = int(tag.split('_')[0]) experiment_dir = os.path.join(self._workspace, 'venv_tune') traj_name_list = sorted(os.listdir(experiment_dir),key=lambda x:x[-19:], reverse=True) for traj_name in filter(lambda x: "ReviveLog" in x, traj_name_list): if len(traj_name.split('_')[1]) == 5: # create by random search or grid search id_index = 3 else: id_index = 2 if int(traj_name.split('_')[id_index]) == self._traj_id: self._traj_dir = os.path.join(experiment_dir, traj_name) break else: self._traj_id = 1 self._traj_dir = os.path.join(self._workspace, 'venv_train') update_env_vars("pattern", "venv") # setup constant self._stop_flag = False self._batch_cnt = 0 self._epoch_cnt = 0 self._last_wdist_epoch = 0 self._wdist_id_train = None self._wdist_id_val = None self._wdist_ready_train = False self._wdist_ready_val = False self._wdist_ready = False self._use_gpu = self.config["use_gpu"] and torch.cuda.is_available() self._device = 'cuda' if self._use_gpu else 'cpu' self._speedup = self.config.get("speedup",False) # if "shooting_" in self.config['venv_metric']: # self.config['venv_metric'] = self.config['venv_metric'].replace("shooting_", "rollout_") if self.config['venv_metric'] in ["mae","mse","nll"]: self.config['venv_metric'] = "rollout_" + self.config['venv_metric'] # prepare for training self.graph_train = deepcopy(self._graph) self.graph_val = deepcopy(self._graph) self._setup_componects() # register models to graph self._register_models_to_graph(self.graph_train, self.nodes_models_train) self._register_models_to_graph(self.graph_val, self.nodes_models_val) self.total_dim = 0 for node_name in self._graph.metric_nodes: self.total_dim += self.config['total_dims'][node_name]['input'] self.nodes_dim_name_map = {} for node_name in list(self._graph.nodes) + list(self._graph.leaf): node_dims = [] for node_dim in self._graph.descriptions[node_name]: node_dims.append(list(node_dim.keys())[0]) self.nodes_dim_name_map[node_name] = node_dims logger.info(f"Nodes : {self.nodes_dim_name_map}") self.graph_to_save_train = self.graph_train self.graph_to_save_val = self.graph_val self.best_graph_train = deepcopy(self.graph_train) self.best_graph_val = deepcopy(self.graph_val) if self._traj_dir.endswith("venv_train"): if "venv.pkl" in os.listdir(self._traj_dir): logger.info("Find existing checkpoint, Back up existing model.") try: self._traj_dir_bak = self._traj_dir+"_bak" self._filename_bak = self._filename+"_bak" shutil.copytree(self._traj_dir, self._traj_dir_bak ) shutil.copy(self._filename, self._filename_bak) except: pass self.best_graph_train_info = {} self.best_graph_val_info = [] self._save_models(self._traj_dir) self._update_metric() self.global_step = 0
[docs] def nan_in_grad(self): if hasattr(self, "nums_nan_in_grad"): self.nums_nan_in_grad += 1 else: self.nums_nan_in_grad = 1 if self.nums_nan_in_grad > 100: self._stop_flag = True logger.warning(f'Find too many nan in loss. Early stop.')
def _early_stop(self, info : dict): info["stop_flag"] = self._stop_flag return info def _update_metric(self): ''' update metric to data buffer ''' # NOTE: Very large model (e.g. 1024 x 4 LSTM) cannot be updated. try: venv_train = torch.load(os.path.join(self._traj_dir, 'venv_train.pt'), map_location='cpu') venv_val = torch.load(os.path.join(self._traj_dir, 'venv_val.pt'), map_location='cpu') except: venv_train = venv_val = None try: metric = self.least_metric except: metric = np.sum(self._least_metric_train) / self.total_dim if 'mse' in self.metric_name or 'mae' in self.metric_name: acc = (self.config['max_distance'] - np.log(metric)) / (self.config['max_distance'] - self.config['min_distance']) else: acc = (self.config['max_distance'] - metric) / (self.config['max_distance'] - self.config['min_distance']) acc = min(max(acc,0),1) acc = 0.5 + (0.4*acc) if acc == 0.5: acc = 0 self._acc = acc self._data_buffer.update_metric.remote(self._traj_id, { "metric": metric, "acc" : acc, "ip": self._ip, "venv_train" : venv_train, "venv_val" : venv_val, "traj_dir" : self._traj_dir, "epoch" : self._epoch_cnt, }) self._data_buffer.write.remote(self._filename) # if metric == ray.get(self._data_buffer.get_least_metric.remote()): # self._save_models(self._workspace, with_env=False) def _save_models(self, path: str, with_env:bool=True, save_best_graph:bool=True, model_prefixes:str=""): """ param: path, where to save the models param: with_env, whether to save venv along with the models """ if model_prefixes: model_prefixes = model_prefixes + "_" if save_best_graph: # logger.info(f"Save the best model: {os.path.join(path, model_prefixes + 'venv.pkl')} ") graph_train = deepcopy(self.best_graph_train).to("cpu") graph_val = deepcopy(self.best_graph_val).to("cpu") else: # logger.info(f"Save the latest model: {os.path.join(path, model_prefixes + 'venv.pkl')} ") graph_train = deepcopy(self.graph_train).to("cpu") graph_val = deepcopy(self.graph_val).to("cpu") graph_train.reset() graph_val.reset() # Save train model for checkpoint if save_best_graph: # Save train model for node_name in graph_train.keys(): node = graph_train.get_node(node_name) if node.node_type == 'network': network = deepcopy(node.get_network()).cpu() torch.save(network, os.path.join(path, node_name + '_train.pt')) # Save val model for node_name in graph_val.keys(): node = graph_val.get_node(node_name) if node.node_type == 'network': network = deepcopy(node.get_network()).cpu() torch.save(network, os.path.join(path, node_name + '_val.pt')) if with_env: venv_train = VirtualEnvDev(graph_train, train_algo=self.NAME, info=self.best_graph_train_info) torch.save(venv_train, os.path.join(path, "venv_train.pt")) venv_val = VirtualEnvDev(graph_val, train_algo=self.NAME, info=self.best_graph_val_info) torch.save(venv_val, os.path.join(path, "venv_val.pt")) venv = VirtualEnv([venv_train, venv_val]) with open(os.path.join(path, model_prefixes + 'venv.pkl'), 'wb') as f: pickle.dump(venv, f) venv_list = ray.get(self._data_buffer.get_best_venv.remote()) with open(os.path.join(path, model_prefixes +'ensemble_env.pkl'), 'wb') as f: pickle.dump(venv_list, f) def _load_best_models(self): best_graph_train = deepcopy(self.graph_train) best_graph_val = deepcopy(self.graph_val) for node_name in best_graph_train.keys(): best_node = best_graph_train.get_node(node_name) current_node = self.graph_train.get_node(node_name) if best_node.node_type == 'network': current_node.get_network().load_state_dict(best_node.get_network().state_dict()) for node_name in best_graph_val.keys(): best_node = best_graph_val.get_node(node_name) current_node = self.graph_val.get_node(node_name) if best_node.node_type == 'network': current_node.get_network().load_state_dict(best_node.get_network().state_dict()) @torch.no_grad() def _log_histogram(self, expert_data, generated_data, scope='valEnv_on_trainData'): if scope == 'valEnv_on_trainData': graph = self.graph_val else: graph = self.graph_train info = {} for node_name in graph.keys(): # compute values if graph.get_node(node_name).node_type == 'network': node_dist = graph.compute_node(node_name, expert_data) expert_action = expert_data[node_name] generated_action = node_dist.sample() policy_std = node_dist.std else: expert_action = expert_data[node_name] generated_action = graph.compute_node(node_name, expert_data) policy_std = torch.zeros_like(generated_action) # make logs error = expert_action - generated_action expert_action = expert_action.cpu() generated_action = generated_action.cpu() error = error.cpu() policy_std = policy_std.cpu() for i in range(error.shape[-1]): info[f'{node_name}_dim{i}_{scope}/error'] = error.select(-1, i) info[f'{node_name}_dim{i}_{scope}/expert'] = expert_action.select(-1, i) info[f'{node_name}_dim{i}_{scope}/sampled'] = generated_action.select(-1, i) info[f'{node_name}_dim{i}_{scope}/sampled_std'] = policy_std.select(-1, i) return info def _env_test(self, scope : str = 'train'): # pick target policy graph = self.graph_train if scope == 'train' else self.graph_val node = graph.get_node(self.config['target_policy_name']) env = create_env(self.config['task']) if env is None: return {} graph.reset() node = deepcopy(node) node = node.to('cpu') policy = PolicyModelDev(node) policy = PolicyModel(policy) reward, length = test_on_real_env(env, policy) return { f"{self.NAME}/real_reward_{scope}" : reward, f"{self.NAME}/real_length_{scope}" : length, } def _mse_test(self, expert_data, generated_data, scope='valEnv_on_trainData', loss_mask=None): info = {} if not self.config['mse_test'] and not 'mse' in self.metric_name: return info if 'mse' in self.metric_name: self.graph_to_save_train = self.graph_train self.graph_to_save_val = self.graph_val if scope == 'valEnv_on_trainData': graph = self.graph_val else: graph = self.graph_train graph.reset() if loss_mask is None: loss_mask = 1. new_data = Batch({name : expert_data[name] for name in graph.leaf}) total_mse = 0 for node_name in graph.keys(): if node_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"]) else: isnan_index = 1. if graph.get_node(node_name).node_type == 'network': node_dist = graph.compute_node(node_name, new_data) new_data[node_name] = node_dist.mode else: new_data[node_name] = graph.compute_node(node_name, new_data) continue if node_name in graph.metric_nodes: # if isnan_index is not None: node_mse = (((new_data[node_name] - expert_data[node_name])*isnan_index*loss_mask) ** 2).sum(dim=-1).mean() # else: # node_mse = ((new_data[node_name] - expert_data[node_name]) ** 2).sum(dim=-1).mean() total_mse += node_mse.item() info[f"{self.NAME}/{node_name}_one_step_mse_{scope}"] = node_mse.item() info[f"{self.NAME}/average_one_step_mse_{scope}"] = total_mse / self.total_dim mse_error = 0 for node_name in graph.metric_nodes: if node_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"]) else: isnan_index = 1. # if isnan_index is not None: policy_rollout_mse = (((expert_data[node_name] - generated_data[node_name])*isnan_index*loss_mask) ** 2).sum(dim=-1).mean() # else: # policy_rollout_mse = ((expert_data[node_name] - generated_data[node_name]) ** 2).sum(dim=-1).mean() mse_error += policy_rollout_mse.item() info[f"{self.NAME}/{node_name}_rollout_mse_{scope}"] = policy_rollout_mse.item() info[f"{self.NAME}/average_rollout_mse_{scope}"] = mse_error / self.total_dim return info def _nll_test(self, expert_data, generated_data, scope='valEnv_on_trainData', loss_mask=None): info = {} if not self.config['nll_test'] and not 'nll' in self.metric_name: return info if 'nll' in self.metric_name: self.graph_to_save_train = self.graph_train self.graph_to_save_val = self.graph_val if scope == 'valEnv_on_trainData': graph = self.graph_val else: graph = self.graph_train graph.reset() if loss_mask is None: loss_mask = 1. new_data = Batch({name : expert_data[name] for name in graph.leaf}) total_nll = 0 for node_name in graph.keys(): if node_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"],axis=-1) else: isnan_index = 1. if node_name in graph.learnable_node_names and node_name in graph.metric_nodes: node_dist = graph.compute_node(node_name, new_data) new_data[node_name] = node_dist.mode else: new_data[node_name] = expert_data[node_name] continue # if isnan_index is not None: node_nll = - (node_dist.log_prob(expert_data[node_name])*isnan_index*loss_mask.squeeze(-1)).mean() # else: # node_nll = - node_dist.log_prob(expert_data[node_name]).mean() total_nll += node_nll.item() info[f"{self.NAME}/{node_name}_one_step_nll_{scope}"] = node_nll.item() info[f"{self.NAME}/average_one_step_nll_{scope}"] = total_nll / self.total_dim total_nll = 0 for node_name in graph.metric_nodes: if node_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - expert_data[node_name + "_isnan_index_"] else: isnan_index = 1. if node_name in graph.learnable_node_names: node_dist = graph.compute_node(node_name, expert_data) # if isnan_index is not None: policy_nll = - (node_dist.log_prob(expert_data[node_name])) if isinstance(isnan_index, torch.Tensor) and isnan_index.shape != policy_nll.shape: isnan_index = isnan_index.reshape(*policy_nll.shape) policy_nll = - (node_dist.log_prob(expert_data[node_name])*isnan_index*loss_mask.squeeze(-1)).mean() # else: # policy_nll = - node_dist.log_prob(expert_data[node_name]).mean() total_nll += policy_nll.item() info[f"{self.NAME}/{node_name}_rollout_nll_{scope}"] = policy_nll.item() else: total_nll += 0 info[f"{self.NAME}/{node_name}_rollout_nll_{scope}"] = 0 info[f"{self.NAME}/average_rollout_nll_{scope}"] = total_nll / self.total_dim return info def _mae_test(self, expert_data, generated_data, scope='valEnv_on_trainData', loss_mask=None): info = {} if not self.config['mae_test'] and not 'mae' in self.metric_name: return info if 'mae' in self.metric_name: self.graph_to_save_train = self.graph_train self.graph_to_save_val = self.graph_val if scope == 'valEnv_on_trainData': graph = self.graph_val else: graph = self.graph_train graph.reset() if loss_mask is None: loss_mask = 1. new_data = Batch({name : expert_data[name] for name in graph.leaf}) total_mae = 0 for node_name in graph.keys(): if node_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"]) else: isnan_index = 1. if graph.get_node(node_name).node_type == 'network': node_dist = graph.compute_node(node_name, new_data) new_data[node_name] = node_dist.mode else: new_data[node_name] = graph.compute_node(node_name, new_data) continue if node_name in graph.metric_nodes: # if isnan_index is not None: node_mae = ((new_data[node_name] - expert_data[node_name])*isnan_index*loss_mask).abs().sum(dim=-1).mean() # else: # node_mae = (new_data[node_name] - expert_data[node_name]).abs().sum(dim=-1).mean() total_mae += node_mae.item() info[f"{self.NAME}/{node_name}_one_step_mae_{scope}"] = node_mae.item() info[f"{self.NAME}/average_one_step_mae_{scope}"] = total_mae / self.total_dim mae_error = 0 for node_name in graph.keys(): if node_name in graph.metric_nodes: if node_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"]) else: isnan_index = 1. # if isnan_index is not None: policy_shooting_error = (torch.abs(expert_data[node_name] - generated_data[node_name])*isnan_index*loss_mask).sum(dim=-1).mean() # else: # policy_shooting_error = torch.abs(expert_data[node_name] - generated_data[node_name]).sum(dim=-1).mean() mae_error += policy_shooting_error.item() info[f"{self.NAME}/{node_name}_rollout_mae_{scope}"] = policy_shooting_error.item() # TODO: plot rollout error # rollout_error = torch.abs(expert_data[node_name] - generated_data[node_name]).reshape(expert_data.shape[0],-1).mean(dim=-1) info[f"{self.NAME}/average_rollout_mae_{scope}"] = mae_error / self.total_dim return info def _wdist_test(self, expert_data, generated_data, scope='valEnv_on_trainData', loss_mask=None): info = {} if not self.config['wdist_test'] and not 'wdist' in self.metric_name: return info if 'wdist' in self.metric_name: self.graph_to_save_train = self.graph_train self.graph_to_save_val = self.graph_val if scope == 'valEnv_on_trainData': graph = self.graph_val else: graph = self.graph_train graph.reset() if loss_mask is None: loss_mask = 1. wdist_error = [] for node_name in graph.keys(): if node_name in graph.metric_nodes: # TODO: support isnan_index node_dim = expert_data[node_name].shape[-1] if node_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"],axis=-1) else: isnan_index = None wdist = [stats.wasserstein_distance(expert_data[node_name].reshape(-1, expert_data[node_name].shape[-1])[..., dim].cpu().numpy(), generated_data[node_name].reshape(-1, generated_data[node_name].shape[-1])[..., dim].cpu().numpy()) for dim in range(node_dim)] wdist = np.sum(wdist) info[f"{self.NAME}/{node_name}_wdist_{scope}"] = wdist wdist_error.append(wdist) info[f"{self.NAME}/average_wdist_{scope}"] = np.sum(wdist_error) / self.total_dim return info
[docs] @catch_error def before_train_epoch(self): # Update env vars update_env_vars("venv_epoch",self._epoch_cnt) for model in self.train_models: model.train() for model in self.val_models: model.train()
[docs] @catch_error def after_train_epoch(self): pass
[docs] @catch_error def before_validate_epoch(self): for model in self.train_models: model.eval() for model in self.val_models: model.eval()
[docs] @catch_error def after_validate_epoch(self): pass
[docs] @catch_error def train_epoch(self): info = dict() self._epoch_cnt += 1 logger.info(f"Venv Train epoch : {self._epoch_cnt} ") def algo_on_train_data(self): metric_meters_train = AverageMeterCollection() for batch_idx, batch in enumerate(iter(self._train_loader_train)): batch_info = { "batch_idx": batch_idx, "global_step": self.global_step } batch_info.update(info) batch.to_torch(device=self._device) metrics = self.train_batch(batch, batch_info=batch_info, scope='train') metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) self.global_step += 1 torch.cuda.empty_cache() return metric_meters_train def algo_on_val_data(self): metric_meters_val = AverageMeterCollection() for batch_idx, batch in enumerate(iter(self._val_loader_train)): batch_info = { "batch_idx": batch_idx, "global_step": self.global_step } batch_info.update(info) batch.to_torch(device=self._device) metrics = self.train_batch(batch, batch_info=batch_info, scope='val') metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) self.global_step += 1 torch.cuda.empty_cache() return metric_meters_val if self._speedup: with futures.ThreadPoolExecutor() as executor: algo_on_train_data_result = executor.submit(algo_on_train_data, self) algo_on_val_data_result = executor.submit(algo_on_val_data, self) futures.wait([algo_on_train_data_result, algo_on_val_data_result]) metric_meters_train = algo_on_train_data_result.result() metric_meters_val = algo_on_val_data_result.result() else: metric_meters_train = algo_on_train_data(self) metric_meters_val = algo_on_val_data(self) info = metric_meters_train.summary() info.update(metric_meters_val.summary()) return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
[docs] @catch_error def validate_epoch(self): # logger.info(f"Validate Epoch : {self._epoch_cnt} ") info = dict() self.logged_histogram = { 'valEnv_on_trainData' : False, 'trainEnv_on_valData' : False, } def train_model_validate_on_val_data(self): with torch.no_grad(): metric_meters_train = AverageMeterCollection() for batch_idx, batch in enumerate(iter(self._train_loader_val)): if self.config['rnn_dataset']: batch, loss_mask = batch else: loss_mask = torch.tensor([1]) loss_mask = loss_mask.to(self._device) batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, 'valEnv_on_trainData', loss_mask=loss_mask) metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) if batch_idx > 128: break torch.cuda.empty_cache() return metric_meters_train def val_model_validate_on_train_data(self): with torch.no_grad(): metric_meters_val = AverageMeterCollection() for batch_idx, batch in enumerate(iter(self._val_loader_val)): if self.config['rnn_dataset']: batch, loss_mask = batch else: loss_mask = torch.tensor([1]) loss_mask = loss_mask.to(self._device) batch_info = {"batch_idx": batch_idx} batch_info.update(info) metrics = self.validate_batch(batch, batch_info, 'trainEnv_on_valData', loss_mask=loss_mask) metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) if batch_idx > 128: break torch.cuda.empty_cache() return metric_meters_val if self._speedup: with futures.ThreadPoolExecutor() as executor: train_model_validate_on_val_data_result = executor.submit(train_model_validate_on_val_data, self) val_model_validate_on_train_data_result = executor.submit(val_model_validate_on_train_data, self) futures.wait([train_model_validate_on_val_data_result, val_model_validate_on_train_data_result]) metric_meters_train = train_model_validate_on_val_data_result.result() metric_meters_val = val_model_validate_on_train_data_result.result() else: metric_meters_train = train_model_validate_on_val_data(self) metric_meters_val = val_model_validate_on_train_data(self) info = metric_meters_train.summary() info.update(metric_meters_val.summary()) info = {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())} # Run test on real environment if (not self.config['real_env_test_frequency'] == 0) and \ self._epoch_cnt % self.config['real_env_test_frequency'] == 0: info.update(self._env_test('train')) info.update(self._env_test('val')) # Save model by metric """ # Save the model by node metric if self._epoch_cnt >= self.config['save_start_epoch']: need_update = [] for i, node_name in enumerate(self.graph_train.metric_nodes): i = self.graph_train.metric_nodes.index(node_name) if info[f"{self.NAME}/{node_name}_{self.config['venv_metric']}_trainEnv_on_valData"] < self._least_metric_train[i]: self._least_metric_train[i] = info[f"{self.NAME}/{node_name}_{self.config['venv_metric']}_trainEnv_on_valData"] self.best_graph_train.nodes[node_name] = deepcopy(self.graph_to_save_train.get_node(node_name)) need_update.append(True) else: need_update.append(False) for i, node_name in enumerate(self.graph_val.metric_nodes): i = self.graph_val.metric_nodes.index(node_name) if info[f"{self.NAME}/{node_name}_{self.config['venv_metric']}_valEnv_on_trainData"] < self._least_metric_val[i]: self._least_metric_val[i] = info[f"{self.NAME}/{node_name}_{self.config['venv_metric']}_valEnv_on_trainData"] self.best_graph_val.nodes[node_name] = deepcopy(self.graph_to_save_val.get_node(node_name)) need_update.append(True) else: need_update.append(False) if self.config["save_by_node"]: info["least_metric"] = np.sum(self._least_metric_train) / self.total_dim self.least_metric = info["least_metric"] if True in need_update: self._save_models(self._traj_dir) self._update_metric() """ info["stop_flag"] = self._stop_flag info['now_metric'] = info[f'{self.metric_name}_trainEnv_on_valData'] info['now_val_metric'] = info[f'{self.metric_name}_valEnv_on_trainData'] if not self.config["save_by_node"]: # now_val_metric = info[f'{self.metric_name}_valEnv_on_trainData'] # Save the best model if self._epoch_cnt >= self.config['save_start_epoch']: if info["now_metric"] <= self.least_train_metric or info["now_val_metric"] <= self.least_val_metric: if info["now_val_metric"] <= self.least_val_metric: self.least_val_metric = info['now_val_metric'] self.best_graph_val = deepcopy(self.graph_val) self.best_graph_val_info = { "id" : self._traj_id, "epoch" : self._epoch_cnt, } if info["now_metric"] <= self.least_train_metric: self.least_train_metric = info["now_metric"] self.best_graph_train = deepcopy(self.graph_train) self.best_graph_train_info = { "id" : self._traj_id, "epoch" : self._epoch_cnt, } self._save_models(self._traj_dir, save_best_graph=True) self._update_metric() info["least_metric"] = self.least_train_metric self.least_metric = self.least_train_metric # Periodic preservation model if self.config["venv_save_frequency"]: if self._epoch_cnt % self.config["venv_save_frequency"] == 0: self._save_models(self._traj_dir, save_best_graph=False, model_prefixes = str(self._epoch_cnt)) # Save the k env for every task _k = 1 while self._epoch_cnt % _k == 0: if len(self._num_venv_list) < self.config['num_venv_store']: self._num_venv_list.append(info["now_metric"]) elif info["now_metric"] < np.max(self._num_venv_list): _del_index = np.argmax(self._num_venv_list) self._data_buffer.delet_deque_item.remote(self._traj_id, _del_index) self._num_venv_list.pop(_del_index) self._num_venv_list.append(info["now_metric"]) else: break _graph_train = deepcopy(self.graph_train).to("cpu") _graph_val = deepcopy(self.graph_val).to("cpu") _graph_train.reset() _graph_val.reset() _venv_train = VirtualEnvDev(_graph_train, train_algo=self.NAME) _venv_val = VirtualEnvDev(_graph_val, train_algo=self.NAME) self._data_buffer.update_venv_deque_dict.remote(self._traj_id, _venv_train, _venv_val) break # Record validation metrics for k in list(info.keys()): if self.NAME in k: v = info.pop(k) info['VAL_' + k] = v if self._epoch_cnt >= self._total_epoch: self._stop_flag = True info["stop_flag"] = self._stop_flag '''plot histogram when training is finished''' # [ OTHER ] more frequent valuation if self._stop_flag or self._epoch_cnt % self.config["rollout_plt_frequency"] == 0: if self.config["rollout_plt_frequency"] > 0: histogram_path = os.path.join(self._traj_dir, 'histogram') if not os.path.exists(histogram_path): os.makedirs(histogram_path) try: save_histogram(histogram_path, self.best_graph_train, self._train_loader_val, device=self._device, scope='train', rnn_dataset=self.config['rnn_dataset']) save_histogram(histogram_path, self.best_graph_val, self._val_loader_val, device=self._device, scope='val', rnn_dataset=self.config['rnn_dataset']) if self.config['rnn_dataset']: if self.config["rollout_dataset_mode"] == "validate": rollout_dataset = self.rnn_val_dataset else: rollout_dataset = self.rnn_train_dataset else: if self.config["rollout_dataset_mode"] == "validate": rollout_dataset = self.val_dataset else: rollout_dataset = self.train_dataset # save rolllout action image rollout_save_path = os.path.join(self._traj_dir, 'rollout_images') nodes_map = deepcopy(self.nodes_dim_name_map) # del step_node if "step_node_" in nodes_map.keys(): nodes_map.pop("step_node_") save_rollout_action(rollout_save_path, self.best_graph_train, self._device, rollout_dataset, deepcopy(nodes_map), rollout_plt_length=self.config.get("rollout_plt_length",None), pre_horzion=self.config["pre_horzion"], train_type='Venv') rollout_kwargs = { "graph": self.best_graph_train, "rollout_save_path": rollout_save_path, "device": self._device, "nodes_map": nodes_map, "pre_horizon": self.config["pre_horzion"], "rollout_plt_length": self.config.get("rollout_plt_length",None), } with open(os.path.join(self._traj_dir, 'rollout_kwargs.pkl'), 'wb') as file: pickle.dump(rollout_kwargs, file) # [ OTHER ] not only plotting the best model, but also plotting the result of the current model rollout_save_path = os.path.join(self._traj_dir, 'rollout_images_current') save_rollout_action(rollout_save_path, self.graph_train, self._device, rollout_dataset, deepcopy(nodes_map), rollout_plt_length=self.config.get("rollout_plt_length",None), pre_horzion=self.config["pre_horzion"], train_type='Venv') except Exception as e: logger.warning(e) else: logger.info("Don't plot images.") # Attempt to load trained model parameters as initialization parameters if self._epoch_cnt == 1: try: info = self._load_checkpoint(info) except: logger.info("Don't Load checkpoint!") if hasattr(self, "_traj_dir_bak") and os.path.exists(self._traj_dir_bak): shutil.rmtree(self._traj_dir_bak) os.remove(self._filename_bak) if self._stop_flag: # Draw a response curve images if self.config["plt_response_curve"]: response_curve_path = os.path.join(self._traj_dir, 'response_curve') if not os.path.exists(response_curve_path): os.makedirs(response_curve_path) dataset = ray.get(self.config['dataset']) plot_response_curve(response_curve_path, self.best_graph_train, self.best_graph_val, dataset=dataset.data, device=self._device) try: customer_uploadTrainLog(self.config["trainId"], os.path.join(os.path.abspath(self._workspace),"revive.log"), "train.simulator", "success", self._acc, self.config["accessToken"]) except Exception as e: error_message = traceback.format_exc() error_message = "" logger.info('Detect warning:{}, Warning Message: {}'.format(e,error_message)) return info
def _load_checkpoint(self,info): if self._traj_dir.endswith("venv_train"): self._load_models(self._traj_dir_bak) with open(self._filename_bak, 'r') as f: train_log = json.load(f) metric = train_log["metrics"]["1"] venv_train = torch.load(os.path.join(self._traj_dir_bak, 'venv_train.pt'), map_location='cpu') venv_val = torch.load(os.path.join(self._traj_dir_bak, 'venv_val.pt'), map_location='cpu') self._data_buffer.update_metric.remote(self._traj_id, { "metric": metric["metric"], "acc" : metric["acc"], "ip": self._ip, "venv_train" : venv_train, "venv_val" : venv_val, "traj_dir" : self._traj_dir, }) self._data_buffer.write.remote(self._filename) self._save_models(self._workspace, with_env=False) return info else: with open(os.path.join(self._traj_dir,"params.json"), 'r') as f: params = json.load(f) experiment_dir = os.path.dirname(self._traj_dir) dir_name_list = [dir_name for dir_name in os.listdir(experiment_dir) if dir_name.startswith("ReviveLog")] dir_name_list = sorted(dir_name_list,key=lambda x:x[-19:], reverse=False) for dir_name in dir_name_list: dir_path = os.path.join(experiment_dir, dir_name) if dir_path == self._traj_dir: break if os.path.isdir(dir_path): params_json_path = os.path.join(dir_path, "params.json") if os.path.exists(params_json_path): with open(params_json_path, 'r') as f: history_params = json.load(f) if history_params == params: result_json_path = os.path.join(dir_path, "result.json") with open(result_json_path, 'r') as f: history_result = [] for line in f.readlines(): history_result.append(json.loads(line)) if history_result:# and history_result[-1]["stop_flag"]: for k in info.keys(): if k in history_result[-1].keys(): info[k] = history_result[-1][k] self.least_metric = info["least_metric"] # load model logger.info("Find exist checkpoint, Load the model.") self._load_models(dir_path) self._save_models(self._traj_dir) self._update_metric() # check early stop self._stop_flag = history_result[-1]["stop_flag"] info["stop_flag"] = self._stop_flag logger.info("Load checkpoint success!") break return info def _load_models(self, path : str, with_env : bool = True): """ param: path, where to load the models param: with_env, whether to load venv along with the models """ for node_name in self.best_graph_train.keys(): best_node = self.best_graph_train.get_node(node_name) if best_node.node_type == 'network': best_node.get_network().load_state_dict(torch.load(os.path.join(path, node_name + '_train.pt')).state_dict()) for node_name in self.best_graph_val.keys(): best_node = self.best_graph_val.get_node(node_name) if best_node.node_type == 'network': best_node.get_network().load_state_dict(torch.load(os.path.join(path, node_name + '_train.pt')).state_dict())
[docs] def train_batch(self, expert_data, batch_info, scope='train'): r"""Define the training process for an batch data.""" raise NotImplementedError
[docs] def validate_batch(self, expert_data, batch_info, scope='valEnv_on_trainData', loss_mask=None): r"""Define the validate process for an batch data. Args: expert_data: The batch offline Data. batch_info: A batch info dict. scope: if ``scope=valEnv_on_trainData`` means training data test on the model trained by validation dataset. """ info = {} if scope == 'valEnv_on_trainData': graph = self.graph_val else: graph = self.graph_train graph.reset() expert_data.to_torch(device=self._device) traj_length = expert_data.shape[0] # data is generated by taking the most likely action sample_fn = lambda dist: dist.mode generated_data = generate_rollout(expert_data, graph, traj_length, sample_fn, clip=True) if self.config["pre_horzion"] > 0: expert_data = expert_data[self.config["pre_horzion"]:] generated_data = generated_data[self.config["pre_horzion"]:] for node_name in graph.nodata_node_names: assert node_name not in expert_data.keys() assert node_name in generated_data.keys() expert_data[node_name] = generated_data[node_name] info.update(self._nll_test(expert_data, generated_data, scope, loss_mask=loss_mask)) info.update(self._mse_test(expert_data, generated_data, scope, loss_mask=loss_mask)) info.update(self._mae_test(expert_data, generated_data, scope, loss_mask=loss_mask)) # info.update(self._wdist_test(expert_data, generated_data, scope)) # log histogram info if (not self.config['histogram_log_frequency'] == 0) and \ self._epoch_cnt % self.config['histogram_log_frequency'] == 0 and \ not self.logged_histogram[scope]: info.update(self._log_histogram(expert_data, generated_data, scope)) self.logged_histogram[scope] = True return info
[docs] class VenvAlgorithm: ''' Class use to manage venv algorithms ''' def __init__(self, algo : str, workspace: str =None): self.algo = algo self.workspace = workspace if self.algo == "revive" or self.algo == "revive_p" or self.algo == "revivep" or self.algo == "revive_ppo": self.algo = "revive_p" elif self.algo == "revive_f" or self.algo == "revivef" or self.algo == "revive_filter": self.algo = "revive_f" elif self.algo == "revive_t" or self.algo == "revivet" or self.algo == "revive_td3": self.algo = "revive_t" elif self.algo == "bc": self.algo = "bc" else: raise NotImplementedError try: self.algo_module = importlib.import_module(f'revive.dist.algo.venv.{self.algo}') logger.info(f"Import encryption venv algorithm module -> {self.algo}!") except: self.algo_module = importlib.import_module(f'revive.algo.venv.{self.algo}') logger.info(f"Import venv algorithm module -> {self.algo}!") # Assume there is only one operator other than VenvOperator for k in self.algo_module.__dir__(): if 'Operator' in k and not k == 'VenvOperator': self.operator = getattr(self.algo_module, k) self.operator_config = {}
[docs] def get_train_func(self, config={}): # The training function to execute on each worker def train_func(config): for k,v in self.operator_config.items(): if not k in config.keys(): config[k] = v algo_operator = self.operator(config) venv_val_freq = config.get("venv_val_freq", 1) writer = SummaryWriter(algo_operator._traj_dir) while True: # Train an epoch algo_operator.before_train_epoch() train_stats = algo_operator.train_epoch() algo_operator.after_train_epoch() epoch = algo_operator._epoch_cnt # Validate an epoch, support set model validation frequency by config["venv_val_freq"] if (epoch - 1) % venv_val_freq == 0 or epoch == 1: with torch.no_grad(): algo_operator.before_validate_epoch() val_stats = algo_operator.validate_epoch() algo_operator.after_validate_epoch() # Report model metric to ray session.report({"mean_accuracy": algo_operator._acc, "least_metric": val_stats["least_metric"], "now_metric": val_stats["now_metric"], "stop_flag": val_stats["stop_flag"]}) # Write log to tensorboard for k, v in [*train_stats.items(), *val_stats.items()]: if type(v) in VALID_SUMMARY_TYPES: writer.add_scalar(k, v, global_step=epoch) elif isinstance(v, torch.Tensor): v = v.view(-1) writer.add_histogram(k, v, global_step=epoch) writer.flush() # Check stop flag train_stats.update(val_stats) if train_stats.get('stop_flag', False): break torch.cuda.empty_cache() return train_func
[docs] def get_trainer(self, config): try: train_func = self.get_train_func(config) trainer_resources = None """ try: cluster_resources = ray.cluster_resources() num_gpus = cluster_resources["GPU"] num_cpus = cluster_resources["CPU"] num_gpus_blocks = int(1/config['venv_gpus_per_worker']) * num_gpus resources_per_trial_cpu = max(int(num_cpus/num_gpus_blocks),1) trainer_resources = {"GPU": config['venv_gpus_per_worker'],"CPU":resources_per_trial_cpu} except: trainer_resources = None """ trainer = TorchTrainer( train_func, train_loop_config=config, scaling_config=ScalingConfig(trainer_resources=trainer_resources, num_workers=config['workers_per_trial'], use_gpu=config['use_gpu']), ) return trainer except Exception as e: logger.error('Detect Error: {}'.format(e)) raise e
[docs] def get_trainable(self, config): try: trainer = self.get_trainer(config) # Converts trainer to a tune.Trainable object trainable = trainer.as_trainable() return trainable except Exception as e: logger.error('Detect Error: {}'.format(e)) raise e
[docs] def get_parameters(self, command=None): try: return self.operator.get_parameters(command) except AttributeError: raise AttributeError("Custom algorithm need to implement `get_parameters`") except Exception as e: logger.error('Detect Unknown Error:'.format(e)) raise e
[docs] def get_tune_parameters(self, config): try: self.operator_config = config return self.operator.get_tune_parameters(config) except AttributeError: raise AttributeError("Custom algorithm need to implement `get_tune_parameters`") except Exception as e: logger.error('Detect Unknown Error:'.format(e)) raise e