Source code for revive.algo.venv.bc

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

import torch
from loguru import logger

from revive.computation.graph import *
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.batch import Batch
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.algo.venv.base import VenvOperator, catch_error


[docs]class BCOperator(VenvOperator): NAME = "BC" PARAMETER_DESCRIPTION = [ { "name" : "bc_batch_size", "description": "Batch size of training process.", "abbreviation" : "bbs", "type" : int, "default" : 256, 'doc': True, }, { "name" : "bc_epoch", "description": "Number of epcoh for the training process", "abbreviation" : "bep", "type" : int, "default" : 500, 'doc': True, }, { "name" : "bc_horizon", "abbreviation" : "bh", "type" : int, "default" : 10, }, { "name" : "policy_hidden_features", "description": "Number of neurons per layer of the policy network.", "abbreviation" : "phf", "type" : int, "default" : 256, 'doc': True, }, { "name" : "policy_hidden_layers", "description": "Depth of policy network.", "abbreviation" : "phl", "type" : int, "default" : 4, "search_mode" : "grid", "search_values" : [3, 4, 5], 'doc': True, }, { "name" : "policy_activation", "abbreviation" : "pa", "type" : str, "default" : 'leakyrelu', }, { "name" : "policy_normalization", "abbreviation" : "pn", "type" : str, "default" : None, }, { "name" : "policy_backbone", "description": "Backbone of policy network.", "abbreviation" : "pb", "type" : str, "default" : "res", "search_mode" : "grid", "search_values" : ['mlp', 'res'], 'doc': True, }, { "name" : "transition_hidden_features", "abbreviation" : "thf", "type" : int, "default" : 256, }, { "name" : "transition_hidden_layers", "abbreviation" : "thl", "type" : int, "default" : 3, }, { "name" : "transition_activation", "abbreviation" : "ta", "type" : str, "default" : 'leakyrelu', }, { "name" : "transition_normalization", "abbreviation" : "tn", "type" : str, "default" : 'ln', }, { "name" : "transition_backbone", "description": "Backbone of Transition network.", "abbreviation" : "tb", "type" : str, "default" : "res", }, { "name" : "g_lr", "description": "Initial learning rate of the training process.", "type" : float, "default" : 1e-4, "search_mode" : "continuous", "search_values" : [1e-6, 1e-3], 'doc': True, }, { "name" : "weight_decay", "abbreviation" : "wd", "type" : float, "default" : 1e-4, }, { "name" : "lr_decay", "abbreviation" : "ld", "type" : float, "default" : 0.99, }, { "name" : "loss_type", "description": 'Bc support different loss function("log_prob", "mae", "mse").', "type" : str, "default" : "log_prob", 'doc': True, }, ] def _setup_componects(self, config): super()._setup_componects(config) # add lr decay here self.train_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.train_optimizers[0], gamma=config['lr_decay']) self.val_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.val_optimizers[0], gamma=config['lr_decay']) self.loss_type = config["loss_type"]
[docs] def model_creator(self, config : dict, graph : DesicionGraph): """ Create policies and transition, if needed. :param config: configuration parameters :return: list of all models """ total_dims = config['total_dims'] networks = [] for node_name in list(graph.keys()): if node_name in graph.transition_map.values(): continue node = graph.get_node(node_name) if not node.node_type == 'network': logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.') continue input_dim = get_input_dim_from_graph(graph, node_name, total_dims) node.initialize_network(input_dim, total_dims[node_name]['output'], hidden_features=config['policy_hidden_features'], hidden_layers=config['policy_hidden_layers'], hidden_activation=config['policy_activation'], norm=config['policy_normalization'], backbone_type=config['policy_backbone'], dist_config=config['dist_configs'][node_name], is_transition=False) networks.append(node.get_network()) for node_name in graph.transition_map.values(): node = graph.get_node(node_name) if not node.node_type == 'network': logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.') continue input_dim = get_input_dim_from_graph(graph, node_name, total_dims) if node_name[5:] not in node.input_names: transition_mode = 'global' if config['transition_mode']: warnings.warn('Fallback to global transition mode since the transition variable is not provided as an input!') else: transition_mode = config['transition_mode'] node.initialize_network(input_dim, total_dims[node_name]['output'], hidden_features=config['transition_hidden_features'], hidden_layers=config['transition_hidden_layers'], hidden_activation=config['transition_activation'], norm=config['transition_normalization'], backbone_type=config['transition_backbone'], dist_config=config['dist_configs'][node_name], is_transition=True, transition_mode=transition_mode, obs_dim=total_dims[node_name]['input']) networks.append(node.get_network()) assert len(networks) > 0, 'at least one node need to be a network to run training!' return networks
[docs] def optimizer_creator(self, models, config): optimizer = torch.optim.Adam(get_models_parameters(*models), lr=config['g_lr'], weight_decay=config['weight_decay']) return [optimizer]
[docs] def data_creator(self, config : dict): config[BATCH_SIZE] = config['bc_batch_size'] if config['policy_backbone'] in ['lstm', 'gru'] or config['transition_backbone'] in ['lstm', 'gru']: return data_creator(config, training_horizon=config['bc_horizon'], training_is_sample=False, val_horizon=config['venv_rollout_horizon'], double=True) else: return data_creator(config, training_mode='transition', training_is_sample=False, val_horizon=config['venv_rollout_horizon'], double=True)
@catch_error def train_epoch(self): results = super().train_epoch() self.train_scheduler.step() self.val_scheduler.step() return results
[docs] def train_batch(self, expert_data, batch_info, scope='train'): self._batch_cnt += 1 if scope == 'train': models = self.train_models graph = self.graph_train optimizer = self.train_optimizers[0] else: models = self.val_models graph = self.graph_val optimizer = self.val_optimizers[0] graph.reset() expert_data.to_torch(device=self._device) info = {} loss = 0 generated_data = Batch() for node_name in graph.keys(): node = graph.get_node(node_name) if node_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - expert_data[node_name + "_isnan_index_"] else: isnan_index = None if node.node_type == 'network': action_dist = graph.compute_node(node_name, expert_data) _loss_type = graph.nodes_loss_type.get(node_name, self.loss_type) if _loss_type == "mae": if isnan_index is not None: policy_loss = ((action_dist.mode - expert_data[node_name])*isnan_index).abs().sum(dim=-1).mean() else: policy_loss = (action_dist.mode - expert_data[node_name]).abs().sum(dim=-1).mean() elif _loss_type == "mse": if isnan_index is not None: policy_loss = (((action_dist.mode - expert_data[node_name])*isnan_index)**2).sum(dim=-1).mean() else: policy_loss = ((action_dist.mode - expert_data[node_name])**2).sum(dim=-1).mean() elif _loss_type == "nll" or _loss_type == "log_prob" : generated_data[node_name + '_log_prob'] = action_dist.log_prob(expert_data[node_name]) generated_data[node_name] = action_dist.sample() if isnan_index is not None: policy_loss = - (generated_data[node_name + '_log_prob']*isnan_index).mean() else: policy_loss = - generated_data[node_name + '_log_prob'].mean() elif _loss_type.startswith("user_module."): loss_name = _loss_type[len("user_module."):] loss_function = self.config["user_module"].get(loss_name, None) assert loss_function is not None kwargs = { "node_dist" : action_dist, "node_name" : node_name, "isnan_index" : isnan_index, "graph" : graph, "expert_data" : expert_data, } policy_loss = loss_function(kwargs) else: raise NotImplementedError loss += policy_loss info[f"{self.NAME}/{node_name}_loss_{scope}"] = policy_loss.item() info[f"{self.NAME}/total_loss_{scope}"] = loss.item() optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(get_models_parameters(*models), 50) if torch.any(torch.isnan(grad_norm)): logger.info(f'Detect nan in gradient, skip this batch! (loss : {loss}, grad_norm : {grad_norm})') else: optimizer.step() info[f"{self.NAME}/grad_norm"] = grad_norm.item() info = self._early_stop(info) if self._epoch_cnt >= self.config['bc_epoch']: self._stop_flag = True info["stop_flag"] = self._stop_flag return info