''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
from typing import OrderedDict
import torch
from copy import deepcopy
from revive.computation.dists import kl_divergence
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.algo.policy.base import PolicyOperator, catch_error
[docs]
class PPOOperator(PolicyOperator):
"""
A class used to train target policy.
"""
NAME = "PPO"
PARAMETER_DESCRIPTION = [
{
"name" : "ppo_batch_size",
"description": "Batch size of training process.",
"abbreviation" : "pbs",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_bc_epoch",
"description": "pre-train policy with setting epoch",
"type" : int,
"default" : 0,
'doc': True,
},
{
"name" : "ppo_epoch",
"description": "Number of epcoh for the training process",
"abbreviation" : "bep",
"type" : int,
"default" : 1000,
'doc': True,
},
{
"name" : "ppo_rollout_horizon",
"description": "Rollout length of the policy train.",
"abbreviation" : "prh",
"type" : int,
"default" : 100,
'doc': True,
},
{
"name" : "policy_hidden_features",
"description": "Number of neurons per layer of the policy network.",
"abbreviation" : "phf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_hidden_layers",
"description": "Depth of policy network.",
"abbreviation" : "phl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "policy_backbone",
"description": "Backbone of policy network.[mlp, res, ft_transformer]",
"abbreviation" : "pb",
"type" : str,
"default" : "res",
'doc': True,
},
{
"name" : "value_hidden_features",
"abbreviation" : "vhf",
"type" : int,
"default" : 256,
},
{
"name" : "value_hidden_layers",
"abbreviation" : "vhl",
"type" : int,
"default" : 4,
},
{
"name" : "ppo_runs",
"type" : int,
"default" : 2,
},
{
"name" : "epsilon",
"type" : float,
"default" : 0.2,
},
{
"name" : "w_vl2",
"type" : float,
"default" : 0.001,
},
{
"name" : "w_ent",
"type" : float,
"default" : 0.0,
},
{
"name" : "w_kl",
"type" : float,
"default" : 1.0,
},
{
"name" : "gae_gamma",
"type" : float,
"default" : 0.99,
},
{
"name" : "gae_lambda",
"type" : float,
"default" : 0.95,
},
{
"name" : "g_lr",
"description": "Initial learning rate of the training process.",
"type" : float,
"default" : 4e-5,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-3],
'doc': True,
},
{
"name" : "reward_uncertainty_weight",
"description": "Reward uncertainty weight(MOPO)",
"type" : float,
"default" : 0,
"search_mode" : "continuous",
},
{
"name" : "filter",
"type" : bool,
"default" : False,
},
{
"name" : "candidate_num",
"type" : int,
"default" : 50,
},
{
"name" : "ensemble_size",
"type" : int,
"default" : 50,
},
{
"name" : "ensemble_choosing_interval",
"type" : int,
"default" : 10,
},
{
"name" : "disturbing_transition_function",
"description": "Disturbing the network node in policy learning",
"type": bool,
"default": False,
},
{
"name" : "disturbing_nodes",
"description": "Disturbing the network node in policy learning",
"type": list,
"default": 'auto',
},
{
"name" : "disturbing_net_num",
"description": "Disturbing the network node in policy learning",
"type": int,
"default": 100,
},
{
"name" : "disturbing_weight",
"description": "Disturbing the network node in policy learning",
"type": float,
"default": 0.05,
},
{
"name" : "generate_deter",
"description": "deterministic of generator rollout",
"type" : int,
"default" : 1
},
]
@property
def policy_bc_optimizer(self):
return self.val_models[:-1]
[docs]
def model_creator(self, nodes):
"""
Create models including target policy and value net.
:return: env models, target policy, value net
"""
models = super().model_creator(nodes)
graph = self._graph
value_net_input_dim = 0
for node_name in graph.get_leaf(graph):
value_net_input_dim += get_node_dim_from_dist_configs(self.config["dist_configs"], node_name)
value_net = MLP(value_net_input_dim, 1, self.config['value_hidden_features'], self.config['value_hidden_layers'])
models.append(value_net)
return models
[docs]
def optimizer_creator(self, scope):
"""
:return: generator optimizers including target policy optimizers and value net optimizers
"""
if scope == "train":
target_policys = self.train_policy
assert len(self.other_train_models) == 1
value_net = self.other_train_models[0]
else:
target_policys = self.val_policy
assert len(self.other_val_models) == 1
value_net = self.other_val_models[0]
models_params = []
optimizers = []
for target_policy in target_policys:
models_params.append({'params': target_policy.parameters(), 'lr': self.config['g_lr']})
models_params.append({'params': value_net.parameters(), 'lr': self.config['g_lr']})
generator_optimizer = torch.optim.Adam(models_params)
optimizers.append(generator_optimizer)
return optimizers
[docs]
def data_creator(self):
self.config[BATCH_SIZE] = self.config['ppo_batch_size']
return data_creator(self.config,
training_mode='trajectory',
training_horizon=self.config['ppo_rollout_horizon'],
training_is_sample=True,
val_horizon=self.config['test_horizon'],
double=self.double_validation)
@catch_error
def after_train_epoch(self):
if self._epoch_cnt >= self.config['ppo_epoch'] + self.config['policy_bc_epoch']:
self._stop_flag = True
@catch_error
def train_batch(self, expert_data, batch_info, scope='train'):
if scope == 'train':
target_policy = self.train_models[:-1]
value_net = self.train_models[-1]
generator_optimizer = self.train_optimizers[0]
envs = self.envs_train
else:
assert self.double_validation
target_policy = self.val_models[:-1]
value_net = self.val_models[-1]
generator_optimizer = self.val_optimizers[0]
envs = self.envs_val
expert_data.to_torch(device=self._device)
with torch.no_grad():
#generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['ppo_rollout_horizon'], deterministic=False, clip=True)
generated_data, info = self._run_rollout(expert_data,
target_policy,
envs,
traj_length=self.config['ppo_rollout_horizon'],
deterministic=bool(self.config['generate_deter']),
clip=True)
info.update(self._run_update(expert_data, generated_data, target_policy, value_net, generator_optimizer))
for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k)
info = self._early_stop(info)
return info
def _run_update(self, expert_data, generated_data, target_policy, value_net, generator_optimizer=None):
states = OrderedDict()
actions = OrderedDict()
action_log_probs = OrderedDict()
for policy_name in self.policy_name:
states[policy_name] = get_input_from_graph(self._graph, policy_name , generated_data)
if hasattr(self.train_nodes[policy_name], 'custom_node') or self.input_is_dict:
states[policy_name] = get_input_dict_from_graph(self._graph, policy_name , generated_data)
actions[policy_name] = generated_data[policy_name]
action_log_probs[policy_name] = generated_data[policy_name + '_log_prob']
value_net_inputs = OrderedDict()
for node_name in self._graph.get_leaf(self._graph):
value_net_inputs[node_name] = generated_data[node_name]
value_net_states = torch.cat(list(value_net_inputs.values()), dim=-1)
rewards = generated_data.reward
_values = value_net(value_net_states)
_masks = generated_data["done_mask"]
valid_masks = generated_data["valid_masks"]
advs, returns = self.ADV(rewards.detach(), _masks, _values.detach(), gamma=self.config['gae_gamma'], lam=self.config['gae_lambda'])
_repeat = self.config['ppo_runs']
for k in range(_repeat):
v_loss, p_loss, e_loss, kl_loss, total_loss = self.ppo_step(target_policy, value_net, generator_optimizer,
states, actions, returns, advs, action_log_probs, expert_data, value_net_states,valid_masks)
info = {
"PPO/v_loss" : v_loss.mean().item(),
"PPO/p_loss" : p_loss.mean().item(),
"PPO/e_loss" : e_loss.mean().item(),
"PPO/kl_loss" : kl_loss.mean().item(),
"PPO/total_loss": total_loss.mean().item(),
"PPO/mean_reward": rewards.mean().item(),
"PPO/min_reward": rewards.min().item(),
"PPO/max_reward": rewards.max().item(),
}
return info
[docs]
def ADV(self, reward, mask, value, gamma, lam, use_gae=True):
"""
Compute advantage function for PPO.
:param reward: rewards of each step
:param mask: mask is 1 if the trajectory done, else 0
:param value: value for each state
:param gamma: discount factor
:param lam: GAE lamda
:param use_gae: True or False
:return: advantages and new value
"""
advantages = torch.zeros_like(reward)
if not use_gae:
pre_value, pre_adv = 0, 0
for t in reversed(range(reward.shape[0])):
advantages[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t]
pre_value = value[t]
returns = value + advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10)
else:
td_error = torch.zeros_like(reward)
pre_value, pre_adv = 0, 0
for t in reversed(range(reward.shape[0])):
td_error[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t]
advantages[t] = td_error[t] + gamma * lam * pre_adv * (1 - mask[t])
pre_adv = advantages[t]
pre_value = value[t]
returns = value + advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10)
return advantages, returns
[docs]
def ppo_step(self, target_policy, value_net, generator_optimizer,
states, actions, ret, advantages, action_log_probs, expert_data, value_net_states,valid_masks=1):
"""
Train target_policy and value_net by PPO algorithm.
:param target_policy: target policy
:param value_net: value net
:param generator_optimizer: the optimizers used to optimize target policy and value net
:param state: state of target policy
:param action: action of target policy
:param ret: GAE value
:param adv: advantage
:param action_log_prob: action log probability of target policy
:param expert_data: batch of expert data
:return: v_loss, p_loss, total_loss
"""
# updata critic
value_o = value_net(value_net_states.detach())
value_o = value_o.view(-1, value_o.shape[-1])
returns = ret.view(-1, ret.shape[-1])
v_loss = ((value_o - returns.detach())*valid_masks.view(-1, valid_masks.shape[-1])).pow(2).mean()
for param in value_net.parameters():
v_loss += param.pow(2).sum() * self.config['w_vl2']
# compute kl loss
kl_loss = 0
for policy, behaviour_policy, state in zip(target_policy, self.behaviour_policys, states.values()):
policy.reset()
behaviour_policy.reset()
q = behaviour_policy(state)
p = policy(state)
kl_loss += (kl_divergence(q, p)* torch.squeeze(valid_masks, dim=-1)).mean()
if isinstance(kl_loss,int):
kl_loss = torch.tensor([0],dtype=torch.float32).to(v_loss.device)
# update actor
prob_olds = []
prob_news = []
e_loss = 0
for policy, state, action_log_prob, action in zip(target_policy,states.values(),action_log_probs.values(),actions.values()):
policy.reset()
prob_old = action_log_prob
action_dist = policy(state)
policy_entropy = action_dist.entropy().unsqueeze(dim=-1)
prob_new = action_dist.log_prob(action).unsqueeze(dim=-1)
prob_olds.append(prob_old)
prob_news.append(prob_new)
e_loss += (policy_entropy*valid_masks).mean()
prob_old = torch.cat(prob_olds, dim=-1)
prob_new = torch.cat(prob_news, dim=-1)
ratio = torch.exp(prob_new - prob_old.detach())
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.config['epsilon'], 1 + self.config['epsilon']) * advantages
p_loss = - (torch.min(surr1, surr2)*valid_masks).sum()
e_loss = (action_dist.entropy().unsqueeze(dim=-1)*valid_masks).mean()
total_loss = v_loss + p_loss - e_loss * self.config['w_ent'] + kl_loss * self.config['w_kl']
if generator_optimizer is not None:
generator_optimizer.zero_grad()
total_loss.backward()
grad_norm = get_grad_norm(get_models_parameters(*list(target_policy)+[value_net,]))
generator_optimizer.step()
return v_loss, p_loss, e_loss, kl_loss, total_loss