Source code for revive.algo.venv.bc

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

import torch
from loguru import logger

from revive.computation.graph import *
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.batch import Batch
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.algo.venv.base import VenvOperator, catch_error


[docs] class BCOperator(VenvOperator): NAME = "REVIVE_VENV" #"BC" PARAMETER_DESCRIPTION = [ { "name" : "bc_batch_size", "description": "Batch size of training process.", "abbreviation" : "bbs", "type" : int, "default" : 256, 'doc': True, }, { "name" : "bc_epoch", "description": "Number of epcoh for the training process", "abbreviation" : "bep", "type" : int, "default" : 500, 'doc': True, }, { "name" : "bc_horizon", "abbreviation" : "bh", "type" : int, "default" : 10, }, { "name" : "policy_hidden_features", "description": "Number of neurons per layer of the policy network.", "abbreviation" : "phf", "type" : int, "default" : 256, 'doc': True, }, { "name" : "policy_hidden_layers", "description": "Depth of policy network.", "abbreviation" : "phl", "type" : int, "default" : 4, "search_mode" : "grid", "search_values" : [3, 4, 5], 'doc': True, }, { "name" : "policy_activation", "abbreviation" : "pa", "type" : str, "default" : 'leakyrelu', }, { "name" : "policy_normalization", "abbreviation" : "pn", "type" : str, "default" : None, }, { "name" : "policy_backbone", "description": "Backbone of policy network. Support selecting from [mlp, res, ft_transformer, lstm, gru].", "abbreviation" : "pb", "type" : str, "default" : "res", 'doc': True, }, { "name" : "transition_hidden_features", "abbreviation" : "thf", "type" : int, "default" : 256, 'doc': True, }, { "name" : "transition_hidden_layers", "abbreviation" : "thl", "type" : int, "default" : 3, 'doc': True, }, { "name" : "transition_activation", "abbreviation" : "ta", "type" : str, "default" : 'leakyrelu', }, { "name" : "transition_normalization", "abbreviation" : "tn", "type" : str, "default" : None, }, { "name" : "transition_backbone", "description": "Backbone of Transition network. Support selecting from [mlp, res, ft_transformer, lstm, gru].", "abbreviation" : "tb", "type" : str, "default" : "res", 'doc': True, }, { "name" : "g_lr", "description": "Initial learning rate of the training process.", "type" : float, "default" : 1e-4, "search_mode" : "continuous", "search_values" : [1e-6, 1e-3], 'doc': True, }, { "name" : "weight_decay", "abbreviation" : "wd", "type" : float, "default" : 1e-4, }, { "name" : "lr_decay", "abbreviation" : "ld", "type" : float, "default" : 0.99, }, { "name" : "loss_type", "description": 'Bc support different loss function("nll", "mae", "mse").', "type" : str, "default" : "nll", 'doc': True, }, { "name" : "bc_l2_coef", "type" : float, "default" : 5e-5, }, { "name" : "logstd_loss_coef", "type" : float, "default" : 0.01, }, ] @catch_error def __init__(self, config): super().__init__(config) self._total_epoch = self.config['bc_epoch'] def _setup_componects(self): super()._setup_componects() # add lr decay here # self.train_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.train_optimizers[0], gamma=self.config['lr_decay']) # self.val_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.val_optimizers[0], gamma=self.config['lr_decay']) self.loss_type = self.config["loss_type"]
[docs] def model_creator(self, config : dict, graph : DesicionGraph): """ Create policies and transition, if needed. :param config: configuration parameters :return: list of all models """ total_dims = config['total_dims'] networks = [] for node_name in list(graph.keys()): if node_name in graph.transition_map.values(): continue node = graph.get_node(node_name) if not node.node_type == 'network': logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.') continue input_dim = get_input_dim_from_graph(graph, node_name, total_dims) if hasattr(node, 'custom_node'): input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims) input_dim_dict = input_dim else: input_dim_dict = get_input_dim_dict_from_graph(graph, node_name, total_dims) node.initialize_network(input_dim, total_dims[node_name]['output'], hidden_features=config['policy_hidden_features'], hidden_layers=config['policy_hidden_layers'], hidden_activation=config['policy_activation'], norm=config['policy_normalization'], backbone_type=config['policy_backbone'], dist_config=config['dist_configs'][node_name], is_transition=False, input_dim_dict=input_dim_dict,) networks.append(node.get_network()) # initialize transition node networks for node_name in graph.transition_map.values(): node = graph.get_node(node_name) if not node.node_type == 'network': logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.') continue if node_name[5:] not in node.input_names: transition_mode = 'global' if config['transition_mode']: warnings.warn('Fallback to global transition mode since the transition variable is not provided as an input!') else: transition_mode = config['transition_mode'] input_dim = get_input_dim_from_graph(graph, node_name, total_dims) input_dim_dict = get_input_dim_dict_from_graph(graph, node_name, total_dims) if hasattr(node, 'custom_node'): input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims) node.initialize_network(input_dim, total_dims[node_name]['output'], hidden_features=config['transition_hidden_features'], hidden_layers=config['transition_hidden_layers'], hidden_activation=config['transition_activation'], norm=config['transition_normalization'], backbone_type=config['transition_backbone'], dist_config=config['dist_configs'][node_name], is_transition=True, transition_mode=transition_mode, obs_dim=total_dims[node_name]['input'], input_dim_dict=input_dim_dict,) networks.append(node.get_network()) assert len(networks) > 0, 'at least one node need to be a network to run training!' return networks
[docs] def optimizer_creator(self, models, config): optimizer = torch.optim.Adam(get_models_parameters(*models), lr=config['g_lr'], weight_decay=config['weight_decay']) return [optimizer]
[docs] def data_creator(self): self.config[BATCH_SIZE] = self.config['bc_batch_size'] if self.config["venv_train_dataset_mode"] == "trajectory": return data_creator(self.config, training_mode="trajectory", training_horizon=self.config['bc_horizon'], training_is_sample=False, val_horizon=self.config['venv_rollout_horizon'], pre_horzion=self.config['pre_horzion'], double=True) else: return data_creator(self.config, training_mode='transition', training_is_sample=False, val_horizon=self.config['venv_rollout_horizon'], pre_horzion=self.config['pre_horzion'], double=True)
@catch_error def train_epoch(self): results = super().train_epoch() # self.train_scheduler.step() # self.val_scheduler.step() return results
[docs] def train_batch(self, expert_data, batch_info, scope='train'): self._batch_cnt += 1 if scope == 'train': models = self.train_models graph = self.graph_train optimizer = self.train_optimizers[0] else: models = self.val_models graph = self.graph_val optimizer = self.val_optimizers[0] graph.reset() info = {} loss = 0 dataset_mode = self.config["venv_train_dataset_mode"] if dataset_mode == "trajectory": sample_fn = lambda dist: dist.mode generated_data, dist_dict = generate_rollout(expert_data, graph, expert_data.shape[0], sample_fn, clip=1.5, return_dist=True, mode="train") if self.config["pre_horzion"] > 0: expert_data = expert_data[self.config["pre_horzion"]:] generated_data = generated_data[self.config["pre_horzion"]:] for node_name in graph.keys(): node = graph.get_node(node_name) if node.node_type == 'network': # Missing value processing by isnan_index isnan_index_list = [] isnan_index = 1. # check whether nan is in inputs for node_name_ in node.input_names: if (node_name_ + "_isnan_index_") in expert_data.keys(): isnan_index_list.append(expert_data[node_name_ + "_isnan_index_"]) # check whether nan is in outputs if (node_name + "_isnan_index_") in expert_data.keys(): isnan_index_list.append(expert_data[node_name + "_isnan_index_"]) if isnan_index_list: isnan_index, _ = torch.max(torch.cat(isnan_index_list, axis=-1), axis=-1, keepdim=True) loss_mask = 1 - isnan_index else: loss_mask = 1. if dataset_mode == "trajectory": action_dist = dist_dict[node_name] action = generated_data[node_name] else: action_dist = graph.compute_node(node_name, expert_data) action = action_dist.mode # Empty node processing, use rollout data as expert data for nodata nodes if node_name in self._graph.nodata_node_names: if dataset_mode == "trajectory": expert_data[node_name] = generated_data[node_name] else: expert_data[node_name] = action_dist.mode continue # Calculate loss _loss_type = graph.nodes_loss_type.get(node_name, self.loss_type) if _loss_type == "mae": policy_loss = ((action - expert_data[node_name])*loss_mask).abs().sum(dim=-1).mean() elif _loss_type == "mse": policy_loss = (((action - expert_data[node_name])*loss_mask)**2).sum(dim=-1).mean() elif _loss_type == "nll" or _loss_type == "log_prob": if dataset_mode == "trajectory": policy_loss = 0 for traj_index in range(expert_data.shape[0]): if isinstance(loss_mask, float): policy_loss -= (dist_dict[node_name][traj_index].log_prob(expert_data[node_name][traj_index])).mean() else: policy_loss -= (dist_dict[node_name][traj_index].log_prob(expert_data[node_name][traj_index])*loss_mask[traj_index]).mean() else: policy_loss = - (action_dist.log_prob(expert_data[node_name])*loss_mask).mean() # policy_loss += self.net_l2_norm(node.network) * self.config["bc_l2_coef"] # policy_loss += self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].max_logstd.sum() \ # - self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].min_logstd.sum() elif _loss_type == "gaussian_loss": # assert not rnn_flag, "gaussian_loss does not support RNN yet" target = expert_data[node_name] mean, std = action_dist.mode, action_dist.std var = torch.pow(std, 2) # Average over batch and dim, sum over ensembles. mse_loss_inv = (torch.pow(mean - target, 2) / torch.maximum(var, torch.ones_like(var, device=self._device) * 1e-6) * loss_mask).mean() var_loss = torch.log(torch.maximum(var, torch.ones_like(var, device=self._device) * 1e-6)).mean() policy_loss = mse_loss_inv + var_loss # policy_loss += self.net_l2_norm(node.network) * self.config["bc_l2_coef"] # policy_loss += self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].max_logstd.sum() \ # - self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].min_logstd.sum() info[f"{self.NAME}/{node_name}_mse_loss_inv_{scope}"] = mse_loss_inv.item() info[f"{self.NAME}/{node_name}_var_loss_{scope}"] = var_loss.item() elif _loss_type.startswith("user_module."): loss_name = _loss_type[len("user_module."):] loss_function = self.config["user_module"].get(loss_name, None) assert loss_function is not None kwargs = { "node_dist" : action_dist, "node_name" : node_name, "isnan_index_list" : isnan_index_list, "isnan_index" : isnan_index, "loss_mask" : loss_mask, "graph" : graph, "expert_data" : expert_data, } policy_loss = loss_function(kwargs) else: raise NotImplementedError loss += policy_loss info[f"{self.NAME}/{node_name}_bc_loss_{scope}"] = policy_loss.item() try: info[f"{self.NAME}/{node_name}_bc_std_{scope}"] = action_dist.std.mean().item() except: breakpoint() info[f"{self.NAME}/total_loss_{scope}"] = loss.item() optimizer.zero_grad(set_to_none=True) loss.backward() grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(get_models_parameters(*models), 50) if torch.isnan(grad_norm.mean()): self.nan_in_grad() logger.info(f'Detect nan in gradient, skip this batch! (loss : {loss}, grad_norm : {grad_norm})') else: optimizer.step() info[f"{self.NAME}/grad_norm"] = grad_norm.item() return info