Source code for revive.algo.policy.ppo

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

from typing import OrderedDict
import torch
from copy import deepcopy

from revive.computation.dists import kl_divergence
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.algo.policy.base import PolicyOperator, catch_error


[docs]class PPOOperator(PolicyOperator): """ A class used to train target policy. """ PARAMETER_DESCRIPTION = [ { "name" : "ppo_batch_size", "description": "Batch size of training process.", "abbreviation" : "pbs", "type" : int, "default" : 256, 'doc': True, }, { "name" : "policy_bc_epoch", "type" : int, "default" : 0, }, { "name" : "ppo_epoch", "description": "Number of epcoh for the training process", "abbreviation" : "bep", "type" : int, "default" : 200, 'doc': True, }, { "name" : "ppo_rollout_horizon", "description": "Rollout length of the policy train.", "abbreviation" : "prh", "type" : int, "default" : 100, 'doc': True, }, { "name" : "policy_hidden_features", "description": "Number of neurons per layer of the policy network.", "abbreviation" : "phf", "type" : int, "default" : 256, 'doc': True, }, { "name" : "policy_hidden_layers", "description": "Depth of policy network.", "abbreviation" : "phl", "type" : int, "default" : 4, 'doc': True, }, { "name" : "policy_backbone", "description": "Backbone of policy network.", "abbreviation" : "pb", "type" : str, "default" : "mlp", 'doc': True, }, { "name" : "value_hidden_features", "abbreviation" : "vhf", "type" : int, "default" : 256, }, { "name" : "value_hidden_layers", "abbreviation" : "vhl", "type" : int, "default" : 4, }, { "name" : "ppo_runs", "type" : int, "default" : 2, }, { "name" : "epsilon", "type" : float, "default" : 0.2, }, { "name" : "w_vl2", "type" : float, "default" : 0.001, }, { "name" : "w_ent", "type" : float, "default" : 0.0, }, { "name" : "w_kl", "type" : float, "default" : 1.0, }, { "name" : "gae_gamma", "type" : float, "default" : 0.99, }, { "name" : "gae_lambda", "type" : float, "default" : 0.95, }, { "name" : "g_lr", "description": "Initial learning rate of the training process.", "type" : float, "default" : 4e-5, "search_mode" : "continuous", "search_values" : [1e-6, 1e-3], 'doc': True, }, { "name" : "reward_uncertainty_weight", "description": "Reward uncertainty weight(MOPO)", "type" : float, "default" : 0, "search_mode" : "continuous", }, ]
[docs] def model_creator(self, config, nodes): """ Create models including target policy and value net. :return: env models, target policy, value net """ graph = config['graph'] policy_name = config['target_policy_name'] total_dims = config['total_dims'] models = [] value_net_input_dim = OrderedDict() if config['behavioral_policy_init'] and self.behaviour_policys: for behaviour_policy in self.behaviour_policys: target_policy = deepcopy(behaviour_policy) target_policy.requires_grad_(True) models.append(target_policy) else: for policy_name, node in nodes.items(): node_input_dim = get_input_dim_from_graph(graph, policy_name, total_dims) logger.info(f"Policy Backbone: {config['policy_backbone'],}") node.initialize_network(node_input_dim, total_dims[policy_name]['output'], hidden_features=config['policy_hidden_features'], hidden_layers=config['policy_hidden_layers'], backbone_type=config['policy_backbone'], dist_config=config['dist_configs'][policy_name], is_transition=False) target_policy = node.get_network() models.append(target_policy) value_net_input_dim = 0 for node_name in graph.get_leaf(graph): value_net_input_dim += get_node_dim_from_dist_configs(config["dist_configs"], node_name) value_net = MLP(value_net_input_dim, 1, config['value_hidden_features'], config['value_hidden_layers']) models.append(value_net) return models
[docs] def optimizer_creator(self, models, config): """ :return: generator optimizers including target policy optimizers and value net optimizers """ target_policys = models[:-1] value_net = models[-1] models_params = [] optimizers = [] for target_policy in target_policys: models_params.append({'params': target_policy.parameters(), 'lr': config['g_lr']}) models_params.append({'params': value_net.parameters(), 'lr': config['g_lr']}) generator_optimizer = torch.optim.Adam(models_params) optimizers.append(generator_optimizer) if self.config["policy_bc_epoch"] >=1: models_params = [] for target_policy in target_policys: models_params.append({'params': target_policy.parameters(), 'lr': 1e-3}) bc_policy_optimizer = torch.optim.Adam(models_params) optimizers.append(bc_policy_optimizer) return optimizers
[docs] def data_creator(self, config): config[BATCH_SIZE] = config['ppo_batch_size'] return data_creator(config, training_mode='trajectory', training_horizon=config['ppo_rollout_horizon'], training_is_sample=True, val_horizon=config['test_horizon'], double=self.double_validation)
@catch_error def train_batch(self, expert_data, batch_info, scope='train'): if scope == 'train': target_policy = self.train_models[:-1] value_net = self.train_models[-1] generator_optimizer = self.train_optimizers[0] envs = self.envs_train else: assert self.double_validation target_policy = self.val_models[:-1] value_net = self.val_models[-1] generator_optimizer = self.val_optimizers[0] envs = self.envs_val expert_data.to_torch(device=self._device) with torch.no_grad(): generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['ppo_rollout_horizon'], deterministic=False, clip=True) info.update(self._run_update(expert_data, generated_data, target_policy, value_net, generator_optimizer)) for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k) self._stop_flag = self._stop_flag or self._epoch_cnt >= self.config['ppo_epoch'] info = self._early_stop(info) return info @catch_error def bc_train_batch(self, expert_data, batch_info, scope='train'): if scope == 'train': target_policy = self.train_models[:-1] optimizer = self.train_optimizers[-1] else: assert self.double_validation target_policy = self.val_models[:-1] optimizer = self.val_optimizers[-1] expert_data.to_torch(device=self._device) loss = 0 for policy_name, _target_policy in zip(self.policy_name, target_policy): _target_policy.reset() with torch.no_grad(): state = get_input_from_graph(self._graph, policy_name , expert_data) action_dist = _target_policy(state, field='bc') action_log_prob = action_dist.log_prob(expert_data[policy_name]) if policy_name + "_isnan_index_" in expert_data.keys(): isnan_index = 1 - expert_data[policy_name + "_isnan_index_"] policy_loss = - (action_log_prob*isnan_index).mean() else: isnan_index = None policy_loss = - action_log_prob.mean() loss += policy_loss optimizer.zero_grad() loss.backward() optimizer.step() info = { "PPO/policy_bc_loss" : loss.mean().item(), } return info def _run_update(self, expert_data, generated_data, target_policy, value_net, generator_optimizer=None): states = OrderedDict() actions = OrderedDict() action_log_probs = OrderedDict() for policy_name in self.policy_name: states[policy_name] = get_input_from_graph(self._graph, policy_name , generated_data) actions[policy_name] = generated_data[policy_name] action_log_probs[policy_name] = generated_data[policy_name + '_log_prob'] value_net_inputs = OrderedDict() for node_name in self._graph.get_leaf(self._graph): value_net_inputs[node_name] = generated_data[node_name] value_net_states = torch.cat(list(value_net_inputs.values()), dim=-1) rewards = generated_data.reward _values = value_net(value_net_states) _masks = torch.zeros_like(rewards) _masks[-1] = torch.ones_like(_masks[-1]) advs, returns = self.ADV(rewards.detach(), _masks, _values.detach(), gamma=self.config['gae_gamma'], lam=self.config['gae_lambda']) _repeat = self.config['ppo_runs'] for k in range(_repeat): v_loss, p_loss, e_loss, kl_loss, total_loss = self.ppo_step(target_policy, value_net, generator_optimizer, states, actions, returns, advs, action_log_probs, expert_data, value_net_states) info = { "PPO/v_loss" : v_loss.mean().item(), "PPO/p_loss" : p_loss.mean().item(), "PPO/e_loss" : e_loss.mean().item(), "PPO/kl_loss" : kl_loss.mean().item(), "PPO/total_loss": total_loss.mean().item() } return info
[docs] def ADV(self, reward, mask, value, gamma, lam, use_gae=True): """ Compute advantage function for PPO. :param reward: rewards of each step :param mask: mask is 1 if the trajectory done, else 0 :param value: value for each state :param gamma: discount factor :param lam: GAE lamda :param use_gae: True or False :return: advantages and new value """ advantages = torch.zeros_like(reward) if not use_gae: pre_value, pre_adv = 0, 0 for t in reversed(range(reward.shape[0])): advantages[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t] pre_value = value[t] returns = value + advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10) else: td_error = torch.zeros_like(reward) pre_value, pre_adv = 0, 0 for t in reversed(range(reward.shape[0])): td_error[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t] advantages[t] = td_error[t] + gamma * lam * pre_adv * (1 - mask[t]) pre_adv = advantages[t] pre_value = value[t] returns = value + advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10) return advantages, returns
[docs] def ppo_step(self, target_policy, value_net, generator_optimizer, states, actions, ret, advantages, action_log_probs, expert_data, value_net_states): """ Train target_policy and value_net by PPO algorithm. :param target_policy: target policy :param value_net: value net :param generator_optimizer: the optimizers used to optimize target policy and value net :param state: state of target policy :param action: action of target policy :param ret: GAE value :param adv: advantage :param action_log_prob: action log probability of target policy :param expert_data: batch of expert data :return: v_loss, p_loss, total_loss """ # updata critic value_o = value_net(value_net_states.detach()) value_o = value_o.view(-1, value_o.shape[-1]) returns = ret.view(-1, ret.shape[-1]) v_loss = (value_o - returns.detach()).pow(2).mean() for param in value_net.parameters(): v_loss += param.pow(2).sum() * self.config['w_vl2'] # compute kl loss kl_loss = 0 for policy, behaviour_policy, state in zip(target_policy, self.behaviour_policys, states.values()): policy.reset() behaviour_policy.reset() q = behaviour_policy(state) p = policy(state) kl_loss += kl_divergence(q, p).mean() if isinstance(kl_loss,int): kl_loss = torch.tensor([0],dtype=torch.float32).to(v_loss.device) # update actor prob_olds = [] prob_news = [] e_loss = 0 for policy, state, action_log_prob, action in zip(target_policy,states.values(),action_log_probs.values(),actions.values()): policy.reset() prob_old = action_log_prob action_dist = policy(state) policy_entropy = action_dist.entropy() prob_new = action_dist.log_prob(action).unsqueeze(dim=-1) prob_olds.append(prob_old) prob_news.append(prob_new) e_loss += policy_entropy.mean() prob_old = torch.cat(prob_olds, dim=-1) prob_new = torch.cat(prob_news, dim=-1) ratio = torch.exp(prob_new - prob_old.detach()) surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1 - self.config['epsilon'], 1 + self.config['epsilon']) * advantages p_loss = - torch.min(surr1, surr2).sum() e_loss = action_dist.entropy().mean() total_loss = v_loss + p_loss - e_loss * self.config['w_ent'] + kl_loss * self.config['w_kl'] if generator_optimizer is not None: generator_optimizer.zero_grad() total_loss.backward() grad_norm = get_grad_norm(get_models_parameters(*list(target_policy)+[value_net,])) generator_optimizer.step() return v_loss, p_loss, e_loss, kl_loss, total_loss