Source code for revive.algo.policy.sac

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

import torch
import numpy as np
from copy import deepcopy

from revive.computation.dists import kl_divergence
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.batch import Batch
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.data.dataset import data_creator
from revive.algo.policy.base import PolicyOperator, catch_error


[docs]class ReplayBuffer: """ A simple FIFO experience replay buffer for SAC agents. """ def __init__(self, buffer_size): self.data = None self.buffer_size = int(buffer_size)
[docs] def put(self, batch_data : Batch): batch_data.to_torch(device='cpu') if self.data is None: self.data = batch_data else: self.data.cat_(batch_data) if len(self) > self.buffer_size: self.data = self.data[len(self) - self.buffer_size : ]
[docs] def __len__(self): if self.data is None: return 0 return self.data.shape[0]
[docs] def sample(self, batch_size): assert len(self) > 0, 'Cannot sample from an empty buffer!' indexes = np.random.randint(0, len(self), size=(batch_size)) return self.data[indexes]
[docs]class SACOperator(PolicyOperator): """ A class used to train platform policy. """ PARAMETER_DESCRIPTION = [ { "name" : "sac_batch_size", "description": "Batch size of training process.", "abbreviation" : "pbs", "type" : int, "default" : 1024, 'doc': True, }, { "name" : "policy_bc_epoch", "type" : int, "default" : 0, }, { "name" : "sac_epoch", "description": "Number of epcoh for the training process.", "abbreviation" : "bep", "type" : int, "default" : 200, 'doc': True, }, { "name" : "sac_steps_per_epoch", "description": "The number of update rounds of sac in each epoch.", "abbreviation" : "sspe", "type" : int, "default" : 200, 'doc': True, }, { "name" : "sac_rollout_horizon", "abbreviation" : "srh", "type" : int, "default" : 20, 'doc': True, }, { "name" : "policy_hidden_features", "description": "Number of neurons per layer of the policy network.", "abbreviation" : "phf", "type" : int, "default" : 256, 'doc': True, }, { "name" : "policy_hidden_layers", "description": "Depth of policy network.", "abbreviation" : "phl", "type" : int, "default" : 4, 'doc': True, }, { "name" : "policy_backbone", "description": "Backbone of policy network.", "abbreviation" : "pb", "type" : str, "default" : "mlp", 'doc': True, }, { "name" : "policy_hidden_activation", "description": "hidden_activation of policy network.", "abbreviation" : "pha", "type" : str, "default" : "leakyrelu", 'doc': True, }, { "name" : "q_hidden_features", "abbreviation" : "qhf", "type" : int, "default" : 256, }, { "name" : "q_hidden_layers", "abbreviation" : "qhl", "type" : int, "default" : 2, }, { "name" : "num_q_net", "abbreviation" : "nqn", "type" : int, "default" : 4, }, { "name" : "buffer_size", "description": "Size of the buffer to store data.", "abbreviation" : "bfs", "type" : int, "default" : 1e6, 'doc': True, }, { "name" : "w_kl", "type" : float, "default" : 1.0, }, { "name" : "gamma", "type" : float, "default" : 0.99, }, { "name" : "alpha", "type" : float, "default" : 0.2, }, { "name" : "polyak", "type" : float, "default" : 0.99, }, { "name" : "batch_ratio", "type" : float, "default" : 1, }, { "name" : "g_lr", "description": "Initial learning rate of the training process.", "type" : float, "default" : 4e-5, "search_mode" : "continuous", "search_values" : [1e-6, 1e-3], 'doc': True, }, { "name" : "interval", "description": "interval step for index removing.", "type" : int, "default" : 0 }, { "name" : "generate_deter", "description": "deterministic of generator rollout", "type" : int, "default" : 0 }, { "name" : "reward_uncertainty_weight", "description": "Reward uncertainty weight(MOPO)", "type" : float, "default" : 0, "search_mode" : "continuous", }, ] @property def policy(self): if isinstance(self.train_models, list) or isinstance(self.train_models, tuple): assert len(self.train_models[:-2])==1 return self.train_models[:-2] else: assert len(self.train_models[:-2])==1 return self.train_models @property def val_policy(self): if isinstance(self.val_models, list) or isinstance(self.val_models, tuple): assert len(self.val_models[:-2])==1 return self.val_models[:-2] else: assert len(self.val_models[:-2])==1 return self.val_models
[docs] def model_creator(self, config, nodes): """ Create model including platform policy and value net. :return: env model, platform policy, value net """ graph = config['graph'] policy_name = config['target_policy_name'][0] #TODO: target_policy only one future support multi policy total_dims = config['total_dims'] input_dim = get_input_dim_from_graph(graph, policy_name, total_dims) if config['behavioral_policy_init'] and self.behaviour_policys: for behaviour_policy in self.behaviour_policys: target_policy = deepcopy(behaviour_policy) target_policy.requires_grad_(True) else: if self.config.get("policy_bc_epoch", 0) == 0: logger.info(f'user initiate a brand new policy network, so the policy_bc_epoch are set to 50') # logger.warning(f'as policy_bc_epoch=0 and behaviour_policys=[], w_kl is forcefully rested to 0') self.config['policy_bc_epoch'] = 50 assert len(nodes.keys()) == 1 #SAC now is only support single policy instead of multi-policy. for (policy_name, node) in nodes.items(): node_input_dim = get_input_dim_from_graph(graph, policy_name, total_dims) logger.info(f"Policy Backbone: {config['policy_backbone'],}") node.initialize_network(node_input_dim, total_dims[policy_name]['output'], hidden_features=config['policy_hidden_features'], hidden_layers=config['policy_hidden_layers'], hidden_activation=config['policy_hidden_activation'], backbone_type=config['policy_backbone'], dist_config=config['dist_configs'][policy_name], is_transition=False) target_policy = node.get_network() input_dim += total_dims[policy_name]['input'] q_net = VectorizedCritic(input_dim, 1, config['q_hidden_features'], config['q_hidden_layers'], config['num_q_net']) target_q_net = deepcopy(q_net) target_q_net.requires_grad_(False) # q_net1 = MLP(input_dim, 1, config['q_hidden_features'], config['q_hidden_layers']) # q_net2 = MLP(input_dim, 1, config['q_hidden_features'], config['q_hidden_layers']) # target_q_net1 = deepcopy(q_net1) # target_q_net2 = deepcopy(q_net2) # target_q_net1.requires_grad_(False) # target_q_net2.requires_grad_(False) # num_q_net = config['num_q_net'] # logger.warning(f'{num_q_net} critic nets are generated!') return [target_policy, q_net, target_q_net] #q_net1, q_net2, target_q_net1, target_q_net2 ] #q_net, target_q_net]
[docs] def optimizer_creator(self, models, config): """ :return: generator optimizer including platform policy optimizer and value net optimizer """ target_policy, q_net, target_q_net = models actor_optimizor = torch.optim.Adam(target_policy.parameters(), lr=config['g_lr']) critic_optimizor = torch.optim.Adam(q_net.parameters(), lr=config['g_lr']) # return actor_optimizor, critic_optimizor target_policys = models[:-2] optimizers = [] optimizers.append(actor_optimizor) optimizers.append(critic_optimizor) if self.config["policy_bc_epoch"] >=1: models_params = [] for target_policy in target_policys: models_params.append({'params': target_policy.parameters(), 'lr': 1e-3}) bc_policy_optimizer = torch.optim.Adam(models_params) optimizers.append(bc_policy_optimizer) # target_policy, q_net1, q_net2, _, _ = models # actor_optimizor = torch.optim.Adam(target_policy.parameters(), lr=config['g_lr']) # critic_optimizor = torch.optim.Adam([*q_net1.parameters(), *q_net2.parameters()], lr=config['g_lr']) # optimizers = [actor_optimizor, critic_optimizor] return optimizers
[docs] def data_creator(self, config): config[BATCH_SIZE] = config['sac_batch_size'] return data_creator(config, training_mode='trajectory', training_horizon=config['sac_rollout_horizon'], training_is_sample=True, val_horizon=config['test_horizon'], double=self.double_validation)
@catch_error def setup(self, config): # super(SACOperator, self).setup(config) self.replay_buffer_train = ReplayBuffer(config['buffer_size']) self.expert_buffer_train = ReplayBuffer(config['buffer_size']) if self.double_validation: self.replay_buffer_val = ReplayBuffer(config['buffer_size']) self.expert_buffer_val = ReplayBuffer(config['buffer_size']) @catch_error def train_batch(self, expert_data, batch_info, scope='train'): if scope == 'train': #[target_policy, q_net1, q_net2, q_net3, q_net4, target_q_net1, target_q_net2, target_q_net3, target_q_net4] target_policy = self.train_models[:-2] q_net, target_q_net = self.train_models[-2:] # target_policy = self.train_models[:-4] # _, q_net1, q_net2, target_q_net1, target_q_net2 = self.train_models if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0): actor_optimizer = self.train_optimizers[2] else: actor_optimizer = self.train_optimizers[0] critic_optimizer = self.train_optimizers[1] envs = self.envs_train buffer = self.replay_buffer_train buffer_expert = self.expert_buffer_train else: assert self.double_validation target_policy = self.val_models[:-2] q_net, target_q_net = self.val_models[-2:] # target_policy = self.val_models[:-4] # _, q_net1, q_net2, target_q_net1, target_q_net2 = self.val_models # target_policy, q_net1, q_net2, target_q_net1, target_q_net2 = self.val_models if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0): actor_optimizer = self.val_optimizers[2] else: actor_optimizer = self.val_optimizers[0] critic_optimizer = self.val_optimizers[1] envs = self.envs_val buffer = self.replay_buffer_val buffer_expert = self.expert_buffer_val expert_data.to_torch(device=self._device) assert len(expert_data.shape) == 3 # the expert_data should be with three dimensions as [horizon, batch, features] generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['sac_rollout_horizon'], deterministic=bool(self.config['generate_deter']), clip=True) if 'done' in self._graph.graph_dict.keys(): #expert data process temp_done_expert = self._processor.deprocess_single_torch(expert_data['done'], 'done') assert temp_done_expert.shape == expert_data['done'].shape expert_data['done'] = temp_done_expert #generate data process temp_done_generate = self._processor.deprocess_single_torch(generated_data['done'], 'done') assert temp_done_generate.shape == generated_data['done'].shape generated_data['done'] = temp_done_generate assert torch.sum(expert_data['done'])>=0 assert torch.sum(generated_data['done'])>=0 else: #expert data process expert_data['done'] = torch.zeros(expert_data.shape[:-1]+[1]).to(expert_data[self._graph.leaf[0]].device) #generate data process generated_data['done'] = torch.zeros(generated_data.shape[:-1]+[1]).to(generated_data[self._graph.leaf[0]].device) assert list(expert_data['done'].shape)[:-1] == expert_data.shape[:-1] assert list(generated_data['done'].shape)[:-1] == generated_data.shape[:-1] assert torch.sum(expert_data['done']).item()==0 # done is not in the decision-flow, and should all be zeros assert torch.sum(generated_data['done']).item()==0 # done is not in the decision-flow, and should all be zeros all_index = np.arange(0, self.config[BATCH_SIZE], 1) if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0): model_index = np.random.choice(all_index,0) expert_index= all_index buffer = ReplayBuffer(self.config['buffer_size']) else: model_index = all_index #np.random.choice(all_index,int(self.config[BATCH_SIZE]*1), replace=False) #np.random.choice(all_index,int(self.config[BATCH_SIZE]*(self.config['batch_ratio'])), replace=False) expert_index= all_index #np.random.choice(all_index,int(self.config[BATCH_SIZE]*1), replace=False)#np.random.choice(all_index,self.config[BATCH_SIZE]-int(self.config[BATCH_SIZE]*(self.config['batch_ratio'])), replace=False) if True:# if model_index.shape[0]: for i in range(generated_data.shape[0] - 1): #TODO: policy name should be set as list of multi policy name next_remove_index = np.array(torch.where(generated_data['done'][i]!=0)[0].tolist(), dtype=model_index.dtype) try: if self.config.get("interval", 0) != 0 : interval_remove_index = np.array(torch.where(generated_data['done'][i:i+self.config['interval']]!=0)[1].tolist(), dtype=model_index.dtype) __interval = self.config['interval'] logger.warning(f'interval-done, in the epoch: {self._epoch_cnt}, at step: {i}') interval_remove_index = None except: interval_remove_index = None buffer.put(Batch({ 'obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(generated_data[ i ]).values()), dim=-1)[model_index], 'next_obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(generated_data[i+1]).values()), dim=-1)[model_index], 'action' : generated_data[i][self.policy_name[0]][model_index], 'reward' : generated_data[i]['reward'][model_index], 'done' : generated_data[i]['done'][model_index], })) pre_size = model_index.shape[0] model_index = np.setdiff1d(model_index, next_remove_index) if not interval_remove_index is None: model_index = np.setdiff1d(model_index, interval_remove_index) logger.info(f'Detect interval_remove_index done with interval {__interval}, with shape {interval_remove_index.shape[0]}') post_size = model_index.shape[0] if post_size<pre_size: logger.info(f'size of model_index is descent, in the epoch: {self._epoch_cnt}, at step: {i}') if True: #if expert_index.shape[0]: #generate rewards for expert data expert_rewards = generate_rewards(expert_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data))) for i in range(expert_data.shape[0] - 1): #TODO: policy name should be set as list of multi policy name next_remove_index = np.array(torch.where(expert_data[i]['done']!=0)[0].tolist(), dtype=expert_index.dtype) if next_remove_index.shape[0] != 0: logger.info(f'Detect expert data traj-done, in the epoch: {self._epoch_cnt}, at step: {i}') buffer_expert.put(Batch({ 'obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(expert_data[ i ]).values()), dim=-1)[expert_index], 'next_obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(expert_data[i+1]).values()), dim=-1)[expert_index], 'action' : expert_data[i][self.policy_name[0]][expert_index], 'reward' : expert_rewards[i]['reward'][expert_index], 'done' : expert_data[i]['done'][expert_index], })) pre_size = expert_index.shape[0] expert_index = np.setdiff1d(expert_index, next_remove_index) post_size = expert_index.shape[0] if post_size<pre_size: logger.info(f'size of expert_index is descent, in the epoch: {self._epoch_cnt}, at step: {i}') # logger.info(f'buffer_gen length : {len(buffer)}, buffer_real length : {len(buffer_expert)}') assert len(target_policy)==1 # SAC is now only support single policy instead of multi-policy. target_policy = target_policy[0] _info = self.sac(buffer, buffer_expert, target_policy, q_net, target_q_net, #q_net1, q_net2, target_q_net1, target_q_net2, #q_net, target_q_net, #q_net3, q_net4, target_q_net3, target_q_net4, gamma=self.config['gamma'], alpha=self.config['alpha'], polyak=self.config['polyak'], actor_optimizer=actor_optimizer, critic_optimizer=critic_optimizer) info = _info for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k) self._stop_flag = self._stop_flag or self._epoch_cnt >= self.config['sac_epoch'] info = self._early_stop(info) return info @catch_error def bc_train_batch(self, expert_data, batch_info, scope='train'): if scope == 'train': return self.train_batch(expert_data, batch_info=batch_info, scope='train') else: assert self.double_validation return self.train_batch(expert_data, batch_info=batch_info, scope='val') ''' # if scope == 'train': # target_policy = self.train_models[:-4] # optimizer = self.train_optimizers[-1] # else: # assert self.double_validation # target_policy = self.val_models[:-4] # optimizer = self.val_optimizers[-1] # expert_data.to_torch(device=self._device) # loss = 0 # for policy_name, _target_policy in zip(self.policy_name, target_policy): # _target_policy.reset() # with torch.no_grad(): # state = get_input_from_graph(self._graph, policy_name , expert_data) # action_dist = _target_policy(state) # action_log_prob = action_dist.log_prob(expert_data[policy_name]) # if policy_name + "_isnan_index_" in expert_data.keys(): # isnan_index = 1 - expert_data[policy_name + "_isnan_index_"] # policy_loss = - (action_log_prob*isnan_index).mean() # else: # isnan_index = None # policy_loss = - action_log_prob.mean() # loss += policy_loss # optimizer.zero_grad() # loss.backward() # optimizer.step() # info = { # "SAC/policy_bc_loss" : loss.mean().item(), # } # return info '''
[docs] def sac(self, buffer, buffer_real, target_policy, q_net, target_q_net, gamma=0.99, alpha=0.2, polyak=0.99, actor_optimizer=None, critic_optimizer=None): if self._epoch_cnt <= (1 + self.config.get("policy_bc_epoch", 0)): logger.warning(f'BC pre-trained the policy! Critic_pretrain is finished!') # for behaviour_policy in self.behaviour_policys: try: self.bc_init_net = deepcopy(target_policy) #deepcopy(self.behaviour_policys[0]) # except: logger.info(f'No behaviour policys can be used for calculating kl loss') for _ in range(self.config['sac_steps_per_epoch']): self._batch_cnt += 1 if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0): data = buffer_real.sample(int(self.config['sac_batch_size'])) data.to_torch(device=self._device) obs = data['obs'] action = data['action'] else: _data_real = buffer_real.sample(int(self.config['sac_batch_size']*(1-self.config['batch_ratio']))) _data_gen = buffer.sample(int(self.config['sac_batch_size']*(self.config['batch_ratio']))) data = Batch() for k, _ in _data_gen.items(): data[k] = torch.cat([_data_real[k], _data_gen[k]], dim=0) data.to_torch(device=self._device) obs = data['obs'] action = data['action'] # update critic ---------------------------------------- with torch.no_grad(): next_obs = data['next_obs'] next_action_dist = target_policy(next_obs) next_action = next_action_dist.sample() next_action_log_prob = next_action_dist.log_prob(next_action).unsqueeze(dim=-1) next_obs_action = torch.cat([next_obs, next_action], dim=-1) next_q = target_q_net(next_obs_action).min(0).values.unsqueeze(-1) target_q = data['reward'] + gamma * (next_q - alpha * next_action_log_prob) * (1 - data['done']) out_target_q = next_q.mean() if torch.any(data['done']): max_done_q = torch.max(next_q[torch.where(data['done'])[0]]).item() max_q = torch.max(next_q[torch.where(1-data['done'])[0]]).item() else: max_done_q = -np.inf max_q = torch.max(next_q[torch.where(1-data['done'])[0]]).item() obs_action = torch.cat([obs, action], dim=-1) q_values = q_net(obs_action) out_q_values = q_values.mean() # q_limit_loss = (q_values.mean() - 10)**2 critic_loss = ((q_values - target_q.view(-1)) ** 2).mean(dim=1).sum(dim=0) #+ 0.5 * q_limit_loss critic_optimizer.zero_grad() critic_loss.backward() for key, para in q_net.named_parameters(): c_grad = torch.norm(para.grad) break critic_optimizer.step() # update target networks by polyak averaging. with torch.no_grad(): for p, p_targ in zip(q_net.parameters(), target_q_net.parameters()): # NB: We use an in-place operations "mul_", "add_" to update target # params, as opposed to "mul" and "add", which would make new tensors. p_targ.data.mul_(polyak) p_targ.data.add_((1 - polyak) * p.data) # update actor ---------------------------------------- if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0): # bc method action_dist = target_policy(obs) mean, var = action_dist.mode, action_dist.std inv_var = 1/var # Average over batch and dim, sum over ensembles. mse_loss_inv = (torch.pow(mean - data['action'], 2) * inv_var).mean(dim=(0,1)) var_loss = var.mean(dim=(0,1)) actor_loss = mse_loss_inv.sum() + var_loss.sum() actor_optimizer.zero_grad() actor_loss.backward() for key, para in target_policy.named_parameters(): a_grad = torch.norm(para.grad) break actor_optimizer.step() kl_loss = torch.tensor([0.0]) q = torch.tensor([0.0]) else: # SAC update actor action_dist = target_policy(obs) new_action = action_dist.rsample() action_log_prob = action_dist.log_prob(new_action) new_obs_action = torch.cat([obs, new_action], dim=-1) q = q_net(new_obs_action).min(0).values.unsqueeze(-1) actor_loss = - q.mean() + alpha * action_log_prob.mean() # compute kl loss if self.config['w_kl']>0: _kl_q = self.bc_init_net(obs) _kl_p = target_policy(obs) kl_loss = kl_divergence(_kl_q, _kl_p).mean() * self.config['w_kl'] if torch.any(torch.isinf(kl_loss)) or torch.any(torch.isnan(kl_loss)): logger.info(f'Detect nan or inf in kl_loss, skip this batch! (loss : {kl_loss}, epoch : {self._epoch_cnt})') actor_kl_loss = actor_loss else: actor_kl_loss = actor_loss + kl_loss else: actor_kl_loss = actor_loss kl_loss = torch.zeros((1,)) actor_optimizer.zero_grad() actor_kl_loss.backward() for key, para in target_policy.named_parameters(): a_grad = torch.norm(para.grad) # if a_grad<1e-10: # breakpoint() break actor_optimizer.step() info = { "SAC/critic_loss": critic_loss.item(), "SAC/actor_loss": actor_loss.item(), "SAC/critic_grad": c_grad.item(), "SAC/actor_grad": a_grad.item(), "SAC/kl_loss": kl_loss.item(), "SAC/critic_q_values": out_q_values.item(), "SAC/critic_target_q": out_target_q.item(), "SAC/actor_q": q.mean().item(), "SAC/batch_reward": data['reward'].mean().item(), "SAC/batch_reward_max": data['reward'].max().item(), "SAC/max_done_q": max_done_q, "SAC/max_q": max_q } return info