''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import torch
import numpy as np
from copy import deepcopy
from revive.computation.dists import kl_divergence
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.batch import Batch
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.data.dataset import data_creator
from revive.algo.policy.base import PolicyOperator, catch_error
[docs]class ReplayBuffer:
"""
A simple FIFO experience replay buffer for SAC agents.
"""
def __init__(self, buffer_size):
self.data = None
self.buffer_size = int(buffer_size)
[docs] def put(self, batch_data : Batch):
batch_data.to_torch(device='cpu')
if self.data is None:
self.data = batch_data
else:
self.data.cat_(batch_data)
if len(self) > self.buffer_size:
self.data = self.data[len(self) - self.buffer_size : ]
[docs] def __len__(self):
if self.data is None: return 0
return self.data.shape[0]
[docs] def sample(self, batch_size):
assert len(self) > 0, 'Cannot sample from an empty buffer!'
indexes = np.random.randint(0, len(self), size=(batch_size))
return self.data[indexes]
[docs]class SACOperator(PolicyOperator):
"""
A class used to train platform policy.
"""
PARAMETER_DESCRIPTION = [
{
"name" : "sac_batch_size",
"description": "Batch size of training process.",
"abbreviation" : "pbs",
"type" : int,
"default" : 1024,
'doc': True,
},
{
"name" : "policy_bc_epoch",
"type" : int,
"default" : 0,
},
{
"name" : "sac_epoch",
"description": "Number of epcoh for the training process.",
"abbreviation" : "bep",
"type" : int,
"default" : 200,
'doc': True,
},
{
"name" : "sac_steps_per_epoch",
"description": "The number of update rounds of sac in each epoch.",
"abbreviation" : "sspe",
"type" : int,
"default" : 200,
'doc': True,
},
{
"name" : "sac_rollout_horizon",
"abbreviation" : "srh",
"type" : int,
"default" : 20,
'doc': True,
},
{
"name" : "policy_hidden_features",
"description": "Number of neurons per layer of the policy network.",
"abbreviation" : "phf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_hidden_layers",
"description": "Depth of policy network.",
"abbreviation" : "phl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "policy_backbone",
"description": "Backbone of policy network.",
"abbreviation" : "pb",
"type" : str,
"default" : "mlp",
'doc': True,
},
{
"name" : "policy_hidden_activation",
"description": "hidden_activation of policy network.",
"abbreviation" : "pha",
"type" : str,
"default" : "leakyrelu",
'doc': True,
},
{
"name" : "q_hidden_features",
"abbreviation" : "qhf",
"type" : int,
"default" : 256,
},
{
"name" : "q_hidden_layers",
"abbreviation" : "qhl",
"type" : int,
"default" : 2,
},
{
"name" : "num_q_net",
"abbreviation" : "nqn",
"type" : int,
"default" : 4,
},
{
"name" : "buffer_size",
"description": "Size of the buffer to store data.",
"abbreviation" : "bfs",
"type" : int,
"default" : 1e6,
'doc': True,
},
{
"name" : "w_kl",
"type" : float,
"default" : 1.0,
},
{
"name" : "gamma",
"type" : float,
"default" : 0.99,
},
{
"name" : "alpha",
"type" : float,
"default" : 0.2,
},
{
"name" : "polyak",
"type" : float,
"default" : 0.99,
},
{
"name" : "batch_ratio",
"type" : float,
"default" : 1,
},
{
"name" : "g_lr",
"description": "Initial learning rate of the training process.",
"type" : float,
"default" : 4e-5,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-3],
'doc': True,
},
{
"name" : "interval",
"description": "interval step for index removing.",
"type" : int,
"default" : 0
},
{
"name" : "generate_deter",
"description": "deterministic of generator rollout",
"type" : int,
"default" : 0
},
{
"name" : "reward_uncertainty_weight",
"description": "Reward uncertainty weight(MOPO)",
"type" : float,
"default" : 0,
"search_mode" : "continuous",
},
]
@property
def policy(self):
if isinstance(self.train_models, list) or isinstance(self.train_models, tuple):
assert len(self.train_models[:-2])==1
return self.train_models[:-2]
else:
assert len(self.train_models[:-2])==1
return self.train_models
@property
def val_policy(self):
if isinstance(self.val_models, list) or isinstance(self.val_models, tuple):
assert len(self.val_models[:-2])==1
return self.val_models[:-2]
else:
assert len(self.val_models[:-2])==1
return self.val_models
[docs] def model_creator(self, config, nodes):
"""
Create model including platform policy and value net.
:return: env model, platform policy, value net
"""
graph = config['graph']
policy_name = config['target_policy_name'][0] #TODO: target_policy only one future support multi policy
total_dims = config['total_dims']
input_dim = get_input_dim_from_graph(graph, policy_name, total_dims)
if config['behavioral_policy_init'] and self.behaviour_policys:
for behaviour_policy in self.behaviour_policys:
target_policy = deepcopy(behaviour_policy)
target_policy.requires_grad_(True)
else:
if self.config.get("policy_bc_epoch", 0) == 0:
logger.info(f'user initiate a brand new policy network, so the policy_bc_epoch are set to 50')
# logger.warning(f'as policy_bc_epoch=0 and behaviour_policys=[], w_kl is forcefully rested to 0')
self.config['policy_bc_epoch'] = 50
assert len(nodes.keys()) == 1 #SAC now is only support single policy instead of multi-policy.
for (policy_name, node) in nodes.items():
node_input_dim = get_input_dim_from_graph(graph, policy_name, total_dims)
logger.info(f"Policy Backbone: {config['policy_backbone'],}")
node.initialize_network(node_input_dim, total_dims[policy_name]['output'],
hidden_features=config['policy_hidden_features'],
hidden_layers=config['policy_hidden_layers'],
hidden_activation=config['policy_hidden_activation'],
backbone_type=config['policy_backbone'],
dist_config=config['dist_configs'][policy_name],
is_transition=False)
target_policy = node.get_network()
input_dim += total_dims[policy_name]['input']
q_net = VectorizedCritic(input_dim, 1, config['q_hidden_features'], config['q_hidden_layers'], config['num_q_net'])
target_q_net = deepcopy(q_net)
target_q_net.requires_grad_(False)
# q_net1 = MLP(input_dim, 1, config['q_hidden_features'], config['q_hidden_layers'])
# q_net2 = MLP(input_dim, 1, config['q_hidden_features'], config['q_hidden_layers'])
# target_q_net1 = deepcopy(q_net1)
# target_q_net2 = deepcopy(q_net2)
# target_q_net1.requires_grad_(False)
# target_q_net2.requires_grad_(False)
# num_q_net = config['num_q_net']
# logger.warning(f'{num_q_net} critic nets are generated!')
return [target_policy, q_net, target_q_net] #q_net1, q_net2, target_q_net1, target_q_net2 ] #q_net, target_q_net]
[docs] def optimizer_creator(self, models, config):
"""
:return: generator optimizer including platform policy optimizer and value net optimizer
"""
target_policy, q_net, target_q_net = models
actor_optimizor = torch.optim.Adam(target_policy.parameters(), lr=config['g_lr'])
critic_optimizor = torch.optim.Adam(q_net.parameters(), lr=config['g_lr'])
# return actor_optimizor, critic_optimizor
target_policys = models[:-2]
optimizers = []
optimizers.append(actor_optimizor)
optimizers.append(critic_optimizor)
if self.config["policy_bc_epoch"] >=1:
models_params = []
for target_policy in target_policys:
models_params.append({'params': target_policy.parameters(), 'lr': 1e-3})
bc_policy_optimizer = torch.optim.Adam(models_params)
optimizers.append(bc_policy_optimizer)
# target_policy, q_net1, q_net2, _, _ = models
# actor_optimizor = torch.optim.Adam(target_policy.parameters(), lr=config['g_lr'])
# critic_optimizor = torch.optim.Adam([*q_net1.parameters(), *q_net2.parameters()], lr=config['g_lr'])
# optimizers = [actor_optimizor, critic_optimizor]
return optimizers
[docs] def data_creator(self, config):
config[BATCH_SIZE] = config['sac_batch_size']
return data_creator(config,
training_mode='trajectory',
training_horizon=config['sac_rollout_horizon'],
training_is_sample=True,
val_horizon=config['test_horizon'],
double=self.double_validation)
@catch_error
def setup(self, config):
# super(SACOperator, self).setup(config)
self.replay_buffer_train = ReplayBuffer(config['buffer_size'])
self.expert_buffer_train = ReplayBuffer(config['buffer_size'])
if self.double_validation:
self.replay_buffer_val = ReplayBuffer(config['buffer_size'])
self.expert_buffer_val = ReplayBuffer(config['buffer_size'])
@catch_error
def train_batch(self, expert_data, batch_info, scope='train'):
if scope == 'train':
#[target_policy, q_net1, q_net2, q_net3, q_net4, target_q_net1, target_q_net2, target_q_net3, target_q_net4]
target_policy = self.train_models[:-2]
q_net, target_q_net = self.train_models[-2:]
# target_policy = self.train_models[:-4]
# _, q_net1, q_net2, target_q_net1, target_q_net2 = self.train_models
if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0):
actor_optimizer = self.train_optimizers[2]
else:
actor_optimizer = self.train_optimizers[0]
critic_optimizer = self.train_optimizers[1]
envs = self.envs_train
buffer = self.replay_buffer_train
buffer_expert = self.expert_buffer_train
else:
assert self.double_validation
target_policy = self.val_models[:-2]
q_net, target_q_net = self.val_models[-2:]
# target_policy = self.val_models[:-4]
# _, q_net1, q_net2, target_q_net1, target_q_net2 = self.val_models
# target_policy, q_net1, q_net2, target_q_net1, target_q_net2 = self.val_models
if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0):
actor_optimizer = self.val_optimizers[2]
else:
actor_optimizer = self.val_optimizers[0]
critic_optimizer = self.val_optimizers[1]
envs = self.envs_val
buffer = self.replay_buffer_val
buffer_expert = self.expert_buffer_val
expert_data.to_torch(device=self._device)
assert len(expert_data.shape) == 3 # the expert_data should be with three dimensions as [horizon, batch, features]
generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['sac_rollout_horizon'], deterministic=bool(self.config['generate_deter']), clip=True)
if 'done' in self._graph.graph_dict.keys():
#expert data process
temp_done_expert = self._processor.deprocess_single_torch(expert_data['done'], 'done')
assert temp_done_expert.shape == expert_data['done'].shape
expert_data['done'] = temp_done_expert
#generate data process
temp_done_generate = self._processor.deprocess_single_torch(generated_data['done'], 'done')
assert temp_done_generate.shape == generated_data['done'].shape
generated_data['done'] = temp_done_generate
assert torch.sum(expert_data['done'])>=0
assert torch.sum(generated_data['done'])>=0
else:
#expert data process
expert_data['done'] = torch.zeros(expert_data.shape[:-1]+[1]).to(expert_data[self._graph.leaf[0]].device)
#generate data process
generated_data['done'] = torch.zeros(generated_data.shape[:-1]+[1]).to(generated_data[self._graph.leaf[0]].device)
assert list(expert_data['done'].shape)[:-1] == expert_data.shape[:-1]
assert list(generated_data['done'].shape)[:-1] == generated_data.shape[:-1]
assert torch.sum(expert_data['done']).item()==0 # done is not in the decision-flow, and should all be zeros
assert torch.sum(generated_data['done']).item()==0 # done is not in the decision-flow, and should all be zeros
all_index = np.arange(0, self.config[BATCH_SIZE], 1)
if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0):
model_index = np.random.choice(all_index,0)
expert_index= all_index
buffer = ReplayBuffer(self.config['buffer_size'])
else:
model_index = all_index #np.random.choice(all_index,int(self.config[BATCH_SIZE]*1), replace=False) #np.random.choice(all_index,int(self.config[BATCH_SIZE]*(self.config['batch_ratio'])), replace=False)
expert_index= all_index #np.random.choice(all_index,int(self.config[BATCH_SIZE]*1), replace=False)#np.random.choice(all_index,self.config[BATCH_SIZE]-int(self.config[BATCH_SIZE]*(self.config['batch_ratio'])), replace=False)
if True:# if model_index.shape[0]:
for i in range(generated_data.shape[0] - 1):
#TODO: policy name should be set as list of multi policy name
next_remove_index = np.array(torch.where(generated_data['done'][i]!=0)[0].tolist(), dtype=model_index.dtype)
try:
if self.config.get("interval", 0) != 0 :
interval_remove_index = np.array(torch.where(generated_data['done'][i:i+self.config['interval']]!=0)[1].tolist(), dtype=model_index.dtype)
__interval = self.config['interval']
logger.warning(f'interval-done, in the epoch: {self._epoch_cnt}, at step: {i}')
interval_remove_index = None
except:
interval_remove_index = None
buffer.put(Batch({
'obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(generated_data[ i ]).values()), dim=-1)[model_index],
'next_obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(generated_data[i+1]).values()), dim=-1)[model_index],
'action' : generated_data[i][self.policy_name[0]][model_index],
'reward' : generated_data[i]['reward'][model_index],
'done' : generated_data[i]['done'][model_index],
}))
pre_size = model_index.shape[0]
model_index = np.setdiff1d(model_index, next_remove_index)
if not interval_remove_index is None:
model_index = np.setdiff1d(model_index, interval_remove_index)
logger.info(f'Detect interval_remove_index done with interval {__interval}, with shape {interval_remove_index.shape[0]}')
post_size = model_index.shape[0]
if post_size<pre_size:
logger.info(f'size of model_index is descent, in the epoch: {self._epoch_cnt}, at step: {i}')
if True: #if expert_index.shape[0]:
#generate rewards for expert data
expert_rewards = generate_rewards(expert_data, reward_fn=lambda data: self._user_func(self._get_original_actions(data)))
for i in range(expert_data.shape[0] - 1):
#TODO: policy name should be set as list of multi policy name
next_remove_index = np.array(torch.where(expert_data[i]['done']!=0)[0].tolist(), dtype=expert_index.dtype)
if next_remove_index.shape[0] != 0:
logger.info(f'Detect expert data traj-done, in the epoch: {self._epoch_cnt}, at step: {i}')
buffer_expert.put(Batch({
'obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(expert_data[ i ]).values()), dim=-1)[expert_index],
'next_obs' : torch.cat(list(self._graph.get_node(self.policy_name[0]).get_inputs(expert_data[i+1]).values()), dim=-1)[expert_index],
'action' : expert_data[i][self.policy_name[0]][expert_index],
'reward' : expert_rewards[i]['reward'][expert_index],
'done' : expert_data[i]['done'][expert_index],
}))
pre_size = expert_index.shape[0]
expert_index = np.setdiff1d(expert_index, next_remove_index)
post_size = expert_index.shape[0]
if post_size<pre_size:
logger.info(f'size of expert_index is descent, in the epoch: {self._epoch_cnt}, at step: {i}')
# logger.info(f'buffer_gen length : {len(buffer)}, buffer_real length : {len(buffer_expert)}')
assert len(target_policy)==1 # SAC is now only support single policy instead of multi-policy.
target_policy = target_policy[0]
_info = self.sac(buffer, buffer_expert, target_policy,
q_net, target_q_net, #q_net1, q_net2, target_q_net1, target_q_net2, #q_net, target_q_net, #q_net3, q_net4, target_q_net3, target_q_net4,
gamma=self.config['gamma'], alpha=self.config['alpha'], polyak=self.config['polyak'],
actor_optimizer=actor_optimizer, critic_optimizer=critic_optimizer)
info = _info
for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k)
self._stop_flag = self._stop_flag or self._epoch_cnt >= self.config['sac_epoch']
info = self._early_stop(info)
return info
@catch_error
def bc_train_batch(self, expert_data, batch_info, scope='train'):
if scope == 'train':
return self.train_batch(expert_data, batch_info=batch_info, scope='train')
else:
assert self.double_validation
return self.train_batch(expert_data, batch_info=batch_info, scope='val')
'''
# if scope == 'train':
# target_policy = self.train_models[:-4]
# optimizer = self.train_optimizers[-1]
# else:
# assert self.double_validation
# target_policy = self.val_models[:-4]
# optimizer = self.val_optimizers[-1]
# expert_data.to_torch(device=self._device)
# loss = 0
# for policy_name, _target_policy in zip(self.policy_name, target_policy):
# _target_policy.reset()
# with torch.no_grad():
# state = get_input_from_graph(self._graph, policy_name , expert_data)
# action_dist = _target_policy(state)
# action_log_prob = action_dist.log_prob(expert_data[policy_name])
# if policy_name + "_isnan_index_" in expert_data.keys():
# isnan_index = 1 - expert_data[policy_name + "_isnan_index_"]
# policy_loss = - (action_log_prob*isnan_index).mean()
# else:
# isnan_index = None
# policy_loss = - action_log_prob.mean()
# loss += policy_loss
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# info = {
# "SAC/policy_bc_loss" : loss.mean().item(),
# }
# return info
'''
[docs] def sac(self,
buffer, buffer_real, target_policy,
q_net, target_q_net,
gamma=0.99, alpha=0.2, polyak=0.99,
actor_optimizer=None, critic_optimizer=None):
if self._epoch_cnt <= (1 + self.config.get("policy_bc_epoch", 0)):
logger.warning(f'BC pre-trained the policy! Critic_pretrain is finished!')
# for behaviour_policy in self.behaviour_policys:
try:
self.bc_init_net = deepcopy(target_policy) #deepcopy(self.behaviour_policys[0]) #
except:
logger.info(f'No behaviour policys can be used for calculating kl loss')
for _ in range(self.config['sac_steps_per_epoch']):
self._batch_cnt += 1
if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0):
data = buffer_real.sample(int(self.config['sac_batch_size']))
data.to_torch(device=self._device)
obs = data['obs']
action = data['action']
else:
_data_real = buffer_real.sample(int(self.config['sac_batch_size']*(1-self.config['batch_ratio'])))
_data_gen = buffer.sample(int(self.config['sac_batch_size']*(self.config['batch_ratio'])))
data = Batch()
for k, _ in _data_gen.items():
data[k] = torch.cat([_data_real[k], _data_gen[k]], dim=0)
data.to_torch(device=self._device)
obs = data['obs']
action = data['action']
# update critic ----------------------------------------
with torch.no_grad():
next_obs = data['next_obs']
next_action_dist = target_policy(next_obs)
next_action = next_action_dist.sample()
next_action_log_prob = next_action_dist.log_prob(next_action).unsqueeze(dim=-1)
next_obs_action = torch.cat([next_obs, next_action], dim=-1)
next_q = target_q_net(next_obs_action).min(0).values.unsqueeze(-1)
target_q = data['reward'] + gamma * (next_q - alpha * next_action_log_prob) * (1 - data['done'])
out_target_q = next_q.mean()
if torch.any(data['done']):
max_done_q = torch.max(next_q[torch.where(data['done'])[0]]).item()
max_q = torch.max(next_q[torch.where(1-data['done'])[0]]).item()
else:
max_done_q = -np.inf
max_q = torch.max(next_q[torch.where(1-data['done'])[0]]).item()
obs_action = torch.cat([obs, action], dim=-1)
q_values = q_net(obs_action)
out_q_values = q_values.mean()
# q_limit_loss = (q_values.mean() - 10)**2
critic_loss = ((q_values - target_q.view(-1)) ** 2).mean(dim=1).sum(dim=0) #+ 0.5 * q_limit_loss
critic_optimizer.zero_grad()
critic_loss.backward()
for key, para in q_net.named_parameters():
c_grad = torch.norm(para.grad)
break
critic_optimizer.step()
# update target networks by polyak averaging.
with torch.no_grad():
for p, p_targ in zip(q_net.parameters(), target_q_net.parameters()):
# NB: We use an in-place operations "mul_", "add_" to update target
# params, as opposed to "mul" and "add", which would make new tensors.
p_targ.data.mul_(polyak)
p_targ.data.add_((1 - polyak) * p.data)
# update actor ----------------------------------------
if self._epoch_cnt <= self.config.get("policy_bc_epoch", 0):
# bc method
action_dist = target_policy(obs)
mean, var = action_dist.mode, action_dist.std
inv_var = 1/var
# Average over batch and dim, sum over ensembles.
mse_loss_inv = (torch.pow(mean - data['action'], 2) * inv_var).mean(dim=(0,1))
var_loss = var.mean(dim=(0,1))
actor_loss = mse_loss_inv.sum() + var_loss.sum()
actor_optimizer.zero_grad()
actor_loss.backward()
for key, para in target_policy.named_parameters():
a_grad = torch.norm(para.grad)
break
actor_optimizer.step()
kl_loss = torch.tensor([0.0])
q = torch.tensor([0.0])
else:
# SAC update actor
action_dist = target_policy(obs)
new_action = action_dist.rsample()
action_log_prob = action_dist.log_prob(new_action)
new_obs_action = torch.cat([obs, new_action], dim=-1)
q = q_net(new_obs_action).min(0).values.unsqueeze(-1)
actor_loss = - q.mean() + alpha * action_log_prob.mean()
# compute kl loss
if self.config['w_kl']>0:
_kl_q = self.bc_init_net(obs)
_kl_p = target_policy(obs)
kl_loss = kl_divergence(_kl_q, _kl_p).mean() * self.config['w_kl']
if torch.any(torch.isinf(kl_loss)) or torch.any(torch.isnan(kl_loss)):
logger.info(f'Detect nan or inf in kl_loss, skip this batch! (loss : {kl_loss}, epoch : {self._epoch_cnt})')
actor_kl_loss = actor_loss
else:
actor_kl_loss = actor_loss + kl_loss
else:
actor_kl_loss = actor_loss
kl_loss = torch.zeros((1,))
actor_optimizer.zero_grad()
actor_kl_loss.backward()
for key, para in target_policy.named_parameters():
a_grad = torch.norm(para.grad)
# if a_grad<1e-10:
# breakpoint()
break
actor_optimizer.step()
info = {
"SAC/critic_loss": critic_loss.item(),
"SAC/actor_loss": actor_loss.item(),
"SAC/critic_grad": c_grad.item(),
"SAC/actor_grad": a_grad.item(),
"SAC/kl_loss": kl_loss.item(),
"SAC/critic_q_values": out_q_values.item(),
"SAC/critic_target_q": out_target_q.item(),
"SAC/actor_q": q.mean().item(),
"SAC/batch_reward": data['reward'].mean().item(),
"SAC/batch_reward_max": data['reward'].max().item(),
"SAC/max_done_q": max_done_q,
"SAC/max_q": max_q
}
return info