Source code for revive.algo.policy.ppo

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

from typing import OrderedDict
import torch
from copy import deepcopy

from revive.computation.dists import kl_divergence
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.dataset import data_creator
from revive.utils.raysgd_utils import BATCH_SIZE
from revive.algo.policy.base import PolicyOperator, catch_error


[docs] class PPOOperator(PolicyOperator): """ A class used to train target policy. """ NAME = "PPO" PARAMETER_DESCRIPTION = [ { "name" : "ppo_batch_size", "description": "Batch size of training process.", "abbreviation" : "pbs", "type" : int, "default" : 256, 'doc': True, }, { "name" : "policy_bc_epoch", "description": "pre-train policy with setting epoch", "type" : int, "default" : 0, 'doc': True, }, { "name" : "ppo_epoch", "description": "Number of epcoh for the training process", "abbreviation" : "bep", "type" : int, "default" : 1000, 'doc': True, }, { "name" : "ppo_rollout_horizon", "description": "Rollout length of the policy train.", "abbreviation" : "prh", "type" : int, "default" : 100, 'doc': True, }, { "name" : "policy_hidden_features", "description": "Number of neurons per layer of the policy network.", "abbreviation" : "phf", "type" : int, "default" : 256, 'doc': True, }, { "name" : "policy_hidden_layers", "description": "Depth of policy network.", "abbreviation" : "phl", "type" : int, "default" : 4, 'doc': True, }, { "name" : "policy_backbone", "description": "Backbone of policy network.[mlp, res, ft_transformer]", "abbreviation" : "pb", "type" : str, "default" : "res", 'doc': True, }, { "name" : "value_hidden_features", "abbreviation" : "vhf", "type" : int, "default" : 256, }, { "name" : "value_hidden_layers", "abbreviation" : "vhl", "type" : int, "default" : 4, }, { "name" : "ppo_runs", "type" : int, "default" : 2, }, { "name" : "epsilon", "type" : float, "default" : 0.2, }, { "name" : "w_vl2", "type" : float, "default" : 0.001, }, { "name" : "w_ent", "type" : float, "default" : 0.0, }, { "name" : "w_kl", "type" : float, "default" : 1.0, }, { "name" : "gae_gamma", "type" : float, "default" : 0.99, }, { "name" : "gae_lambda", "type" : float, "default" : 0.95, }, { "name" : "g_lr", "description": "Initial learning rate of the training process.", "type" : float, "default" : 4e-5, "search_mode" : "continuous", "search_values" : [1e-6, 1e-3], 'doc': True, }, { "name" : "reward_uncertainty_weight", "description": "Reward uncertainty weight(MOPO)", "type" : float, "default" : 0, "search_mode" : "continuous", }, { "name" : "filter", "type" : bool, "default" : False, }, { "name" : "candidate_num", "type" : int, "default" : 50, }, { "name" : "ensemble_size", "type" : int, "default" : 50, }, { "name" : "ensemble_choosing_interval", "type" : int, "default" : 10, }, { "name" : "disturbing_transition_function", "description": "Disturbing the network node in policy learning", "type": bool, "default": False, }, { "name" : "disturbing_nodes", "description": "Disturbing the network node in policy learning", "type": list, "default": 'auto', }, { "name" : "disturbing_net_num", "description": "Disturbing the network node in policy learning", "type": int, "default": 100, }, { "name" : "disturbing_weight", "description": "Disturbing the network node in policy learning", "type": float, "default": 0.05, }, { "name" : "generate_deter", "description": "deterministic of generator rollout", "type" : int, "default" : 1 }, ] @property def policy_bc_optimizer(self): return self.val_models[:-1]
[docs] def model_creator(self, nodes): """ Create models including target policy and value net. :return: env models, target policy, value net """ models = super().model_creator(nodes) graph = self._graph value_net_input_dim = 0 for node_name in graph.get_leaf(graph): value_net_input_dim += get_node_dim_from_dist_configs(self.config["dist_configs"], node_name) value_net = MLP(value_net_input_dim, 1, self.config['value_hidden_features'], self.config['value_hidden_layers']) models.append(value_net) return models
[docs] def optimizer_creator(self, scope): """ :return: generator optimizers including target policy optimizers and value net optimizers """ if scope == "train": target_policys = self.train_policy assert len(self.other_train_models) == 1 value_net = self.other_train_models[0] else: target_policys = self.val_policy assert len(self.other_val_models) == 1 value_net = self.other_val_models[0] models_params = [] optimizers = [] for target_policy in target_policys: models_params.append({'params': target_policy.parameters(), 'lr': self.config['g_lr']}) models_params.append({'params': value_net.parameters(), 'lr': self.config['g_lr']}) generator_optimizer = torch.optim.Adam(models_params) optimizers.append(generator_optimizer) return optimizers
[docs] def data_creator(self): self.config[BATCH_SIZE] = self.config['ppo_batch_size'] return data_creator(self.config, training_mode='trajectory', training_horizon=self.config['ppo_rollout_horizon'], training_is_sample=True, val_horizon=self.config['test_horizon'], double=self.double_validation)
@catch_error def after_train_epoch(self): if self._epoch_cnt >= self.config['ppo_epoch'] + self.config['policy_bc_epoch']: self._stop_flag = True @catch_error def train_batch(self, expert_data, batch_info, scope='train'): if scope == 'train': target_policy = self.train_models[:-1] value_net = self.train_models[-1] generator_optimizer = self.train_optimizers[0] envs = self.envs_train else: assert self.double_validation target_policy = self.val_models[:-1] value_net = self.val_models[-1] generator_optimizer = self.val_optimizers[0] envs = self.envs_val expert_data.to_torch(device=self._device) with torch.no_grad(): #generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['ppo_rollout_horizon'], deterministic=False, clip=True) generated_data, info = self._run_rollout(expert_data, target_policy, envs, traj_length=self.config['ppo_rollout_horizon'], deterministic=bool(self.config['generate_deter']), clip=True) info.update(self._run_update(expert_data, generated_data, target_policy, value_net, generator_optimizer)) for k in list(info.keys()): info[f'{k}_{scope}'] = info.pop(k) info = self._early_stop(info) return info def _run_update(self, expert_data, generated_data, target_policy, value_net, generator_optimizer=None): states = OrderedDict() actions = OrderedDict() action_log_probs = OrderedDict() for policy_name in self.policy_name: states[policy_name] = get_input_from_graph(self._graph, policy_name , generated_data) if hasattr(self.train_nodes[policy_name], 'custom_node') or self.input_is_dict: states[policy_name] = get_input_dict_from_graph(self._graph, policy_name , generated_data) actions[policy_name] = generated_data[policy_name] action_log_probs[policy_name] = generated_data[policy_name + '_log_prob'] value_net_inputs = OrderedDict() for node_name in self._graph.get_leaf(self._graph): value_net_inputs[node_name] = generated_data[node_name] value_net_states = torch.cat(list(value_net_inputs.values()), dim=-1) rewards = generated_data.reward _values = value_net(value_net_states) _masks = generated_data["done_mask"] valid_masks = generated_data["valid_masks"] advs, returns = self.ADV(rewards.detach(), _masks, _values.detach(), gamma=self.config['gae_gamma'], lam=self.config['gae_lambda']) _repeat = self.config['ppo_runs'] for k in range(_repeat): v_loss, p_loss, e_loss, kl_loss, total_loss = self.ppo_step(target_policy, value_net, generator_optimizer, states, actions, returns, advs, action_log_probs, expert_data, value_net_states,valid_masks) info = { "PPO/v_loss" : v_loss.mean().item(), "PPO/p_loss" : p_loss.mean().item(), "PPO/e_loss" : e_loss.mean().item(), "PPO/kl_loss" : kl_loss.mean().item(), "PPO/total_loss": total_loss.mean().item(), "PPO/mean_reward": rewards.mean().item(), "PPO/min_reward": rewards.min().item(), "PPO/max_reward": rewards.max().item(), } return info
[docs] def ADV(self, reward, mask, value, gamma, lam, use_gae=True): """ Compute advantage function for PPO. :param reward: rewards of each step :param mask: mask is 1 if the trajectory done, else 0 :param value: value for each state :param gamma: discount factor :param lam: GAE lamda :param use_gae: True or False :return: advantages and new value """ advantages = torch.zeros_like(reward) if not use_gae: pre_value, pre_adv = 0, 0 for t in reversed(range(reward.shape[0])): advantages[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t] pre_value = value[t] returns = value + advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10) else: td_error = torch.zeros_like(reward) pre_value, pre_adv = 0, 0 for t in reversed(range(reward.shape[0])): td_error[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t] advantages[t] = td_error[t] + gamma * lam * pre_adv * (1 - mask[t]) pre_adv = advantages[t] pre_value = value[t] returns = value + advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10) return advantages, returns
[docs] def ppo_step(self, target_policy, value_net, generator_optimizer, states, actions, ret, advantages, action_log_probs, expert_data, value_net_states,valid_masks=1): """ Train target_policy and value_net by PPO algorithm. :param target_policy: target policy :param value_net: value net :param generator_optimizer: the optimizers used to optimize target policy and value net :param state: state of target policy :param action: action of target policy :param ret: GAE value :param adv: advantage :param action_log_prob: action log probability of target policy :param expert_data: batch of expert data :return: v_loss, p_loss, total_loss """ # updata critic value_o = value_net(value_net_states.detach()) value_o = value_o.view(-1, value_o.shape[-1]) returns = ret.view(-1, ret.shape[-1]) v_loss = ((value_o - returns.detach())*valid_masks.view(-1, valid_masks.shape[-1])).pow(2).mean() for param in value_net.parameters(): v_loss += param.pow(2).sum() * self.config['w_vl2'] # compute kl loss kl_loss = 0 for policy, behaviour_policy, state in zip(target_policy, self.behaviour_policys, states.values()): policy.reset() behaviour_policy.reset() q = behaviour_policy(state) p = policy(state) kl_loss += (kl_divergence(q, p)* torch.squeeze(valid_masks, dim=-1)).mean() if isinstance(kl_loss,int): kl_loss = torch.tensor([0],dtype=torch.float32).to(v_loss.device) # update actor prob_olds = [] prob_news = [] e_loss = 0 for policy, state, action_log_prob, action in zip(target_policy,states.values(),action_log_probs.values(),actions.values()): policy.reset() prob_old = action_log_prob action_dist = policy(state) policy_entropy = action_dist.entropy().unsqueeze(dim=-1) prob_new = action_dist.log_prob(action).unsqueeze(dim=-1) prob_olds.append(prob_old) prob_news.append(prob_new) e_loss += (policy_entropy*valid_masks).mean() prob_old = torch.cat(prob_olds, dim=-1) prob_new = torch.cat(prob_news, dim=-1) ratio = torch.exp(prob_new - prob_old.detach()) surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1 - self.config['epsilon'], 1 + self.config['epsilon']) * advantages p_loss = - (torch.min(surr1, surr2)*valid_masks).sum() e_loss = (action_dist.entropy().unsqueeze(dim=-1)*valid_masks).mean() total_loss = v_loss + p_loss - e_loss * self.config['w_ent'] + kl_loss * self.config['w_kl'] if generator_optimizer is not None: generator_optimizer.zero_grad() total_loss.backward() grad_norm = get_grad_norm(get_models_parameters(*list(target_policy)+[value_net,])) generator_optimizer.step() return v_loss, p_loss, e_loss, kl_loss, total_loss