''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import torch
from loguru import logger
from revive.computation.graph import *
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.batch import Batch
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.algo.venv.base import VenvOperator, catch_error
[docs]class BCOperator(VenvOperator):
NAME = "BC"
PARAMETER_DESCRIPTION = [
{
"name" : "bc_batch_size",
"description": "Batch size of training process.",
"abbreviation" : "bbs",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "bc_epoch",
"description": "Number of epcoh for the training process",
"abbreviation" : "bep",
"type" : int,
"default" : 500,
'doc': True,
},
{
"name" : "bc_horizon",
"abbreviation" : "bh",
"type" : int,
"default" : 10,
},
{
"name" : "policy_hidden_features",
"description": "Number of neurons per layer of the policy network.",
"abbreviation" : "phf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_hidden_layers",
"description": "Depth of policy network.",
"abbreviation" : "phl",
"type" : int,
"default" : 4,
"search_mode" : "grid",
"search_values" : [3, 4, 5],
'doc': True,
},
{
"name" : "policy_activation",
"abbreviation" : "pa",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "policy_normalization",
"abbreviation" : "pn",
"type" : str,
"default" : None,
},
{
"name" : "policy_backbone",
"description": "Backbone of policy network.",
"abbreviation" : "pb",
"type" : str,
"default" : "res",
"search_mode" : "grid",
"search_values" : ['mlp', 'res'],
'doc': True,
},
{
"name" : "transition_hidden_features",
"abbreviation" : "thf",
"type" : int,
"default" : 256,
},
{
"name" : "transition_hidden_layers",
"abbreviation" : "thl",
"type" : int,
"default" : 3,
},
{
"name" : "transition_activation",
"abbreviation" : "ta",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "transition_normalization",
"abbreviation" : "tn",
"type" : str,
"default" : 'ln',
},
{
"name" : "transition_backbone",
"description": "Backbone of Transition network.",
"abbreviation" : "tb",
"type" : str,
"default" : "res",
},
{
"name" : "g_lr",
"description": "Initial learning rate of the training process.",
"type" : float,
"default" : 1e-4,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-3],
'doc': True,
},
{
"name" : "weight_decay",
"abbreviation" : "wd",
"type" : float,
"default" : 1e-4,
},
{
"name" : "lr_decay",
"abbreviation" : "ld",
"type" : float,
"default" : 0.99,
},
{
"name" : "loss_type",
"description": 'Bc support different loss function("log_prob", "mae", "mse").',
"type" : str,
"default" : "log_prob",
'doc': True,
},
]
def _setup_componects(self, config):
super()._setup_componects(config)
# add lr decay here
self.train_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.train_optimizers[0], gamma=config['lr_decay'])
self.val_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.val_optimizers[0], gamma=config['lr_decay'])
self.loss_type = config["loss_type"]
[docs] def model_creator(self, config : dict, graph : DesicionGraph):
"""
Create policies and transition, if needed.
:param config: configuration parameters
:return: list of all models
"""
total_dims = config['total_dims']
networks = []
for node_name in list(graph.keys()):
if node_name in graph.transition_map.values(): continue
node = graph.get_node(node_name)
if not node.node_type == 'network':
logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
continue
input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
node.initialize_network(input_dim, total_dims[node_name]['output'],
hidden_features=config['policy_hidden_features'],
hidden_layers=config['policy_hidden_layers'],
hidden_activation=config['policy_activation'],
norm=config['policy_normalization'],
backbone_type=config['policy_backbone'],
dist_config=config['dist_configs'][node_name],
is_transition=False)
networks.append(node.get_network())
for node_name in graph.transition_map.values():
node = graph.get_node(node_name)
if not node.node_type == 'network':
logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
continue
input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
if node_name[5:] not in node.input_names:
transition_mode = 'global'
if config['transition_mode']:
warnings.warn('Fallback to global transition mode since the transition variable is not provided as an input!')
else:
transition_mode = config['transition_mode']
node.initialize_network(input_dim, total_dims[node_name]['output'],
hidden_features=config['transition_hidden_features'],
hidden_layers=config['transition_hidden_layers'],
hidden_activation=config['transition_activation'],
norm=config['transition_normalization'],
backbone_type=config['transition_backbone'],
dist_config=config['dist_configs'][node_name],
is_transition=True,
transition_mode=transition_mode,
obs_dim=total_dims[node_name]['input'])
networks.append(node.get_network())
assert len(networks) > 0, 'at least one node need to be a network to run training!'
return networks
[docs] def optimizer_creator(self, models, config):
optimizer = torch.optim.Adam(get_models_parameters(*models), lr=config['g_lr'], weight_decay=config['weight_decay'])
return [optimizer]
[docs] def data_creator(self, config : dict):
config[BATCH_SIZE] = config['bc_batch_size']
if config['policy_backbone'] in ['lstm', 'gru'] or config['transition_backbone'] in ['lstm', 'gru']:
return data_creator(config, training_horizon=config['bc_horizon'], training_is_sample=False, val_horizon=config['venv_rollout_horizon'], double=True)
else:
return data_creator(config, training_mode='transition', training_is_sample=False, val_horizon=config['venv_rollout_horizon'], double=True)
@catch_error
def train_epoch(self):
results = super().train_epoch()
self.train_scheduler.step()
self.val_scheduler.step()
return results
[docs] def train_batch(self, expert_data, batch_info, scope='train'):
self._batch_cnt += 1
if scope == 'train':
models = self.train_models
graph = self.graph_train
optimizer = self.train_optimizers[0]
else:
models = self.val_models
graph = self.graph_val
optimizer = self.val_optimizers[0]
graph.reset()
expert_data.to_torch(device=self._device)
info = {}
loss = 0
generated_data = Batch()
for node_name in graph.keys():
node = graph.get_node(node_name)
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - expert_data[node_name + "_isnan_index_"]
else:
isnan_index = None
if node.node_type == 'network':
action_dist = graph.compute_node(node_name, expert_data)
_loss_type = graph.nodes_loss_type.get(node_name, self.loss_type)
if _loss_type == "mae":
if isnan_index is not None:
policy_loss = ((action_dist.mode - expert_data[node_name])*isnan_index).abs().sum(dim=-1).mean()
else:
policy_loss = (action_dist.mode - expert_data[node_name]).abs().sum(dim=-1).mean()
elif _loss_type == "mse":
if isnan_index is not None:
policy_loss = (((action_dist.mode - expert_data[node_name])*isnan_index)**2).sum(dim=-1).mean()
else:
policy_loss = ((action_dist.mode - expert_data[node_name])**2).sum(dim=-1).mean()
elif _loss_type == "nll" or _loss_type == "log_prob" :
generated_data[node_name + '_log_prob'] = action_dist.log_prob(expert_data[node_name])
generated_data[node_name] = action_dist.sample()
if isnan_index is not None:
policy_loss = - (generated_data[node_name + '_log_prob']*isnan_index).mean()
else:
policy_loss = - generated_data[node_name + '_log_prob'].mean()
elif _loss_type.startswith("user_module."):
loss_name = _loss_type[len("user_module."):]
loss_function = self.config["user_module"].get(loss_name, None)
assert loss_function is not None
kwargs = {
"node_dist" : action_dist,
"node_name" : node_name,
"isnan_index" : isnan_index,
"graph" : graph,
"expert_data" : expert_data,
}
policy_loss = loss_function(kwargs)
else:
raise NotImplementedError
loss += policy_loss
info[f"{self.NAME}/{node_name}_loss_{scope}"] = policy_loss.item()
info[f"{self.NAME}/total_loss_{scope}"] = loss.item()
optimizer.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(get_models_parameters(*models), 50)
if torch.any(torch.isnan(grad_norm)):
logger.info(f'Detect nan in gradient, skip this batch! (loss : {loss}, grad_norm : {grad_norm})')
else:
optimizer.step()
info[f"{self.NAME}/grad_norm"] = grad_norm.item()
info = self._early_stop(info)
if self._epoch_cnt >= self.config['bc_epoch']:
self._stop_flag = True
info["stop_flag"] = self._stop_flag
return info