Source code for revive.algo.policy.sac

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

import torch
import numpy as np
from copy import deepcopy

from revive.computation.dists import kl_divergence
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.batch import Batch
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.data.dataset import data_creator
from revive.algo.policy.base import PolicyOperator, catch_error


[docs] class ReplayBuffer: """ A simple FIFO experience replay buffer for SAC agents. """ def __init__(self, buffer_size): self.data = None self.buffer_size = int(buffer_size)
[docs] def put(self, batch_data : Batch): batch_data.to_torch(device='cpu') if self.data is None: self.data = batch_data else: self.data.cat_(batch_data) if len(self) > self.buffer_size: self.data = self.data[len(self) - self.buffer_size : ]
[docs] def __len__(self): if self.data is None: return 0 return self.data.shape[0]
[docs] def sample(self, batch_size): assert len(self) > 0, 'Cannot sample from an empty buffer!' indexes = np.random.randint(0, len(self), size=(batch_size)) return self.data[indexes]
[docs] class SACOperator(PolicyOperator): """ A class used to train platform policy. """ NAME = "SAC" critic_pre_trained = False PARAMETER_DESCRIPTION = [ { "name" : "sac_batch_size", "description": "Batch size of training process.", "abbreviation" : "pbs", "type" : int, "default" : 1024, 'doc': True, }, { "name" : "policy_bc_epoch", "description": "pre-train policy with setting epoch", "type" : int, "default" : 0, "doc": True, }, { "name" : "sac_epoch", "description": "Number of epcoh for the training process.", "abbreviation" : "bep", "type" : int, "default" : 1000, 'doc': True, }, { "name" : "sac_steps_per_epoch", "description": "The number of update rounds of sac in each epoch.", "abbreviation" : "sspe", "type" : int, "default" : 200, 'doc': True, }, { "name" : "sac_rollout_horizon", "abbreviation" : "srh", "type" : int, "default" : 20, 'doc': True, }, { "name" : "policy_hidden_features", "description": "Number of neurons per layer of the policy network.", "abbreviation" : "phf", "type" : int, "default" : 256, 'doc': True, }, { "name" : "policy_hidden_layers", "description": "Depth of policy network.", "abbreviation" : "phl", "type" : int, "default" : 4, 'doc': True, }, { "name" : "policy_backbone", "description": "Backbone of policy network. [mlp, res, ft_transformer]", "abbreviation" : "pb", "type" : str, "default" : "res", 'doc': True, }, { "name" : "policy_hidden_activation", "description": "hidden_activation of policy network.", "abbreviation" : "pha", "type" : str, "default" : "leakyrelu", 'doc': True, }, { "name" : "q_hidden_features", "abbreviation" : "qhf", "type" : int, "default" : 256, }, { "name" : "q_hidden_layers", "abbreviation" : "qhl", "type" : int, "default" : 2, }, { "name" : "num_q_net", "abbreviation" : "nqn", "type" : int, "default" : 4, }, { "name" : "buffer_size", "description": "Size of the buffer to store data.", "abbreviation" : "bfs", "type" : int, "default" : 1e6, 'doc': True, }, { "name" : "w_kl", "type" : float, "default" : 1.0, }, { "name" : "gamma", "type" : float, "default" : 0.99, }, { "name" : "alpha", "type" : float, "default" : 0.2, }, { "name" : "polyak", "type" : float, "default" : 0.99, }, { "name" : "batch_ratio", "type" : float, "default" : 1, }, { "name" : "g_lr", "description": "Initial learning rate of the training process.", "type" : float, "default" : 4e-5, "search_mode" : "continuous", "search_values" : [1e-6, 1e-3], 'doc': True, }, { "name" : "interval", "description": "interval step for index removing.", "type" : int, "default" : 0 }, { "name" : "generate_deter", "description": "deterministic of generator rollout", "type" : int, "default" : 1 }, { "name" : "filter", "type" : bool, "default" : False, }, { "name" : "candidate_num", "type" : int, "default" : 50, }, { "name" : "ensemble_size", "type" : int, "default" : 50, }, { "name" : "ensemble_choosing_interval", "type" : int, "default" : 10, }, { "name" : "disturbing_transition_function", "description": "Disturbing the network node in policy learning", "type": bool, "default": False, }, { "name" : "disturbing_nodes", "description": "Disturbing the network node in policy learning", "type": list, "default": 'auto', }, { "name" : "disturbing_net_num", "description": "Disturbing the network node in policy learning", "type": int, "default": 100, }, { "name" : "disturbing_weight", "description": "Disturbing the network node in policy learning", "type": float, "default": 0.05, }, { "name" : "critic_pretrain", "type": bool, "default": True, }, { "name" : "reward_uncertainty_weight", "description": "Reward uncertainty weight(MOPO)", "type" : float, "default" : 0, "search_mode" : "continuous", }, { "name" : "penalty_sample_num", "type": int, "default": 20, }, { "name" : "penalty_type", "type": str, "default": "None", }, { "name" : "ts_conv_nodes", "type" : list, "default" : "auto", }, ]
[docs] def model_creator(self, nodes): """ Create model including platform policy and value net. :return: env model, platform policy, value net """ models = super().model_creator(nodes) assert len(models) == 1, f"{self.NAME} don't support multi target polciy." graph = self._graph # TODO: target_policy only one future support multi policy policy_name = self.policy_name[0] total_dims = self.config['total_dims'] input_dim = get_input_dim_from_graph(graph, policy_name, total_dims) input_dim += total_dims[policy_name]['input'] additional_kwargs = {} if self.config['ts_conv_nodes'] != 'auto' and policy_name in self.config['ts_conv_nodes']: """"get""" temp_input_dim = get_input_dim_dict_from_graph(graph, policy_name, total_dims) temp_input_dim[policy_name] = total_dims[policy_name]['input'] has_ts = any('ts' in element for element in temp_input_dim) if not has_ts: input_dim = temp_input_dim pass else: input_dim, input_dim_config, ts_conv_net_config = \ dict_ts_conv(graph, temp_input_dim, total_dims, self.config, net_hidden_features=self.config['q_hidden_features']) additional_kwargs['ts_conv_config'] = input_dim_config additional_kwargs['ts_conv_net_config'] = ts_conv_net_config else: input_dim = get_input_dim_dict_from_graph(graph, policy_name, total_dims) input_dim[policy_name] = total_dims[policy_name]['input'] pass # breakpoint() q_net = Value_Net_VectorizedCritic(input_dim, self.config['q_hidden_features'], self.config['q_hidden_layers'], self.config['num_q_net'], **additional_kwargs) # q_net = VectorizedCritic(input_dim, # 1, # self.config['q_hidden_features'], # self.config['q_hidden_layers'], # self.config['num_q_net']) target_q_net = deepcopy(q_net) target_q_net.requires_grad_(False) models = models + [ q_net, target_q_net] return models
[docs] def optimizer_creator(self, scope): """ :return: generator optimizer including platform policy optimizer and value net optimizer """ # TODO: target_policy only one future support multi policy if scope == "train": target_policy = self.train_policy[0] q_net, _ = self.other_train_models else: target_policy = self.val_policy[0] q_net, _ = self.other_val_models optimizers = [] actor_optimizor = torch.optim.Adam(target_policy.parameters(), lr=self.config['g_lr']) critic_optimizor = torch.optim.Adam(q_net.parameters(), lr=self.config['g_lr']) optimizers.append(actor_optimizor) optimizers.append(critic_optimizor) return optimizers
[docs] def data_creator(self): self.config[BATCH_SIZE] = self.config['sac_batch_size'] return data_creator(self.config, training_mode='trajectory', training_horizon=self.config['sac_rollout_horizon'], training_is_sample=True, val_horizon=self.config['test_horizon'], double=self.double_validation)
@catch_error def setup(self, config): # super(SACOperator, self).setup(config) self.replay_buffer_train = ReplayBuffer(config['buffer_size']) self.expert_buffer_train = ReplayBuffer(config['buffer_size']) if self.double_validation: self.replay_buffer_val = ReplayBuffer(config['buffer_size']) self.expert_buffer_val = ReplayBuffer(config['buffer_size']) @catch_error def before_validate_epoch(self): if self._epoch_cnt >= self.config['policy_bc_epoch'] + self.config['sac_epoch']: self._stop_flag = True
[docs] def done_process(self, expert_data=None, generated_data=None, generate_done=True, expert_done=True): if 'done' in self._graph.graph_dict.keys(): if expert_done: #expert data process temp_done_expert = self._processor.deprocess_single_torch(expert_data['done'], 'done') assert temp_done_expert.shape == expert_data['done'].shape expert_data['done'] = temp_done_expert assert torch.sum(expert_data['done'])>=0 else: expert_data = None if generate_done: #generate data process temp_done_generate = self._processor.deprocess_single_torch(generated_data['done'], 'done') assert temp_done_generate.shape == generated_data['done'].shape generated_data['done'] = temp_done_generate assert torch.sum(generated_data['done'])>=0 else: generated_data = None else: if expert_done: #expert data process expert_data['done'] = torch.zeros(expert_data.shape[:-1]+[1]).to(expert_data[self._graph.leaf[0]].device) assert list(expert_data['done'].shape)[:-1] == expert_data.shape[:-1] assert torch.sum(expert_data['done']).item()==0 # done is not in the decision-flow, and should all be zeros else: expert_data = None if generate_done: #generate data process generated_data['done'] = torch.zeros(generated_data.shape[:-1]+[1]).to(generated_data[self._graph.leaf[0]].device) assert list(generated_data['done'].shape)[:-1] == generated_data.shape[:-1] assert torch.sum(generated_data['done']).item()==0 # done is not in the decision-flow, and should all be zeros else: generated_data = None return expert_data, generated_data
[docs] def buffer_process(self, generated_data=None, buffer=None, model_index=None, expert_data=None, buffer_expert=None, expert_index=None, generate_buffer=True, expert_buffer=True): if generate_buffer:# if model_index.shape[0]: for i in range(generated_data.shape[0] - 1): #TODO: policy name should be set as list of multi policy name next_remove_index = np.array(torch.where(generated_data['done'][i]!=0)[0].tolist(), dtype=model_index.dtype) try: if self.config.get("interval", 0) != 0 : interval_remove_index = np.array(torch.where(generated_data['done'][i:i+self.config['interval']]!=0)[1].tolist(), dtype=model_index.dtype) __interval = self.config['interval'] logger.warning(f'interval-done, in the epoch: {self._epoch_cnt}, at step: {i}') interval_remove_index = None except: interval_remove_index = None buffer.put(Batch({ 'obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(generated_data[ i ]).values()), dim=-1)[model_index], 'next_obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(generated_data[i+1]).values()), dim=-1)[model_index], 'action' : generated_data[i][self.policy_name[0]][model_index], 'reward' : generated_data[i]['reward'][model_index], 'done' : generated_data[i]['done'][model_index], })) self.buffer_input_dim = get_input_dim_dict_from_graph(self._graph, self.policy_name[0], self.config['total_dims']) pre_size = model_index.shape[0] model_index = np.setdiff1d(model_index, next_remove_index) if not interval_remove_index is None: model_index = np.setdiff1d(model_index, interval_remove_index) logger.info(f'Detect interval_remove_index done with interval {__interval}, with shape {interval_remove_index.shape[0]}') post_size = model_index.shape[0] if post_size<pre_size: logger.info(f'size of model_index is descent, in the epoch: {self._epoch_cnt}, at step: {i}') else: buffer = None if expert_buffer: #if expert_index.shape[0]: #generate rewards for expert data expert_rewards = generate_rewards(expert_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data))) for i in range(expert_data.shape[0] - 1): #TODO: policy name should be set as list of multi policy name next_remove_index = np.array(torch.where(expert_data[i]['done']!=0)[0].tolist(), dtype=expert_index.dtype) if next_remove_index.shape[0] != 0: logger.info(f'Detect expert data traj-done, in the epoch: {self._epoch_cnt}, at step: {i}') buffer_expert.put(Batch({ 'obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(expert_data[ i ]).values()), dim=-1)[expert_index], 'next_obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(expert_data[i+1]).values()), dim=-1)[expert_index], 'action' : expert_data[i][self.policy_name[0]][expert_index], 'reward' : expert_rewards[i]['reward'][expert_index], 'done' : expert_data[i]['done'][expert_index], })) self.buffer_input_dim = get_input_dim_dict_from_graph(self._graph, self.policy_name[0], self.config['total_dims']) self.buffer_input_policy_name = self.policy_name[0] pre_size = expert_index.shape[0] expert_index = np.setdiff1d(expert_index, next_remove_index) post_size = expert_index.shape[0] if post_size<pre_size: logger.info(f'size of expert_index is descent, in the epoch: {self._epoch_cnt}, at step: {i}') else: buffer_expert = None return buffer, buffer_expert
[docs] def pretrain_critic(self, expert_data, scope): logger.info(f'Pretraining Critic') if scope == 'train': target_policy = self.train_models[:-2] q_net, target_q_net = self.train_models[-2:] actor_optimizer = self.train_optimizers[0] critic_optimizer = self.train_optimizers[1] envs = self.envs_train buffer = self.replay_buffer_train buffer_expert = self.expert_buffer_train else: assert self.double_validation target_policy = self.val_models[:-2] q_net, target_q_net = self.val_models[-2:] actor_optimizer = self.val_optimizers[0] critic_optimizer = self.val_optimizers[1] envs = self.envs_val buffer = self.replay_buffer_val buffer_expert = self.expert_buffer_val expert_data.to_torch(device=self._device) assert len(expert_data.shape) == 3 # the expert_data should be with three dimensions as [horizon, batch, features] # generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['sac_rollout_horizon'], deterministic=bool(self.config['generate_deter']), clip=True) expert_data, _ = self.done_process(expert_data, generated_data=None, generate_done=False, expert_done=True) all_index = np.arange(0, min(self.config[BATCH_SIZE], expert_data.shape[1]), 1) expert_index= all_index _, buffer_expert = self.buffer_process(generated_data=None, buffer=None, model_index=None, expert_data=expert_data, buffer_expert=buffer_expert, expert_index=expert_index, generate_buffer=False, expert_buffer=True) # logger.info(f'buffer_gen length : {len(buffer)}, buffer_real length : {len(buffer_expert)}') assert len(target_policy)==1 # SAC is now only support single policy instead of multi-policy. target_policy = target_policy[0] _info = self.sac(buffer, buffer_expert, target_policy, q_net, target_q_net, gamma=self.config['gamma'], alpha=self.config['alpha'], polyak=self.config['polyak'], actor_optimizer=actor_optimizer, critic_optimizer=critic_optimizer, pre_train_critic=True) info = _info for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k) try: self.bc_init_net = deepcopy(target_policy) #deepcopy(self.behaviour_policys[0]) # if hasattr(list(self.train_nodes.values())[0], 'custom_node'): self.bc_init_net_custome_node = True else: self.bc_init_net_custome_node = False except: logger.info(f'No behaviour policys can be used for calculating kl loss') return info
[docs] def bc_train_batch(self, expert_data, batch_info, scope='train'): #决定是否有必要进行sac pre train critic的操作,判断依据target policy网络被进行了init初始化 info = {} bc_metrics = super().bc_train_batch(expert_data, batch_info, scope) info.update(bc_metrics) if self.config['critic_pretrain']: _info = self.pretrain_critic(expert_data, scope) info.update(_info) return info
@catch_error def train_batch(self, expert_data, batch_info, scope='train'): if scope == 'train': #[target_policy, q_net1, q_net2, q_net3, q_net4, target_q_net1, target_q_net2, target_q_net3, target_q_net4] target_policy = self.train_models[:-2] q_net, target_q_net = self.train_models[-2:] # target_policy = self.train_models[:-4] # _, q_net1, q_net2, target_q_net1, target_q_net2 = self.train_models actor_optimizer = self.train_optimizers[0] critic_optimizer = self.train_optimizers[1] envs = self.envs_train # List[env_dev] of length 1 buffer = self.replay_buffer_train buffer_expert = self.expert_buffer_train else: assert self.double_validation target_policy = self.val_models[:-2] q_net, target_q_net = self.val_models[-2:] # target_policy = self.val_models[:-4] # _, q_net1, q_net2, target_q_net1, target_q_net2 = self.val_models # target_policy, q_net1, q_net2, target_q_net1, target_q_net2 = self.val_models actor_optimizer = self.val_optimizers[0] critic_optimizer = self.val_optimizers[1] envs = self.envs_val # List[env_dev] of length 1 buffer = self.replay_buffer_val buffer_expert = self.expert_buffer_val expert_data.to_torch(device=self._device) assert len(expert_data.shape) == 3 # the expert_data should be with three dimensions as [horizon, batch, features] generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['sac_rollout_horizon'], deterministic=bool(self.config['generate_deter']), clip=True) expert_data, generated_data = self.done_process(expert_data, generated_data) all_index = np.arange(0, min(self.config[BATCH_SIZE], expert_data.shape[1]), 1) model_index = all_index #np.random.choice(all_index,int(self.config[BATCH_SIZE]*1), replace=False) #np.random.choice(all_index,int(self.config[BATCH_SIZE]*(self.config['batch_ratio'])), replace=False) expert_index= all_index #np.random.choice(all_index,int(self.config[BATCH_SIZE]*1), replace=False)#np.random.choice(all_index,self.config[BATCH_SIZE]-int(self.config[BATCH_SIZE]*(self.config['batch_ratio'])), replace=False) buffer, buffer_expert = self.buffer_process(generated_data, buffer, model_index, expert_data, buffer_expert, expert_index) # logger.info(f'buffer_gen length : {len(buffer)}, buffer_real length : {len(buffer_expert)}') assert len(target_policy)==1 # SAC is now only support single policy instead of multi-policy. target_policy = target_policy[0] if not hasattr(self, 'bc_init_net'): #由于sac需要利用第一个网络来计算kl_loss 所以需要维护一个self.bc_init_net try: self.bc_init_net = deepcopy(target_policy) #deepcopy(self.behaviour_policys[0]) # if hasattr(list(self.train_nodes.values())[0], 'custom_node'): self.bc_init_net_custome_node = True else: self.bc_init_net_custome_node = False except: logger.info(f'No behaviour policys can be used for calculating kl loss') _info = self.sac(buffer, buffer_expert, target_policy, q_net, target_q_net, #q_net1, q_net2, target_q_net1, target_q_net2, #q_net, target_q_net, #q_net3, q_net4, target_q_net3, target_q_net4, gamma=self.config['gamma'], alpha=self.config['alpha'], polyak=self.config['polyak'], actor_optimizer=actor_optimizer, critic_optimizer=critic_optimizer, scope=scope) info = _info for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k) return info
[docs] @ torch.no_grad() def compute_lcb(self, _gen_data: Batch, scope=None): if scope == 'train': envs = self.envs_train[0] # List[env_dev] of length 1 else: assert self.double_validation envs = self.envs_val[0] # List[env_dev] of length 1 gen_data = Batch({'obs': deepcopy(_gen_data['obs']), # [batch_size, dim] 'action': deepcopy(_gen_data['action']), # [batch_size, dim] }) gen_data.to_torch(device=self._device) if self.config["penalty_type"] == "filter": penalty = envs.filter_penalty(penalty_type="filter", data=gen_data, sample_num=1, clip=True) # [batch_size, 1] elif self.config["penalty_type"] == "filter_score_std": penalty = envs.filter_penalty(penalty_type="filter_score_std", data=gen_data, sample_num=self.config['penalty_sample_num'], clip=True) # [batch_size, 1] else: penalty = 0 return penalty
[docs] def sac(self, buffer, buffer_real, target_policy, q_net, target_q_net, gamma=0.99, alpha=0.2, polyak=0.99, actor_optimizer=None, critic_optimizer=None, pre_train_critic=False, scope=None): # if self._epoch_cnt <= (1 + self.config.get("policy_bc_epoch", 0)): # logger.warning(f'BC pre-trained the policy! Critic_pretrain is finished!') # # for behaviour_policy in self.behaviour_policys: for _ in range(self.config['sac_steps_per_epoch']): self._batch_cnt += 1 penalty = 0 penalty_record = 0 if pre_train_critic: data = buffer_real.sample(int(self.config['sac_batch_size'])) data.to_torch(device=self._device) obs = data['obs'] action = data['action'] else: real_data_size = int(self.config['sac_batch_size']*(1-self.config['batch_ratio'])) generated_data_size = int(self.config['sac_batch_size']*(self.config['batch_ratio'])) _data_real = buffer_real.sample(real_data_size) _data_gen = buffer.sample(generated_data_size) # [batch_size, dim] if self.config['reward_uncertainty_weight'] > 0: penalty = self.compute_lcb(_data_gen, scope) if isinstance(penalty, torch.Tensor): penalty_record = (self.config['reward_uncertainty_weight'] * penalty).mean().item() penalty = torch.cat([torch.zeros((real_data_size, 1), device=self._device, dtype=penalty.dtype), penalty], dim=0) # pad real data with 0 penalty data = Batch() for k, _ in _data_gen.items(): data[k] = torch.cat([_data_real[k], _data_gen[k]], dim=0) data.to_torch(device=self._device) obs = data['obs'] action = data['action'] # update critic ---------------------------------------- with torch.no_grad(): next_obs = data['next_obs'] if hasattr(list(self.train_nodes.values())[0], 'custom_node') or hasattr(list(self.train_nodes.values())[0], 'ts_conv_node') or self.input_is_dict: input_dict = dict() last_dim = 0 for k, v in self.buffer_input_dim.items(): input_dict[k] = next_obs[..., last_dim:last_dim+v] last_dim += v next_action_dist = target_policy(input_dict) else: next_action_dist = target_policy(next_obs) next_action = next_action_dist.sample() next_action_log_prob = next_action_dist.log_prob(next_action).unsqueeze(dim=-1) next_obs_action = torch.cat([next_obs, next_action], dim=-1) try: if hasattr(target_q_net, 'conv_ts_node'): next_obs_action = Batch() last_dim = 0 for k, v in self.buffer_input_dim.items(): next_obs_action[k] = next_obs[..., last_dim:v+last_dim] last_dim += v next_obs_action[self.buffer_input_policy_name] = next_action next_q = target_q_net(next_obs_action).min(0).values.unsqueeze(-1) except: pass target_q = (data['reward'] - self.config['reward_uncertainty_weight'] * penalty) + gamma * (next_q - alpha * next_action_log_prob) * (1 - data['done']) out_target_q = next_q.mean() if torch.any(data['done']): max_done_q = torch.max(next_q[torch.where(data['done'])[0]]).item() max_q = torch.max(next_q[torch.where(1-data['done'])[0]]).item() else: max_done_q = -np.inf max_q = torch.max(next_q[torch.where(1-data['done'])[0]]).item() obs_action = torch.cat([obs, action], dim=-1) if hasattr(target_q_net, 'conv_ts_node'): obs_action = Batch() last_dim = 0 for k, v in self.buffer_input_dim.items(): obs_action[k] = obs[..., last_dim:v+last_dim] last_dim += v obs_action[self.buffer_input_policy_name] = action q_values = q_net(obs_action) out_q_values = q_values.mean() # q_limit_loss = (q_values.mean() - 10)**2 critic_loss = ((q_values - target_q.view(-1)) ** 2).mean(dim=1).sum(dim=0) #+ 0.5 * q_limit_loss critic_optimizer.zero_grad() critic_loss.backward() for key, para in q_net.named_parameters(): c_grad = torch.norm(para.grad) break critic_optimizer.step() # update target networks by polyak averaging. with torch.no_grad(): for p, p_targ in zip(q_net.parameters(), target_q_net.parameters()): # NB: We use an in-place operations "mul_", "add_" to update target # params, as opposed to "mul" and "add", which would make new tensors. p_targ.data.mul_(polyak) p_targ.data.add_((1 - polyak) * p.data) # update actor ---------------------------------------- if pre_train_critic: actor_loss = torch.tensor([0.0]) a_grad = torch.tensor([0.0]) kl_loss = torch.tensor([0.0]) q = torch.tensor([0.0]) else: # SAC update actor if hasattr(list(self.train_nodes.values())[0], 'custom_node') or hasattr(list(self.train_nodes.values())[0], 'ts_conv_node') or self.input_is_dict: input_dict = dict() last_dim = 0 for k, v in self.buffer_input_dim.items(): input_dict[k] = obs[..., last_dim:v+last_dim] last_dim += v action_dist = target_policy(input_dict) else: action_dist = target_policy(obs) new_action = action_dist.rsample() action_log_prob = action_dist.log_prob(new_action) new_obs_action = torch.cat([obs, new_action], dim=-1) if hasattr(target_q_net, 'conv_ts_node'): new_obs_action = Batch() last_dim = 0 for k, v in self.buffer_input_dim.items(): new_obs_action[k] = obs[..., last_dim:v+last_dim] last_dim += v new_obs_action[self.buffer_input_policy_name] = new_action q = q_net(new_obs_action).min(0).values.unsqueeze(-1) actor_loss = - q.mean() + alpha * action_log_prob.mean() # compute kl loss if self.config['w_kl']>0: if hasattr(list(self.train_nodes.values())[0], 'custom_node')or hasattr(list(self.train_nodes.values())[0], 'ts_conv_node') or self.input_is_dict: input_dict = dict() last_dim = 0 for k, v in self.buffer_input_dim.items(): input_dict[k] = obs[..., last_dim:v+last_dim] last_dim += v _kl_q = self.bc_init_net(input_dict) _kl_p = target_policy(input_dict) else: _kl_q = self.bc_init_net(obs) _kl_p = target_policy(obs) kl_loss = kl_divergence(_kl_q, _kl_p).mean() * self.config['w_kl'] kl_loss_mean = kl_loss.mean() if torch.isinf(kl_loss_mean) or torch.isnan(kl_loss_mean): logger.info(f'Detect nan or inf in kl_loss, skip this batch! (loss : {kl_loss}, epoch : {self._epoch_cnt})') actor_kl_loss = actor_loss else: actor_kl_loss = actor_loss + kl_loss else: actor_kl_loss = actor_loss kl_loss = torch.zeros((1,)) actor_optimizer.zero_grad() actor_kl_loss.backward() for key, para in target_policy.named_parameters(): a_grad = torch.norm(para.grad) # if a_grad<1e-10: # breakpoint() break actor_optimizer.step() info = { "SAC/critic_loss": critic_loss.item(), "SAC/actor_loss": actor_loss.item(), "SAC/critic_grad": c_grad.item(), "SAC/actor_grad": a_grad.item(), "SAC/kl_loss": kl_loss.item(), "SAC/critic_q_values": out_q_values.item(), "SAC/critic_target_q": out_target_q.item(), "SAC/actor_q": q.mean().item(), "SAC/batch_reward": data['reward'].mean().item(), "SAC/batch_reward_max": data['reward'].max().item(), "SAC/max_done_q": max_done_q, "SAC/max_q": max_q, f"SAC/batch_penalty": penalty_record, } return info