''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
from typing import OrderedDict
import torch
from copy import deepcopy
from revive.computation.dists import kl_divergence
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.algo.policy.base import PolicyOperator, catch_error
[docs]class PPOOperator(PolicyOperator):
"""
A class used to train target policy.
"""
PARAMETER_DESCRIPTION = [
{
"name" : "ppo_batch_size",
"description": "Batch size of training process.",
"abbreviation" : "pbs",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_bc_epoch",
"type" : int,
"default" : 0,
},
{
"name" : "ppo_epoch",
"description": "Number of epcoh for the training process",
"abbreviation" : "bep",
"type" : int,
"default" : 200,
'doc': True,
},
{
"name" : "ppo_rollout_horizon",
"description": "Rollout length of the policy train.",
"abbreviation" : "prh",
"type" : int,
"default" : 100,
'doc': True,
},
{
"name" : "policy_hidden_features",
"description": "Number of neurons per layer of the policy network.",
"abbreviation" : "phf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_hidden_layers",
"description": "Depth of policy network.",
"abbreviation" : "phl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "policy_backbone",
"description": "Backbone of policy network.",
"abbreviation" : "pb",
"type" : str,
"default" : "mlp",
'doc': True,
},
{
"name" : "value_hidden_features",
"abbreviation" : "vhf",
"type" : int,
"default" : 256,
},
{
"name" : "value_hidden_layers",
"abbreviation" : "vhl",
"type" : int,
"default" : 4,
},
{
"name" : "ppo_runs",
"type" : int,
"default" : 2,
},
{
"name" : "epsilon",
"type" : float,
"default" : 0.2,
},
{
"name" : "w_vl2",
"type" : float,
"default" : 0.001,
},
{
"name" : "w_ent",
"type" : float,
"default" : 0.0,
},
{
"name" : "w_kl",
"type" : float,
"default" : 1.0,
},
{
"name" : "gae_gamma",
"type" : float,
"default" : 0.99,
},
{
"name" : "gae_lambda",
"type" : float,
"default" : 0.95,
},
{
"name" : "g_lr",
"description": "Initial learning rate of the training process.",
"type" : float,
"default" : 4e-5,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-3],
'doc': True,
},
{
"name" : "reward_uncertainty_weight",
"description": "Reward uncertainty weight(MOPO)",
"type" : float,
"default" : 0,
"search_mode" : "continuous",
},
]
[docs] def model_creator(self, config, nodes):
"""
Create models including target policy and value net.
:return: env models, target policy, value net
"""
graph = config['graph']
policy_name = config['target_policy_name']
total_dims = config['total_dims']
models = []
value_net_input_dim = OrderedDict()
if config['behavioral_policy_init'] and self.behaviour_policys:
for behaviour_policy in self.behaviour_policys:
target_policy = deepcopy(behaviour_policy)
target_policy.requires_grad_(True)
models.append(target_policy)
else:
for policy_name, node in nodes.items():
node_input_dim = get_input_dim_from_graph(graph, policy_name, total_dims)
logger.info(f"Policy Backbone: {config['policy_backbone'],}")
node.initialize_network(node_input_dim, total_dims[policy_name]['output'],
hidden_features=config['policy_hidden_features'],
hidden_layers=config['policy_hidden_layers'],
backbone_type=config['policy_backbone'],
dist_config=config['dist_configs'][policy_name],
is_transition=False)
target_policy = node.get_network()
models.append(target_policy)
value_net_input_dim = 0
for node_name in graph.get_leaf(graph):
value_net_input_dim += get_node_dim_from_dist_configs(config["dist_configs"], node_name)
value_net = MLP(value_net_input_dim, 1, config['value_hidden_features'], config['value_hidden_layers'])
models.append(value_net)
return models
[docs] def optimizer_creator(self, models, config):
"""
:return: generator optimizers including target policy optimizers and value net optimizers
"""
target_policys = models[:-1]
value_net = models[-1]
models_params = []
optimizers = []
for target_policy in target_policys:
models_params.append({'params': target_policy.parameters(), 'lr': config['g_lr']})
models_params.append({'params': value_net.parameters(), 'lr': config['g_lr']})
generator_optimizer = torch.optim.Adam(models_params)
optimizers.append(generator_optimizer)
if self.config["policy_bc_epoch"] >=1:
models_params = []
for target_policy in target_policys:
models_params.append({'params': target_policy.parameters(), 'lr': 1e-3})
bc_policy_optimizer = torch.optim.Adam(models_params)
optimizers.append(bc_policy_optimizer)
return optimizers
[docs] def data_creator(self, config):
config[BATCH_SIZE] = config['ppo_batch_size']
return data_creator(config,
training_mode='trajectory',
training_horizon=config['ppo_rollout_horizon'],
training_is_sample=True,
val_horizon=config['test_horizon'],
double=self.double_validation)
@catch_error
def train_batch(self, expert_data, batch_info, scope='train'):
if scope == 'train':
target_policy = self.train_models[:-1]
value_net = self.train_models[-1]
generator_optimizer = self.train_optimizers[0]
envs = self.envs_train
else:
assert self.double_validation
target_policy = self.val_models[:-1]
value_net = self.val_models[-1]
generator_optimizer = self.val_optimizers[0]
envs = self.envs_val
expert_data.to_torch(device=self._device)
with torch.no_grad():
generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['ppo_rollout_horizon'], deterministic=False, clip=True)
info.update(self._run_update(expert_data, generated_data, target_policy, value_net, generator_optimizer))
for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k)
self._stop_flag = self._stop_flag or self._epoch_cnt >= self.config['ppo_epoch']
info = self._early_stop(info)
return info
@catch_error
def bc_train_batch(self, expert_data, batch_info, scope='train'):
if scope == 'train':
target_policy = self.train_models[:-1]
optimizer = self.train_optimizers[-1]
else:
assert self.double_validation
target_policy = self.val_models[:-1]
optimizer = self.val_optimizers[-1]
expert_data.to_torch(device=self._device)
loss = 0
for policy_name, _target_policy in zip(self.policy_name, target_policy):
_target_policy.reset()
with torch.no_grad():
state = get_input_from_graph(self._graph, policy_name , expert_data)
action_dist = _target_policy(state, field='bc')
action_log_prob = action_dist.log_prob(expert_data[policy_name])
if policy_name + "_isnan_index_" in expert_data.keys():
isnan_index = 1 - expert_data[policy_name + "_isnan_index_"]
policy_loss = - (action_log_prob*isnan_index).mean()
else:
isnan_index = None
policy_loss = - action_log_prob.mean()
loss += policy_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
info = {
"PPO/policy_bc_loss" : loss.mean().item(),
}
return info
def _run_update(self, expert_data, generated_data, target_policy, value_net, generator_optimizer=None):
states = OrderedDict()
actions = OrderedDict()
action_log_probs = OrderedDict()
for policy_name in self.policy_name:
states[policy_name] = get_input_from_graph(self._graph, policy_name , generated_data)
actions[policy_name] = generated_data[policy_name]
action_log_probs[policy_name] = generated_data[policy_name + '_log_prob']
value_net_inputs = OrderedDict()
for node_name in self._graph.get_leaf(self._graph):
value_net_inputs[node_name] = generated_data[node_name]
value_net_states = torch.cat(list(value_net_inputs.values()), dim=-1)
rewards = generated_data.reward
_values = value_net(value_net_states)
_masks = torch.zeros_like(rewards)
_masks[-1] = torch.ones_like(_masks[-1])
advs, returns = self.ADV(rewards.detach(), _masks, _values.detach(), gamma=self.config['gae_gamma'], lam=self.config['gae_lambda'])
_repeat = self.config['ppo_runs']
for k in range(_repeat):
v_loss, p_loss, e_loss, kl_loss, total_loss = self.ppo_step(target_policy, value_net, generator_optimizer,
states, actions, returns, advs, action_log_probs, expert_data, value_net_states)
info = {
"PPO/v_loss" : v_loss.mean().item(),
"PPO/p_loss" : p_loss.mean().item(),
"PPO/e_loss" : e_loss.mean().item(),
"PPO/kl_loss" : kl_loss.mean().item(),
"PPO/total_loss": total_loss.mean().item()
}
return info
[docs] def ADV(self, reward, mask, value, gamma, lam, use_gae=True):
"""
Compute advantage function for PPO.
:param reward: rewards of each step
:param mask: mask is 1 if the trajectory done, else 0
:param value: value for each state
:param gamma: discount factor
:param lam: GAE lamda
:param use_gae: True or False
:return: advantages and new value
"""
advantages = torch.zeros_like(reward)
if not use_gae:
pre_value, pre_adv = 0, 0
for t in reversed(range(reward.shape[0])):
advantages[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t]
pre_value = value[t]
returns = value + advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10)
else:
td_error = torch.zeros_like(reward)
pre_value, pre_adv = 0, 0
for t in reversed(range(reward.shape[0])):
td_error[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t]
advantages[t] = td_error[t] + gamma * lam * pre_adv * (1 - mask[t])
pre_adv = advantages[t]
pre_value = value[t]
returns = value + advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10)
return advantages, returns
[docs] def ppo_step(self, target_policy, value_net, generator_optimizer,
states, actions, ret, advantages, action_log_probs, expert_data, value_net_states):
"""
Train target_policy and value_net by PPO algorithm.
:param target_policy: target policy
:param value_net: value net
:param generator_optimizer: the optimizers used to optimize target policy and value net
:param state: state of target policy
:param action: action of target policy
:param ret: GAE value
:param adv: advantage
:param action_log_prob: action log probability of target policy
:param expert_data: batch of expert data
:return: v_loss, p_loss, total_loss
"""
# updata critic
value_o = value_net(value_net_states.detach())
value_o = value_o.view(-1, value_o.shape[-1])
returns = ret.view(-1, ret.shape[-1])
v_loss = (value_o - returns.detach()).pow(2).mean()
for param in value_net.parameters():
v_loss += param.pow(2).sum() * self.config['w_vl2']
# compute kl loss
kl_loss = 0
for policy, behaviour_policy, state in zip(target_policy, self.behaviour_policys, states.values()):
policy.reset()
behaviour_policy.reset()
q = behaviour_policy(state)
p = policy(state)
kl_loss += kl_divergence(q, p).mean()
if isinstance(kl_loss,int):
kl_loss = torch.tensor([0],dtype=torch.float32).to(v_loss.device)
# update actor
prob_olds = []
prob_news = []
e_loss = 0
for policy, state, action_log_prob, action in zip(target_policy,states.values(),action_log_probs.values(),actions.values()):
policy.reset()
prob_old = action_log_prob
action_dist = policy(state)
policy_entropy = action_dist.entropy()
prob_new = action_dist.log_prob(action).unsqueeze(dim=-1)
prob_olds.append(prob_old)
prob_news.append(prob_new)
e_loss += policy_entropy.mean()
prob_old = torch.cat(prob_olds, dim=-1)
prob_new = torch.cat(prob_news, dim=-1)
ratio = torch.exp(prob_new - prob_old.detach())
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.config['epsilon'], 1 + self.config['epsilon']) * advantages
p_loss = - torch.min(surr1, surr2).sum()
e_loss = action_dist.entropy().mean()
total_loss = v_loss + p_loss - e_loss * self.config['w_ent'] + kl_loss * self.config['w_kl']
if generator_optimizer is not None:
generator_optimizer.zero_grad()
total_loss.backward()
grad_norm = get_grad_norm(get_models_parameters(*list(target_policy)+[value_net,]))
generator_optimizer.step()
return v_loss, p_loss, e_loss, kl_loss, total_loss