''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import torch
import numpy as np
from copy import deepcopy
from revive.computation.dists import kl_divergence
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.batch import Batch
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.data.dataset import data_creator
from revive.algo.policy.base import PolicyOperator, catch_error
[docs]
class ReplayBuffer:
"""
A simple FIFO experience replay buffer for SAC agents.
"""
def __init__(self, buffer_size):
self.data = None
self.buffer_size = int(buffer_size)
[docs]
def put(self, batch_data : Batch):
batch_data.to_torch(device='cpu')
if self.data is None:
self.data = batch_data
else:
self.data.cat_(batch_data)
if len(self) > self.buffer_size:
self.data = self.data[len(self) - self.buffer_size : ]
[docs]
def __len__(self):
if self.data is None: return 0
return self.data.shape[0]
[docs]
def sample(self, batch_size):
assert len(self) > 0, 'Cannot sample from an empty buffer!'
indexes = np.random.randint(0, len(self), size=(batch_size))
return self.data[indexes]
[docs]
class SACOperator(PolicyOperator):
"""
A class used to train platform policy.
"""
NAME = "SAC"
critic_pre_trained = False
PARAMETER_DESCRIPTION = [
{
"name" : "sac_batch_size",
"description": "Batch size of training process.",
"abbreviation" : "pbs",
"type" : int,
"default" : 1024,
'doc': True,
},
{
"name" : "policy_bc_epoch",
"description": "pre-train policy with setting epoch",
"type" : int,
"default" : 0,
"doc": True,
},
{
"name" : "sac_epoch",
"description": "Number of epcoh for the training process.",
"abbreviation" : "bep",
"type" : int,
"default" : 1000,
'doc': True,
},
{
"name" : "sac_steps_per_epoch",
"description": "The number of update rounds of sac in each epoch.",
"abbreviation" : "sspe",
"type" : int,
"default" : 200,
'doc': True,
},
{
"name" : "sac_rollout_horizon",
"abbreviation" : "srh",
"type" : int,
"default" : 20,
'doc': True,
},
{
"name" : "policy_hidden_features",
"description": "Number of neurons per layer of the policy network.",
"abbreviation" : "phf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_hidden_layers",
"description": "Depth of policy network.",
"abbreviation" : "phl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "policy_backbone",
"description": "Backbone of policy network. [mlp, res, ft_transformer]",
"abbreviation" : "pb",
"type" : str,
"default" : "res",
'doc': True,
},
{
"name" : "policy_hidden_activation",
"description": "hidden_activation of policy network.",
"abbreviation" : "pha",
"type" : str,
"default" : "leakyrelu",
'doc': True,
},
{
"name" : "q_hidden_features",
"abbreviation" : "qhf",
"type" : int,
"default" : 256,
},
{
"name" : "q_hidden_layers",
"abbreviation" : "qhl",
"type" : int,
"default" : 2,
},
{
"name" : "num_q_net",
"abbreviation" : "nqn",
"type" : int,
"default" : 4,
},
{
"name" : "buffer_size",
"description": "Size of the buffer to store data.",
"abbreviation" : "bfs",
"type" : int,
"default" : 1e6,
'doc': True,
},
{
"name" : "w_kl",
"type" : float,
"default" : 1.0,
},
{
"name" : "gamma",
"type" : float,
"default" : 0.99,
},
{
"name" : "alpha",
"type" : float,
"default" : 0.2,
},
{
"name" : "polyak",
"type" : float,
"default" : 0.99,
},
{
"name" : "batch_ratio",
"type" : float,
"default" : 1,
},
{
"name" : "g_lr",
"description": "Initial learning rate of the training process.",
"type" : float,
"default" : 4e-5,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-3],
'doc': True,
},
{
"name" : "interval",
"description": "interval step for index removing.",
"type" : int,
"default" : 0
},
{
"name" : "generate_deter",
"description": "deterministic of generator rollout",
"type" : int,
"default" : 1
},
{
"name" : "filter",
"type" : bool,
"default" : False,
},
{
"name" : "candidate_num",
"type" : int,
"default" : 50,
},
{
"name" : "ensemble_size",
"type" : int,
"default" : 50,
},
{
"name" : "ensemble_choosing_interval",
"type" : int,
"default" : 10,
},
{
"name" : "disturbing_transition_function",
"description": "Disturbing the network node in policy learning",
"type": bool,
"default": False,
},
{
"name" : "disturbing_nodes",
"description": "Disturbing the network node in policy learning",
"type": list,
"default": 'auto',
},
{
"name" : "disturbing_net_num",
"description": "Disturbing the network node in policy learning",
"type": int,
"default": 100,
},
{
"name" : "disturbing_weight",
"description": "Disturbing the network node in policy learning",
"type": float,
"default": 0.05,
},
{
"name" : "critic_pretrain",
"type": bool,
"default": True,
},
{
"name" : "reward_uncertainty_weight",
"description": "Reward uncertainty weight(MOPO)",
"type" : float,
"default" : 0,
"search_mode" : "continuous",
},
{
"name" : "penalty_sample_num",
"type": int,
"default": 20,
},
{
"name" : "penalty_type",
"type": str,
"default": "None",
},
{
"name" : "ts_conv_nodes",
"type" : list,
"default" : "auto",
},
]
[docs]
def model_creator(self, nodes):
"""
Create model including platform policy and value net.
:return: env model, platform policy, value net
"""
models = super().model_creator(nodes)
assert len(models) == 1, f"{self.NAME} don't support multi target polciy."
graph = self._graph
# TODO: target_policy only one future support multi policy
policy_name = self.policy_name[0]
total_dims = self.config['total_dims']
input_dim = get_input_dim_from_graph(graph, policy_name, total_dims)
input_dim += total_dims[policy_name]['input']
additional_kwargs = {}
if self.config['ts_conv_nodes'] != 'auto' and policy_name in self.config['ts_conv_nodes']:
""""get"""
temp_input_dim = get_input_dim_dict_from_graph(graph, policy_name, total_dims)
temp_input_dim[policy_name] = total_dims[policy_name]['input']
has_ts = any('ts' in element for element in temp_input_dim)
if not has_ts:
input_dim = temp_input_dim
pass
else:
input_dim, input_dim_config, ts_conv_net_config = \
dict_ts_conv(graph, temp_input_dim, total_dims, self.config,
net_hidden_features=self.config['q_hidden_features'])
additional_kwargs['ts_conv_config'] = input_dim_config
additional_kwargs['ts_conv_net_config'] = ts_conv_net_config
else:
input_dim = get_input_dim_dict_from_graph(graph, policy_name, total_dims)
input_dim[policy_name] = total_dims[policy_name]['input']
pass
# breakpoint()
q_net = Value_Net_VectorizedCritic(input_dim,
self.config['q_hidden_features'],
self.config['q_hidden_layers'],
self.config['num_q_net'], **additional_kwargs)
# q_net = VectorizedCritic(input_dim,
# 1,
# self.config['q_hidden_features'],
# self.config['q_hidden_layers'],
# self.config['num_q_net'])
target_q_net = deepcopy(q_net)
target_q_net.requires_grad_(False)
models = models + [ q_net, target_q_net]
return models
[docs]
def optimizer_creator(self, scope):
"""
:return: generator optimizer including platform policy optimizer and value net optimizer
"""
# TODO: target_policy only one future support multi policy
if scope == "train":
target_policy = self.train_policy[0]
q_net, _ = self.other_train_models
else:
target_policy = self.val_policy[0]
q_net, _ = self.other_val_models
optimizers = []
actor_optimizor = torch.optim.Adam(target_policy.parameters(), lr=self.config['g_lr'])
critic_optimizor = torch.optim.Adam(q_net.parameters(), lr=self.config['g_lr'])
optimizers.append(actor_optimizor)
optimizers.append(critic_optimizor)
return optimizers
[docs]
def data_creator(self):
self.config[BATCH_SIZE] = self.config['sac_batch_size']
return data_creator(self.config,
training_mode='trajectory',
training_horizon=self.config['sac_rollout_horizon'],
training_is_sample=True,
val_horizon=self.config['test_horizon'],
double=self.double_validation)
@catch_error
def setup(self, config):
# super(SACOperator, self).setup(config)
self.replay_buffer_train = ReplayBuffer(config['buffer_size'])
self.expert_buffer_train = ReplayBuffer(config['buffer_size'])
if self.double_validation:
self.replay_buffer_val = ReplayBuffer(config['buffer_size'])
self.expert_buffer_val = ReplayBuffer(config['buffer_size'])
@catch_error
def before_validate_epoch(self):
if self._epoch_cnt >= self.config['policy_bc_epoch'] + self.config['sac_epoch']:
self._stop_flag = True
[docs]
def done_process(self,
expert_data=None, generated_data=None,
generate_done=True, expert_done=True):
if 'done' in self._graph.graph_dict.keys():
if expert_done:
#expert data process
temp_done_expert = self._processor.deprocess_single_torch(expert_data['done'], 'done')
assert temp_done_expert.shape == expert_data['done'].shape
expert_data['done'] = temp_done_expert
assert torch.sum(expert_data['done'])>=0
else:
expert_data = None
if generate_done:
#generate data process
temp_done_generate = self._processor.deprocess_single_torch(generated_data['done'], 'done')
assert temp_done_generate.shape == generated_data['done'].shape
generated_data['done'] = temp_done_generate
assert torch.sum(generated_data['done'])>=0
else:
generated_data = None
else:
if expert_done:
#expert data process
expert_data['done'] = torch.zeros(expert_data.shape[:-1]+[1]).to(expert_data[self._graph.leaf[0]].device)
assert list(expert_data['done'].shape)[:-1] == expert_data.shape[:-1]
assert torch.sum(expert_data['done']).item()==0 # done is not in the decision-flow, and should all be zeros
else:
expert_data = None
if generate_done:
#generate data process
generated_data['done'] = torch.zeros(generated_data.shape[:-1]+[1]).to(generated_data[self._graph.leaf[0]].device)
assert list(generated_data['done'].shape)[:-1] == generated_data.shape[:-1]
assert torch.sum(generated_data['done']).item()==0 # done is not in the decision-flow, and should all be zeros
else:
generated_data = None
return expert_data, generated_data
[docs]
def buffer_process(self,
generated_data=None, buffer=None, model_index=None,
expert_data=None, buffer_expert=None, expert_index=None,
generate_buffer=True, expert_buffer=True):
if generate_buffer:# if model_index.shape[0]:
for i in range(generated_data.shape[0] - 1):
#TODO: policy name should be set as list of multi policy name
next_remove_index = np.array(torch.where(generated_data['done'][i]!=0)[0].tolist(), dtype=model_index.dtype)
try:
if self.config.get("interval", 0) != 0 :
interval_remove_index = np.array(torch.where(generated_data['done'][i:i+self.config['interval']]!=0)[1].tolist(), dtype=model_index.dtype)
__interval = self.config['interval']
logger.warning(f'interval-done, in the epoch: {self._epoch_cnt}, at step: {i}')
interval_remove_index = None
except:
interval_remove_index = None
buffer.put(Batch({
'obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(generated_data[ i ]).values()), dim=-1)[model_index],
'next_obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(generated_data[i+1]).values()), dim=-1)[model_index],
'action' : generated_data[i][self.policy_name[0]][model_index],
'reward' : generated_data[i]['reward'][model_index],
'done' : generated_data[i]['done'][model_index],
}))
self.buffer_input_dim = get_input_dim_dict_from_graph(self._graph, self.policy_name[0], self.config['total_dims'])
pre_size = model_index.shape[0]
model_index = np.setdiff1d(model_index, next_remove_index)
if not interval_remove_index is None:
model_index = np.setdiff1d(model_index, interval_remove_index)
logger.info(f'Detect interval_remove_index done with interval {__interval}, with shape {interval_remove_index.shape[0]}')
post_size = model_index.shape[0]
if post_size<pre_size:
logger.info(f'size of model_index is descent, in the epoch: {self._epoch_cnt}, at step: {i}')
else:
buffer = None
if expert_buffer: #if expert_index.shape[0]:
#generate rewards for expert data
expert_rewards = generate_rewards(expert_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data)))
for i in range(expert_data.shape[0] - 1):
#TODO: policy name should be set as list of multi policy name
next_remove_index = np.array(torch.where(expert_data[i]['done']!=0)[0].tolist(), dtype=expert_index.dtype)
if next_remove_index.shape[0] != 0:
logger.info(f'Detect expert data traj-done, in the epoch: {self._epoch_cnt}, at step: {i}')
buffer_expert.put(Batch({
'obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(expert_data[ i ]).values()), dim=-1)[expert_index],
'next_obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(expert_data[i+1]).values()), dim=-1)[expert_index],
'action' : expert_data[i][self.policy_name[0]][expert_index],
'reward' : expert_rewards[i]['reward'][expert_index],
'done' : expert_data[i]['done'][expert_index],
}))
self.buffer_input_dim = get_input_dim_dict_from_graph(self._graph, self.policy_name[0], self.config['total_dims'])
self.buffer_input_policy_name = self.policy_name[0]
pre_size = expert_index.shape[0]
expert_index = np.setdiff1d(expert_index, next_remove_index)
post_size = expert_index.shape[0]
if post_size<pre_size:
logger.info(f'size of expert_index is descent, in the epoch: {self._epoch_cnt}, at step: {i}')
else:
buffer_expert = None
return buffer, buffer_expert
[docs]
def pretrain_critic(self, expert_data, scope):
logger.info(f'Pretraining Critic')
if scope == 'train':
target_policy = self.train_models[:-2]
q_net, target_q_net = self.train_models[-2:]
actor_optimizer = self.train_optimizers[0]
critic_optimizer = self.train_optimizers[1]
envs = self.envs_train
buffer = self.replay_buffer_train
buffer_expert = self.expert_buffer_train
else:
assert self.double_validation
target_policy = self.val_models[:-2]
q_net, target_q_net = self.val_models[-2:]
actor_optimizer = self.val_optimizers[0]
critic_optimizer = self.val_optimizers[1]
envs = self.envs_val
buffer = self.replay_buffer_val
buffer_expert = self.expert_buffer_val
expert_data.to_torch(device=self._device)
assert len(expert_data.shape) == 3 # the expert_data should be with three dimensions as [horizon, batch, features]
# generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['sac_rollout_horizon'], deterministic=bool(self.config['generate_deter']), clip=True)
expert_data, _ = self.done_process(expert_data, generated_data=None, generate_done=False, expert_done=True)
all_index = np.arange(0, min(self.config[BATCH_SIZE], expert_data.shape[1]), 1)
expert_index= all_index
_, buffer_expert = self.buffer_process(generated_data=None, buffer=None, model_index=None,
expert_data=expert_data, buffer_expert=buffer_expert, expert_index=expert_index,
generate_buffer=False, expert_buffer=True)
# logger.info(f'buffer_gen length : {len(buffer)}, buffer_real length : {len(buffer_expert)}')
assert len(target_policy)==1 # SAC is now only support single policy instead of multi-policy.
target_policy = target_policy[0]
_info = self.sac(buffer, buffer_expert, target_policy,
q_net, target_q_net,
gamma=self.config['gamma'], alpha=self.config['alpha'], polyak=self.config['polyak'],
actor_optimizer=actor_optimizer, critic_optimizer=critic_optimizer, pre_train_critic=True)
info = _info
for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k)
try:
self.bc_init_net = deepcopy(target_policy) #deepcopy(self.behaviour_policys[0]) #
if hasattr(list(self.train_nodes.values())[0], 'custom_node'):
self.bc_init_net_custome_node = True
else:
self.bc_init_net_custome_node = False
except:
logger.info(f'No behaviour policys can be used for calculating kl loss')
return info
[docs]
def bc_train_batch(self, expert_data, batch_info, scope='train'):
#决定是否有必要进行sac pre train critic的操作,判断依据target policy网络被进行了init初始化
info = {}
bc_metrics = super().bc_train_batch(expert_data, batch_info, scope)
info.update(bc_metrics)
if self.config['critic_pretrain']:
_info = self.pretrain_critic(expert_data, scope)
info.update(_info)
return info
@catch_error
def train_batch(self, expert_data, batch_info, scope='train'):
if scope == 'train':
#[target_policy, q_net1, q_net2, q_net3, q_net4, target_q_net1, target_q_net2, target_q_net3, target_q_net4]
target_policy = self.train_models[:-2]
q_net, target_q_net = self.train_models[-2:]
# target_policy = self.train_models[:-4]
# _, q_net1, q_net2, target_q_net1, target_q_net2 = self.train_models
actor_optimizer = self.train_optimizers[0]
critic_optimizer = self.train_optimizers[1]
envs = self.envs_train # List[env_dev] of length 1
buffer = self.replay_buffer_train
buffer_expert = self.expert_buffer_train
else:
assert self.double_validation
target_policy = self.val_models[:-2]
q_net, target_q_net = self.val_models[-2:]
# target_policy = self.val_models[:-4]
# _, q_net1, q_net2, target_q_net1, target_q_net2 = self.val_models
# target_policy, q_net1, q_net2, target_q_net1, target_q_net2 = self.val_models
actor_optimizer = self.val_optimizers[0]
critic_optimizer = self.val_optimizers[1]
envs = self.envs_val # List[env_dev] of length 1
buffer = self.replay_buffer_val
buffer_expert = self.expert_buffer_val
expert_data.to_torch(device=self._device)
assert len(expert_data.shape) == 3 # the expert_data should be with three dimensions as [horizon, batch, features]
generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['sac_rollout_horizon'], deterministic=bool(self.config['generate_deter']), clip=True)
expert_data, generated_data = self.done_process(expert_data, generated_data)
all_index = np.arange(0, min(self.config[BATCH_SIZE], expert_data.shape[1]), 1)
model_index = all_index #np.random.choice(all_index,int(self.config[BATCH_SIZE]*1), replace=False) #np.random.choice(all_index,int(self.config[BATCH_SIZE]*(self.config['batch_ratio'])), replace=False)
expert_index= all_index #np.random.choice(all_index,int(self.config[BATCH_SIZE]*1), replace=False)#np.random.choice(all_index,self.config[BATCH_SIZE]-int(self.config[BATCH_SIZE]*(self.config['batch_ratio'])), replace=False)
buffer, buffer_expert = self.buffer_process(generated_data, buffer, model_index, expert_data, buffer_expert, expert_index)
# logger.info(f'buffer_gen length : {len(buffer)}, buffer_real length : {len(buffer_expert)}')
assert len(target_policy)==1 # SAC is now only support single policy instead of multi-policy.
target_policy = target_policy[0]
if not hasattr(self, 'bc_init_net'):
#由于sac需要利用第一个网络来计算kl_loss 所以需要维护一个self.bc_init_net
try:
self.bc_init_net = deepcopy(target_policy) #deepcopy(self.behaviour_policys[0]) #
if hasattr(list(self.train_nodes.values())[0], 'custom_node'):
self.bc_init_net_custome_node = True
else:
self.bc_init_net_custome_node = False
except:
logger.info(f'No behaviour policys can be used for calculating kl loss')
_info = self.sac(buffer, buffer_expert, target_policy,
q_net, target_q_net, #q_net1, q_net2, target_q_net1, target_q_net2, #q_net, target_q_net, #q_net3, q_net4, target_q_net3, target_q_net4,
gamma=self.config['gamma'], alpha=self.config['alpha'], polyak=self.config['polyak'],
actor_optimizer=actor_optimizer, critic_optimizer=critic_optimizer, scope=scope)
info = _info
for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k)
return info
[docs]
@ torch.no_grad()
def compute_lcb(self,
_gen_data: Batch,
scope=None):
if scope == 'train':
envs = self.envs_train[0] # List[env_dev] of length 1
else:
assert self.double_validation
envs = self.envs_val[0] # List[env_dev] of length 1
gen_data = Batch({'obs': deepcopy(_gen_data['obs']), # [batch_size, dim]
'action': deepcopy(_gen_data['action']), # [batch_size, dim]
})
gen_data.to_torch(device=self._device)
if self.config["penalty_type"] == "filter":
penalty = envs.filter_penalty(penalty_type="filter", data=gen_data, sample_num=1, clip=True) # [batch_size, 1]
elif self.config["penalty_type"] == "filter_score_std":
penalty = envs.filter_penalty(penalty_type="filter_score_std", data=gen_data, sample_num=self.config['penalty_sample_num'], clip=True) # [batch_size, 1]
else:
penalty = 0
return penalty
[docs]
def sac(self,
buffer, buffer_real, target_policy,
q_net, target_q_net,
gamma=0.99, alpha=0.2, polyak=0.99,
actor_optimizer=None, critic_optimizer=None, pre_train_critic=False,
scope=None):
# if self._epoch_cnt <= (1 + self.config.get("policy_bc_epoch", 0)):
# logger.warning(f'BC pre-trained the policy! Critic_pretrain is finished!')
# # for behaviour_policy in self.behaviour_policys:
for _ in range(self.config['sac_steps_per_epoch']):
self._batch_cnt += 1
penalty = 0
penalty_record = 0
if pre_train_critic:
data = buffer_real.sample(int(self.config['sac_batch_size']))
data.to_torch(device=self._device)
obs = data['obs']
action = data['action']
else:
real_data_size = int(self.config['sac_batch_size']*(1-self.config['batch_ratio']))
generated_data_size = int(self.config['sac_batch_size']*(self.config['batch_ratio']))
_data_real = buffer_real.sample(real_data_size)
_data_gen = buffer.sample(generated_data_size) # [batch_size, dim]
if self.config['reward_uncertainty_weight'] > 0:
penalty = self.compute_lcb(_data_gen, scope)
if isinstance(penalty, torch.Tensor):
penalty_record = (self.config['reward_uncertainty_weight'] * penalty).mean().item()
penalty = torch.cat([torch.zeros((real_data_size, 1), device=self._device, dtype=penalty.dtype), penalty], dim=0) # pad real data with 0 penalty
data = Batch()
for k, _ in _data_gen.items():
data[k] = torch.cat([_data_real[k], _data_gen[k]], dim=0)
data.to_torch(device=self._device)
obs = data['obs']
action = data['action']
# update critic ----------------------------------------
with torch.no_grad():
next_obs = data['next_obs']
if hasattr(list(self.train_nodes.values())[0], 'custom_node') or hasattr(list(self.train_nodes.values())[0], 'ts_conv_node') or self.input_is_dict:
input_dict = dict()
last_dim = 0
for k, v in self.buffer_input_dim.items():
input_dict[k] = next_obs[..., last_dim:last_dim+v]
last_dim += v
next_action_dist = target_policy(input_dict)
else:
next_action_dist = target_policy(next_obs)
next_action = next_action_dist.sample()
next_action_log_prob = next_action_dist.log_prob(next_action).unsqueeze(dim=-1)
next_obs_action = torch.cat([next_obs, next_action], dim=-1)
try:
if hasattr(target_q_net, 'conv_ts_node'):
next_obs_action = Batch()
last_dim = 0
for k, v in self.buffer_input_dim.items():
next_obs_action[k] = next_obs[..., last_dim:v+last_dim]
last_dim += v
next_obs_action[self.buffer_input_policy_name] = next_action
next_q = target_q_net(next_obs_action).min(0).values.unsqueeze(-1)
except:
pass
target_q = (data['reward'] - self.config['reward_uncertainty_weight'] * penalty) + gamma * (next_q - alpha * next_action_log_prob) * (1 - data['done'])
out_target_q = next_q.mean()
if torch.any(data['done']):
max_done_q = torch.max(next_q[torch.where(data['done'])[0]]).item()
max_q = torch.max(next_q[torch.where(1-data['done'])[0]]).item()
else:
max_done_q = -np.inf
max_q = torch.max(next_q[torch.where(1-data['done'])[0]]).item()
obs_action = torch.cat([obs, action], dim=-1)
if hasattr(target_q_net, 'conv_ts_node'):
obs_action = Batch()
last_dim = 0
for k, v in self.buffer_input_dim.items():
obs_action[k] = obs[..., last_dim:v+last_dim]
last_dim += v
obs_action[self.buffer_input_policy_name] = action
q_values = q_net(obs_action)
out_q_values = q_values.mean()
# q_limit_loss = (q_values.mean() - 10)**2
critic_loss = ((q_values - target_q.view(-1)) ** 2).mean(dim=1).sum(dim=0) #+ 0.5 * q_limit_loss
critic_optimizer.zero_grad()
critic_loss.backward()
for key, para in q_net.named_parameters():
c_grad = torch.norm(para.grad)
break
critic_optimizer.step()
# update target networks by polyak averaging.
with torch.no_grad():
for p, p_targ in zip(q_net.parameters(), target_q_net.parameters()):
# NB: We use an in-place operations "mul_", "add_" to update target
# params, as opposed to "mul" and "add", which would make new tensors.
p_targ.data.mul_(polyak)
p_targ.data.add_((1 - polyak) * p.data)
# update actor ----------------------------------------
if pre_train_critic:
actor_loss = torch.tensor([0.0])
a_grad = torch.tensor([0.0])
kl_loss = torch.tensor([0.0])
q = torch.tensor([0.0])
else:
# SAC update actor
if hasattr(list(self.train_nodes.values())[0], 'custom_node') or hasattr(list(self.train_nodes.values())[0], 'ts_conv_node') or self.input_is_dict:
input_dict = dict()
last_dim = 0
for k, v in self.buffer_input_dim.items():
input_dict[k] = obs[..., last_dim:v+last_dim]
last_dim += v
action_dist = target_policy(input_dict)
else:
action_dist = target_policy(obs)
new_action = action_dist.rsample()
action_log_prob = action_dist.log_prob(new_action)
new_obs_action = torch.cat([obs, new_action], dim=-1)
if hasattr(target_q_net, 'conv_ts_node'):
new_obs_action = Batch()
last_dim = 0
for k, v in self.buffer_input_dim.items():
new_obs_action[k] = obs[..., last_dim:v+last_dim]
last_dim += v
new_obs_action[self.buffer_input_policy_name] = new_action
q = q_net(new_obs_action).min(0).values.unsqueeze(-1)
actor_loss = - q.mean() + alpha * action_log_prob.mean()
# compute kl loss
if self.config['w_kl']>0:
if hasattr(list(self.train_nodes.values())[0], 'custom_node')or hasattr(list(self.train_nodes.values())[0], 'ts_conv_node') or self.input_is_dict:
input_dict = dict()
last_dim = 0
for k, v in self.buffer_input_dim.items():
input_dict[k] = obs[..., last_dim:v+last_dim]
last_dim += v
_kl_q = self.bc_init_net(input_dict)
_kl_p = target_policy(input_dict)
else:
_kl_q = self.bc_init_net(obs)
_kl_p = target_policy(obs)
kl_loss = kl_divergence(_kl_q, _kl_p).mean() * self.config['w_kl']
kl_loss_mean = kl_loss.mean()
if torch.isinf(kl_loss_mean) or torch.isnan(kl_loss_mean):
logger.info(f'Detect nan or inf in kl_loss, skip this batch! (loss : {kl_loss}, epoch : {self._epoch_cnt})')
actor_kl_loss = actor_loss
else:
actor_kl_loss = actor_loss + kl_loss
else:
actor_kl_loss = actor_loss
kl_loss = torch.zeros((1,))
actor_optimizer.zero_grad()
actor_kl_loss.backward()
for key, para in target_policy.named_parameters():
a_grad = torch.norm(para.grad)
# if a_grad<1e-10:
# breakpoint()
break
actor_optimizer.step()
info = {
"SAC/critic_loss": critic_loss.item(),
"SAC/actor_loss": actor_loss.item(),
"SAC/critic_grad": c_grad.item(),
"SAC/actor_grad": a_grad.item(),
"SAC/kl_loss": kl_loss.item(),
"SAC/critic_q_values": out_q_values.item(),
"SAC/critic_target_q": out_target_q.item(),
"SAC/actor_q": q.mean().item(),
"SAC/batch_reward": data['reward'].mean().item(),
"SAC/batch_reward_max": data['reward'].max().item(),
"SAC/max_done_q": max_done_q,
"SAC/max_q": max_q,
f"SAC/batch_penalty": penalty_record,
}
return info