''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import torch
from loguru import logger
from revive.computation.graph import *
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.batch import Batch
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.algo.venv.base import VenvOperator, catch_error
[docs]
class BCOperator(VenvOperator):
NAME = "REVIVE_VENV" #"BC"
PARAMETER_DESCRIPTION = [
{
"name" : "bc_batch_size",
"description": "Batch size of training process.",
"abbreviation" : "bbs",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "bc_epoch",
"description": "Number of epcoh for the training process",
"abbreviation" : "bep",
"type" : int,
"default" : 500,
'doc': True,
},
{
"name" : "bc_horizon",
"abbreviation" : "bh",
"type" : int,
"default" : 10,
},
{
"name" : "policy_hidden_features",
"description": "Number of neurons per layer of the policy network.",
"abbreviation" : "phf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_hidden_layers",
"description": "Depth of policy network.",
"abbreviation" : "phl",
"type" : int,
"default" : 4,
"search_mode" : "grid",
"search_values" : [3, 4, 5],
'doc': True,
},
{
"name" : "policy_activation",
"abbreviation" : "pa",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "policy_normalization",
"abbreviation" : "pn",
"type" : str,
"default" : None,
},
{
"name" : "policy_backbone",
"description": "Backbone of policy network. Support selecting from [mlp, res, ft_transformer, lstm, gru].",
"abbreviation" : "pb",
"type" : str,
"default" : "res",
'doc': True,
},
{
"name" : "transition_hidden_features",
"abbreviation" : "thf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "transition_hidden_layers",
"abbreviation" : "thl",
"type" : int,
"default" : 3,
'doc': True,
},
{
"name" : "transition_activation",
"abbreviation" : "ta",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "transition_normalization",
"abbreviation" : "tn",
"type" : str,
"default" : None,
},
{
"name" : "transition_backbone",
"description": "Backbone of Transition network. Support selecting from [mlp, res, ft_transformer, lstm, gru].",
"abbreviation" : "tb",
"type" : str,
"default" : "res",
'doc': True,
},
{
"name" : "g_lr",
"description": "Initial learning rate of the training process.",
"type" : float,
"default" : 1e-4,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-3],
'doc': True,
},
{
"name" : "weight_decay",
"abbreviation" : "wd",
"type" : float,
"default" : 1e-4,
},
{
"name" : "lr_decay",
"abbreviation" : "ld",
"type" : float,
"default" : 0.99,
},
{
"name" : "loss_type",
"description": 'Bc support different loss function("nll", "mae", "mse").',
"type" : str,
"default" : "nll",
'doc': True,
},
{
"name" : "bc_l2_coef",
"type" : float,
"default" : 5e-5,
},
{
"name" : "logstd_loss_coef",
"type" : float,
"default" : 0.01,
},
]
@catch_error
def __init__(self, config):
super().__init__(config)
self._total_epoch = self.config['bc_epoch']
def _setup_componects(self):
super()._setup_componects()
# add lr decay here
# self.train_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.train_optimizers[0], gamma=self.config['lr_decay'])
# self.val_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.val_optimizers[0], gamma=self.config['lr_decay'])
self.loss_type = self.config["loss_type"]
[docs]
def model_creator(self, config : dict, graph : DesicionGraph):
"""
Create policies and transition, if needed.
:param config: configuration parameters
:return: list of all models
"""
total_dims = config['total_dims']
networks = []
for node_name in list(graph.keys()):
if node_name in graph.transition_map.values():
continue
node = graph.get_node(node_name)
if not node.node_type == 'network':
logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
continue
input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
if hasattr(node, 'custom_node'):
input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
input_dim_dict = input_dim
else:
input_dim_dict = get_input_dim_dict_from_graph(graph, node_name, total_dims)
node.initialize_network(input_dim, total_dims[node_name]['output'],
hidden_features=config['policy_hidden_features'],
hidden_layers=config['policy_hidden_layers'],
hidden_activation=config['policy_activation'],
norm=config['policy_normalization'],
backbone_type=config['policy_backbone'],
dist_config=config['dist_configs'][node_name],
is_transition=False,
input_dim_dict=input_dim_dict,)
networks.append(node.get_network())
# initialize transition node networks
for node_name in graph.transition_map.values():
node = graph.get_node(node_name)
if not node.node_type == 'network':
logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
continue
if node_name[5:] not in node.input_names:
transition_mode = 'global'
if config['transition_mode']:
warnings.warn('Fallback to global transition mode since the transition variable is not provided as an input!')
else:
transition_mode = config['transition_mode']
input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
input_dim_dict = get_input_dim_dict_from_graph(graph, node_name, total_dims)
if hasattr(node, 'custom_node'):
input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
node.initialize_network(input_dim, total_dims[node_name]['output'],
hidden_features=config['transition_hidden_features'],
hidden_layers=config['transition_hidden_layers'],
hidden_activation=config['transition_activation'],
norm=config['transition_normalization'],
backbone_type=config['transition_backbone'],
dist_config=config['dist_configs'][node_name],
is_transition=True,
transition_mode=transition_mode,
obs_dim=total_dims[node_name]['input'],
input_dim_dict=input_dim_dict,)
networks.append(node.get_network())
assert len(networks) > 0, 'at least one node need to be a network to run training!'
return networks
[docs]
def optimizer_creator(self, models, config):
optimizer = torch.optim.Adam(get_models_parameters(*models), lr=config['g_lr'], weight_decay=config['weight_decay'])
return [optimizer]
[docs]
def data_creator(self):
self.config[BATCH_SIZE] = self.config['bc_batch_size']
if self.config["venv_train_dataset_mode"] == "trajectory":
return data_creator(self.config,
training_mode="trajectory",
training_horizon=self.config['bc_horizon'],
training_is_sample=False,
val_horizon=self.config['venv_rollout_horizon'],
pre_horzion=self.config['pre_horzion'],
double=True)
else:
return data_creator(self.config,
training_mode='transition',
training_is_sample=False,
val_horizon=self.config['venv_rollout_horizon'],
pre_horzion=self.config['pre_horzion'],
double=True)
@catch_error
def train_epoch(self):
results = super().train_epoch()
# self.train_scheduler.step()
# self.val_scheduler.step()
return results
[docs]
def train_batch(self, expert_data, batch_info, scope='train'):
self._batch_cnt += 1
if scope == 'train':
models = self.train_models
graph = self.graph_train
optimizer = self.train_optimizers[0]
else:
models = self.val_models
graph = self.graph_val
optimizer = self.val_optimizers[0]
graph.reset()
info = {}
loss = 0
dataset_mode = self.config["venv_train_dataset_mode"]
if dataset_mode == "trajectory":
sample_fn = lambda dist: dist.mode
generated_data, dist_dict = generate_rollout(expert_data,
graph,
expert_data.shape[0],
sample_fn,
clip=1.5,
return_dist=True,
mode="train")
if self.config["pre_horzion"] > 0:
expert_data = expert_data[self.config["pre_horzion"]:]
generated_data = generated_data[self.config["pre_horzion"]:]
for node_name in graph.keys():
node = graph.get_node(node_name)
if node.node_type == 'network':
# Missing value processing by isnan_index
isnan_index_list = []
isnan_index = 1.
# check whether nan is in inputs
for node_name_ in node.input_names:
if (node_name_ + "_isnan_index_") in expert_data.keys():
isnan_index_list.append(expert_data[node_name_ + "_isnan_index_"])
# check whether nan is in outputs
if (node_name + "_isnan_index_") in expert_data.keys():
isnan_index_list.append(expert_data[node_name + "_isnan_index_"])
if isnan_index_list:
isnan_index, _ = torch.max(torch.cat(isnan_index_list, axis=-1), axis=-1, keepdim=True)
loss_mask = 1 - isnan_index
else:
loss_mask = 1.
if dataset_mode == "trajectory":
action_dist = dist_dict[node_name]
action = generated_data[node_name]
else:
action_dist = graph.compute_node(node_name, expert_data)
action = action_dist.mode
# Empty node processing, use rollout data as expert data for nodata nodes
if node_name in self._graph.nodata_node_names:
if dataset_mode == "trajectory":
expert_data[node_name] = generated_data[node_name]
else:
expert_data[node_name] = action_dist.mode
continue
# Calculate loss
_loss_type = graph.nodes_loss_type.get(node_name, self.loss_type)
if _loss_type == "mae":
policy_loss = ((action - expert_data[node_name])*loss_mask).abs().sum(dim=-1).mean()
elif _loss_type == "mse":
policy_loss = (((action - expert_data[node_name])*loss_mask)**2).sum(dim=-1).mean()
elif _loss_type == "nll" or _loss_type == "log_prob":
if dataset_mode == "trajectory":
policy_loss = 0
for traj_index in range(expert_data.shape[0]):
if isinstance(loss_mask, float):
policy_loss -= (dist_dict[node_name][traj_index].log_prob(expert_data[node_name][traj_index])).mean()
else:
policy_loss -= (dist_dict[node_name][traj_index].log_prob(expert_data[node_name][traj_index])*loss_mask[traj_index]).mean()
else:
policy_loss = - (action_dist.log_prob(expert_data[node_name])*loss_mask).mean()
# policy_loss += self.net_l2_norm(node.network) * self.config["bc_l2_coef"]
# policy_loss += self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].max_logstd.sum() \
# - self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].min_logstd.sum()
elif _loss_type == "gaussian_loss":
# assert not rnn_flag, "gaussian_loss does not support RNN yet"
target = expert_data[node_name]
mean, std = action_dist.mode, action_dist.std
var = torch.pow(std, 2)
# Average over batch and dim, sum over ensembles.
mse_loss_inv = (torch.pow(mean - target, 2) / torch.maximum(var, torch.ones_like(var, device=self._device) * 1e-6) * loss_mask).mean()
var_loss = torch.log(torch.maximum(var, torch.ones_like(var, device=self._device) * 1e-6)).mean()
policy_loss = mse_loss_inv + var_loss
# policy_loss += self.net_l2_norm(node.network) * self.config["bc_l2_coef"]
# policy_loss += self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].max_logstd.sum() \
# - self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].min_logstd.sum()
info[f"{self.NAME}/{node_name}_mse_loss_inv_{scope}"] = mse_loss_inv.item()
info[f"{self.NAME}/{node_name}_var_loss_{scope}"] = var_loss.item()
elif _loss_type.startswith("user_module."):
loss_name = _loss_type[len("user_module."):]
loss_function = self.config["user_module"].get(loss_name, None)
assert loss_function is not None
kwargs = {
"node_dist" : action_dist,
"node_name" : node_name,
"isnan_index_list" : isnan_index_list,
"isnan_index" : isnan_index,
"loss_mask" : loss_mask,
"graph" : graph,
"expert_data" : expert_data,
}
policy_loss = loss_function(kwargs)
else:
raise NotImplementedError
loss += policy_loss
info[f"{self.NAME}/{node_name}_bc_loss_{scope}"] = policy_loss.item()
try:
info[f"{self.NAME}/{node_name}_bc_std_{scope}"] = action_dist.std.mean().item()
except:
breakpoint()
info[f"{self.NAME}/total_loss_{scope}"] = loss.item()
optimizer.zero_grad(set_to_none=True)
loss.backward()
grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(get_models_parameters(*models), 50)
if torch.isnan(grad_norm.mean()):
self.nan_in_grad()
logger.info(f'Detect nan in gradient, skip this batch! (loss : {loss}, grad_norm : {grad_norm})')
else:
optimizer.step()
info[f"{self.NAME}/grad_norm"] = grad_norm.item()
return info