''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
from copy import deepcopy
import torch
import numpy as np
from loguru import logger
import matplotlib.pyplot as plt
from collections import deque
from torch import nn
from torch.functional import F
from revive.computation.graph import *
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.dataset import data_creator
from revive.algo.venv.base import VenvOperator, catch_error
try:
from revive.dist.algo.venv.revive import ReviveOperator
logger.info(f"Import encryption venv algorithm module -> ReviveOperator!")
except:
from revive.algo.venv.revive import ReviveOperator
logger.info(f"Import venv algorithm module -> ReviveOperator!")
[docs]
class PPOOperator(ReviveOperator):
NAME = "REVIVE_VENV"
PARAMETER_DESCRIPTION = [
{
"name" : "bc_batch_size",
"abbreviation" : "bbs",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "bc_epoch",
"abbreviation" : "bep",
"type" : int,
"default" : 0,
'doc': True,
},
{
"name" : "bc_lr",
"type" : float,
"default" : 1e-3,
},
{
"name" : "bc_loss_type",
"description": 'Bc support different loss function("nll", "mae", "mse").',
"type" : str,
"default" : "nll",
},
{
"name" : "revive_batch_size",
"description": "Batch size of training process.",
"abbreviation" : "mbs",
"type" : int,
"default" : 1024,
'doc': True,
},
{
"name" : "revive_epoch",
"description": "Number of epcoh for the training process",
"abbreviation" : "mep",
"type" : int,
"default" : 1000,
'doc': True,
},
{
"name" : "fintune",
"abbreviation" : "bet",
"type" : int,
"default" : 1,
'doc': True,
},
{
"name" : "finetune_fre",
"abbreviation" : "betfre",
"type" : int,
"default" : 1,
'doc': True,
},
{
"name" : "matcher_pretrain_epoch",
"abbreviation" : "dpe",
"type" : int,
"default" : 0,
},
{
"name" : "policy_hidden_features",
"description": "Number of neurons per layer of the policy network.",
"abbreviation" : "phf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_hidden_layers",
"description": "Depth of policy network.",
"abbreviation" : "phl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "policy_activation",
"abbreviation" : "pa",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "policy_normalization",
"abbreviation" : "pn",
"type" : str,
"default" : None,
},
{
"name" : "policy_backbone",
"description": "Backbone of policy network. Support selecting from [mlp, res, ft_transformer, lstm, gru].",
"abbreviation" : "pb",
"type" : str,
"default" : "res",
'doc': True,
},
{
"name" : "transition_hidden_features",
"description": "Number of neurons per layer of the transition network.",
"abbreviation" : "thf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "transition_hidden_layers",
"abbreviation" : "thl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "transition_activation",
"abbreviation" : "ta",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "transition_normalization",
"abbreviation" : "tn",
"type" : str,
"default" : None,
},
{
"name" : "transition_backbone",
"description": "Backbone of Transition network. Support selecting from [mlp, res, ft_transformer, lstm, gru].",
"abbreviation" : "tb",
"type" : str,
"default" : "res",
'doc': True,
},
{
"name" : "matching_nodes",
"type" : list,
"default" : 'auto',
},
{
"name" : "matching_fit_nodes",
"type" : list,
"default" : 'auto',
},
{
"name" : "matcher_hidden_features",
"description": "Number of neurons per layer of the matcher network.",
"abbreviation" : "dhf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "matcher_hidden_layers",
"description": "Depth of the matcher network.",
"abbreviation" : "dhl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "matcher_activation",
"abbreviation" : "da",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "matcher_normalization",
"abbreviation" : "dn",
"type" : str,
"default" : None,
},
{
"name" : "state_nodes",
"type" : list,
"default" : 'auto',
},
{
"name" : "value_hidden_features",
"abbreviation" : "vhf",
"type" : int,
"default" : 256,
},
{
"name" : "value_hidden_layers",
"abbreviation" : "vhl",
"type" : int,
"default" : 4,
},
{
"name" : "value_activation",
"abbreviation" : "va",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "value_normalization",
"abbreviation" : "vn",
"type" : str,
"default" : None,
},
{
"name" : "generator_type",
"abbreviation" : "gt",
"type" : str,
"default" : "res",
},
{
"name" : "matcher_type",
"abbreviation" : "dt",
"type" : str,
"default" : "res",
},
{
"name" : "birnn",
"type" : bool,
"default" : False,
},
{
"name" : "std_adapt_strategy",
"abbreviation" : "sas",
"type" : str,
"default" : None,
},
{
"name" : "generator_algo",
"abbreviation" : "ga",
"type" : str,
"default" : "ppo",
},
{
"name" : "ppo_runs",
"type" : int,
"default" : 2,
},
{
"name" : "ppo_epsilon",
"type" : float,
"default" : 0.2,
},
{
"name" : "ppo_l2norm_cof",
"type" : float,
"default" : 0,
},
{
"name" : "ppo_entropy_cof",
"type" : float,
"default" : 0,
},
{
"name" : "generator_sup_cof",
"type" : float,
"default" : 0,
},
{
"name" : "gae_gamma",
"type" : float,
"default" : 0.99,
},
{
"name" : "gae_lambda",
"type" : float,
"default" : 0.95,
},
{
"name" : "g_steps",
"description": "The number of update rounds of the generator in each epoch.",
"type" : int,
"default" : 1,
"search_mode" : "grid",
"search_values" : [1, 3, 5],
'doc': True,
},
{
"name" : "d_steps",
"description": "Number of update rounds of matcher in each epoch.",
"type" : int,
"default" : 1,
"search_mode" : "grid",
"search_values" : [1, 3, 5],
'doc': True,
},
{
"name" : "g_lr",
"description": "Initial learning rate of the generator nodes nets.",
"type" : float,
"default" : 4e-5,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-4],
'doc': True,
},
{
"name" : "d_lr",
"description": "Initial learning rate of the matcher.",
"type" : float,
"default" : 6e-4,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-3],
'doc': True,
},
{
"name" : "matcher_loss_length",
"description": "Matcher loss length.",
"type" : int,
"default" : 0,
},
{
"name" : "matcher_loss_high",
"description": "Matcher loss high value. When the matcher_loss beyond the value, the generator would stop train",
"type" : float,
"default" : 1.2,
},
{
"name" : "matcher_loss_low",
"description": "Matcher loss high value. When the matcher_loss low the value, the matcher would stop train",
"type" : float,
"default" : 0.3,
},
{
"name" : "matcher_sample",
"description": "Sample the data for tring the matcher.",
"type" : bool,
"default" : False,
},
{
"name" : "mae_reward_weight",
"description": "reward = (1-mae_reward_weight)*matcher_reward + mae_reward_weight*mae_reward.",
"type" : float,
"default" : 0.25,
},
{
"name" : "generator_data_repeat",
"description": "Repeat rollout more data to train generator.",
"type" : int,
"default" : 1,
},
{
"name" : "rnn_hidden_features",
"description": "RNN hidden dims",
"type" : int,
"default" : 64,
},
{
"name" : "window_size",
"description": "length of the sliding_window in RNN",
"type" : int,
"default" : 0,
},
{
"name" : "bc_weight_decay",
"description": "weight_decay in bc finetune",
"type" : float,
"default" : 1e-4,
'doc': True,
},
{
"name" : "mix_data",
"type" : bool,
"default" : False,
},
{
"name" : "quantile",
"type" : float,
"default" : 0.95,
},
{
"name" : "matcher_grad_norm_clip",
"type" : float,
"default" : 0.5,
},
{
"name" : "mix_sample",
"type" : bool,
"default" : False,
},
{
"name" : "mix_sample_ratio",
"type" : float,
"default" : 0.5,
},
{
"name" : "gp_coef",
"type" : float,
"default" : 0.0,
},
{
"name" : "discr_ent_coef",
"type" : float,
"default" : 0.01,
},
{
"name" : "matcher_l2_norm_coeff",
"type" : float,
"default" : 0.0005,
},
{
"name" : "value_l2_norm_coef",
"type" : float,
"default" : 1e-6,
},
{
"name" : "generator_l2_norm_coef",
"type" : float,
"default" : 1e-6,
},
{
"name" : "bc_l2_coef",
"type" : float,
"default" : 5e-5,
},
{
"name" : "logstd_loss_coef",
"type" : float,
"default" : 0.0,
},
{
"name" : "entropy_coef",
"type" : float,
"default" : 0.0,
},
{
"name" : "bc_loss",
"type" : str,
"default" : "nll",
},
{
"name" : "ts_conv_nodes",
"type" : list,
"default" : "auto",
},
{
"name" : "controller_weight",
"type" : float,
"default" : 10,
},
{
"name" : "bc_repeat",
"type" : int,
"default" : -1,
},
]
[docs]
def generator_model_creator(self, config, graph):
"""
Create generator models.
:param config: configuration parameters
:return: all the models.
"""
total_dims = config['total_dims']
networks = []
# initialize policy node networks
for node_name in list(graph.keys()):
additional_kwargs = {"rnn_hidden_features": config['rnn_hidden_features'],
"window_size": config['window_size'],
'ts_conv_config': None, #ts node for conv layer
}
if node_name in graph.transition_map.values():
continue
node = graph.get_node(node_name)
if node.node_type != 'network':
logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
continue
input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
if hasattr(node, 'custom_node'):
input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
if config['ts_conv_nodes'] == 'auto':
pass
elif node_name in config['ts_conv_nodes'] :
temp_input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
has_ts = any('ts' in element for element in temp_input_dim)
if not has_ts:
pass
else:
input_dim, input_dim_config, ts_conv_net_config = \
dict_ts_conv(graph, temp_input_dim, total_dims, config,
net_hidden_features=config['policy_hidden_features'])
additional_kwargs['ts_conv_config'] = input_dim_config
additional_kwargs['ts_conv_net_config'] = ts_conv_net_config
input_dim_dict = get_input_dim_dict_from_graph(graph, node_name, total_dims)
node.initialize_network(input_dim, total_dims[node_name]['output'],
hidden_features=config['policy_hidden_features'],
hidden_layers=config['policy_hidden_layers'],
hidden_activation=config['policy_activation'],
norm=config['policy_normalization'],
backbone_type=config['policy_backbone'],
dist_config=config['dist_configs'][node_name],
is_transition=False,
input_dim_dict=input_dim_dict,
**additional_kwargs)
networks.append(node.get_network())
# initialize transition node networks
for node_name in graph.transition_map.values():
additional_kwargs = {"rnn_hidden_features": config['rnn_hidden_features'],
"window_size": config['window_size'],
'ts_conv_config': None, #ts node for conv layer
}
node = graph.get_node(node_name)
if node.node_type != 'network':
logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
continue
if node_name[5:] not in node.input_names:
transition_mode = 'global'
if config['transition_mode']:
warnings.warn('Fallback to global transition mode since the transition variable is not provided as an input!')
else:
transition_mode = config['transition_mode']
input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
if hasattr(node, 'custom_node'):
input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
if config['ts_conv_nodes'] == 'auto':
pass
elif node_name in config['ts_conv_nodes'] :
temp_input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
has_ts = any('ts' in element for element in temp_input_dim)
if not has_ts:
pass
else:
input_dim, input_dim_config, ts_conv_net_config = \
dict_ts_conv(graph, temp_input_dim, total_dims, config,
net_hidden_features=config['transition_hidden_features'])
additional_kwargs['ts_conv_config'] = input_dim_config
additional_kwargs['ts_conv_net_config'] = ts_conv_net_config
input_dim_dict = get_input_dim_dict_from_graph(graph, node_name, total_dims)
node.initialize_network(input_dim, total_dims[node_name]['output'],
hidden_features=config['transition_hidden_features'],
hidden_layers=config['transition_hidden_layers'],
hidden_activation=config['transition_activation'],
norm=config['transition_normalization'],
backbone_type=config['transition_backbone'],
dist_config=config['dist_configs'][node_name],
is_transition=True,
transition_mode=transition_mode,
obs_dim=total_dims[node_name]['input'],
input_dim_dict=input_dim_dict,
**additional_kwargs)
networks.append(node.get_network())
assert len(networks) > 0, 'at least one node need to be a network to run training!'
# create value function
# input_dim = 0
# state_nodes = graph.get_leaf() if config['state_nodes'] == 'auto' else config['state_nodes']
# for node_name in state_nodes:
# input_dim += total_dims[node_name]['input']
for nodes_list in self.matching_nodes_list:
additional_kwargs = {}
input_dim = {}
for i_node in nodes_list:
if i_node in graph.nodes and graph.nodes[i_node].node_type=='network':
continue
else:
input_dim[i_node] = total_dims[i_node]['input']
if config['ts_conv_nodes'] == 'auto':
pass
else:
if any(item in config['ts_conv_nodes'] for item in nodes_list):
temp_input_dim = {}
for i_node in nodes_list:
if i_node in graph.nodes and graph.nodes[i_node].node_type=='network':
continue
else:
temp_input_dim[i_node] = total_dims[i_node]['input']
has_ts = any('ts' in element for element in temp_input_dim)
if not has_ts:
pass
else:
input_dim, input_dim_config, ts_conv_net_config = \
dict_ts_conv(graph, temp_input_dim, total_dims, config,net_hidden_features=config['value_hidden_features'])
additional_kwargs['ts_conv_config'] = input_dim_config
additional_kwargs['ts_conv_net_config'] = ts_conv_net_config
else:
pass
networks.append(Value_Net(input_dim,
config['value_hidden_features'],
config['value_hidden_layers'],
config['value_normalization'],
config['value_activation'],
**additional_kwargs))
# # input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
# networks.append(MLP(input_dim, 1, config['value_hidden_features'], config['value_hidden_layers'],
# norm=config['value_normalization'], hidden_activation=config['value_activation']))
return networks
[docs]
def optimizer_creator(self, models, config):
"""
Optimizer creator including generator optimizers and matcher optimizers.
:param models: node models, matcher, value_net
:param config: configuration parameters
:return: generator_optimizer, matcher_optimizer
"""
node_models = models[:-self.matcher_num*2]
value_nets = models[-self.matcher_num*2:-self.matcher_num]
matcher_nets = models[-self.matcher_num:]
self.node_model_num = len(node_models)
assert self.node_model_num == self.config['learning_nodes_num'], "self.node_model_num != self.config['learning_nodes_num']"
bc_node_optimizer = torch.optim.Adam(get_models_parameters(*node_models), lr=config['bc_lr'], weight_decay=config['bc_weight_decay'])
generator_node_optimizer = torch.optim.Adam(get_models_parameters(*node_models, *value_nets), lr=config['g_lr'])
generator_value_optimizer = None
# [ TRAINING D ] RMSProp is able to reduce the discriminator shocking
matcher_optimizer = torch.optim.RMSprop(get_models_parameters(*matcher_nets), lr=config['d_lr'])
return [bc_node_optimizer, generator_node_optimizer, generator_value_optimizer, matcher_optimizer]
def _run_generator(self, expert_data, graph, generator_other_net, matcher, generator_optimizer=None, other_generator_optimizers=None, matcher_index=None,scope=None):
generated_data = self._generate_rollout(expert_data, graph, matcher, generate_reward=True, clip=1.5)
if self.config["pre_horzion"] > 0:
expert_data = expert_data[self.config["pre_horzion"]:]
generated_data = generated_data[self.config["pre_horzion"]:]
value_net = generator_other_net
if isinstance(value_net,Value_Net):
value_state = generated_data
else:
value_state = get_concat_traj(generated_data, self.state_nodes)
if self.config['generator_algo'] == 'ppo':
reward = generated_data.reward.detach()
masks = torch.zeros_like(reward)
masks[-1] = 1
value = value_net(value_state).detach()
generated_data.advs, generated_data.returns = self.ADV(reward,
masks,
value,
gamma=self.config['gae_gamma'],
lam=self.config['gae_lambda'])
_repeat = self.config['ppo_runs']
info_list = []
for k in range(_repeat):
_batch_num = 8
idx = torch.randperm(generated_data.shape[1]).to(self._device)
new_idx = torch.chunk(idx, _batch_num)
for index in new_idx:
info = self.PPO_step(generated_data[:, index],
graph,
value_net,
generator_optimizer,
matcher,
other_generator_optimizers,
epsilon=self.config['ppo_epsilon'],
lam=self.config['ppo_l2norm_cof'],
w_ent=self.config['ppo_entropy_cof'],
matcher_index=matcher_index,
scope=scope)
info_list.append(info)
# [ OTHER ] add more logs
for k in info:
info[k] = np.mean([item[k] for item in info_list])
info['adv_max'] = torch.max(generated_data.advs).item()
info['adv_min'] = torch.min(generated_data.advs).item()
for k in list(info.keys()):
info[self.NAME + '/' + k] = info.pop(k)
elif self.config['generator_algo'] == 'svg':
raise NotImplementedError
"""
# TODO: SVG does not work now, probably due to the normalization of reward.
# TO BE FIXXED!
values = value_net(value_state)
lambda_returns = compute_lambda_return(generated_data.reward, values)
p_loss = - lambda_returns.mean()
v_loss = ((value_net(value_state.detach()) - lambda_returns.detach()) ** 2).mean()
info = {
f"{self.NAME}/v_loss" : v_loss.mean().item(),
f"{self.NAME}/p_loss" : p_loss.mean().item(),
f"{self.NAME}/return" : lambda_returns.mean().item(),
f"{self.NAME}/value" : values.mean().item()
}
if generator_optimizer is not None:
generator_optimizer.zero_grad()
value_net.requires_grad_(False)
p_loss.backward()
value_net.requires_grad_(True)
v_loss.backward()
generator_grad_norm = nn.utils.clip_grad_norm_(get_models_parameters(*graph.collect_models(), value_net), 100)
if torch.isnan(generator_grad_norm.mean()):
self.nan_in_grad()
logger.info(f'Detect nan in gradient, skip this batch! (p_loss : {p_loss}, v_loss: {v_loss}, grad_norm : {generator_grad_norm})')
else:
generator_optimizer.step()
info[f"{self.NAME}/generator_grad_norm"] = generator_grad_norm.item()
"""
return info, generated_data
[docs]
def ADV(self, reward, mask, value, gamma, lam, use_gae=True):
"""
Compute advantage function for PPO.
:param reward: rewards of each step
:param mask: mask is 1 if the trajectory done, else 0
:param value: value for each state
:param gamma: discount factor
:param lam: GAE lamda
:param use_gae: True or False
:return: advantages and new value
"""
returns = torch.zeros_like(reward)
advantages = torch.zeros_like(reward)
if not use_gae:
pre_value, pre_adv = 0, 0
for t in reversed(range(reward.shape[0])):
advantages[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t]
pre_value = value[t]
returns = value + advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-4)
else:
td_error = torch.zeros_like(reward)
pre_value, pre_adv, pre_ret = 0, 0, 0
# [ TRAIN POLICY ] use real return as return estimation
for t in reversed(range(reward.shape[0])):
td_error[t] = reward[t] + gamma * pre_value * (1 - mask[t]) - value[t]
advantages[t] = td_error[t] + gamma * lam * pre_adv * (1 - mask[t])
returns[t] = reward[t] + gamma * (1 - mask[t]) * pre_ret
pre_adv = advantages[t]
pre_value = value[t]
pre_ret = returns[t]
# returns = value + advantages
# [ TRAIN POLICY ] return is the target of the value function, normalization is not needed
# returns = (returns - returns.mean()) / (returns.std() + 1e-4)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-4)
return advantages, returns
[docs]
def PPO_step(self,
generated_data,
graph,
value_net,
generator_optimizers,
matcher,
other_generator_optimizers,
epsilon=0.1,
lam=0,
w_ent=0,
matcher_index=None,
scope=None):
"""
Train Policy including policies, transition, and value_net by PPO algorithm.
:param generated_data: generated trajectory
:param graph: decision graph
:param value_net: value net
:param generator_optimizers: the optimizers used to optimize node models and value net
:param epsilon: hyperparameter for clipping in the policy objective
:param lam: regularization parameter
:param w_ent: the weight of entropy loss
:return: v_loss, p_loss, sup_loss, total_loss, generator_grad_norm
"""
assert scope in ["train", "val"], f"scope : {scope}"
info = {}
# update critic
value_state = generated_data #get_concat_traj(generated_data, self.state_nodes)
value_o = value_net(value_state.detach())
v_loss = (value_o - generated_data.returns.detach()).pow(2).mean()
if self.config['value_l2_norm_coef']>0:
v_loss += self.net_l2_norm(value_net) * self.config['value_l2_norm_coef']
# update actor
graph.reset()
prob_olds = []
prob_news = []
e_loss = 0
l2_norm_loss = 0
stds = []
matching_fit_nodes = matcher.matching_fit_nodes
for node_name in graph.keys():
# Skip not fit nodes
if node_name not in matching_fit_nodes:
continue
node = graph.get_node(node_name)
if node.node_type == 'network':
action_dist = graph.compute_node(node_name, generated_data)
policy_entropy = action_dist.entropy()
stds.append(action_dist.std.mean().item())
e_loss += policy_entropy.mean()
l2_norm_loss += self.net_l2_norm(graph.get_node(node_name).network)
info[node_name + '_entropy'] = policy_entropy.mean().item()
prob_olds.append(generated_data[node_name + '_log_prob'])
prob_news.append(action_dist.log_prob(generated_data[node_name]).unsqueeze(dim=-1))
prob_old = torch.cat(prob_olds, dim=-1)
prob_new = torch.cat(prob_news, dim=-1)
clipped_prob = torch.clamp_max(prob_new - prob_old.detach(), 40)
ratio = torch.exp(clipped_prob)
# [ OTHER ] more logs
kl = (prob_old - prob_new).detach()
surr1 = ratio * generated_data.advs
surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * generated_data.advs
p_loss = - torch.min(surr1, surr2).mean()
total_loss = v_loss + p_loss - e_loss * w_ent + self.config['generator_l2_norm_coef'] * l2_norm_loss
info.update({
f"v_loss_{matcher_index}" : v_loss.mean().item(),
f"p_loss_{matcher_index}" : p_loss.mean().item(),
f"e_loss_{matcher_index}" : e_loss.mean().item(),
f"total_loss_{matcher_index}": total_loss.mean().item(),
f"ratio_{matcher_index}" : ratio.mean().item(),
f"value_{matcher_index}" : value_o.mean().item(),
f"return_{matcher_index}" : generated_data.returns.mean().item(),
f"advantages_{matcher_index}" : generated_data.advs.mean().item(),
f"std_{matcher_index}": np.mean(stds),
f"kl_{matcher_index}": kl.mean().item(),
f"max_kl_{matcher_index}": kl.max().item(),
f"min_kl_{matcher_index}": kl.min().item(),
f"max_logp_new_{matcher_index}": prob_new.max().item(),
f"min_logp_new_{matcher_index}": prob_new.min().item(),
f"max_logp_old_{matcher_index}": prob_old.max().item(),
f"min_logp_old_{matcher_index}": prob_old.min().item(),
f'mean_logp_old_{matcher_index}': prob_old.mean().item(),
f"mean_logp_new_{matcher_index}": prob_new.mean().item()
})
if generator_optimizers is not None:
generator_optimizers.zero_grad(set_to_none=True)
total_loss.backward()
policy_grad_norm = nn.utils.clip_grad_norm_(get_models_parameters(*graph.collect_models()), 100)
if torch.isnan(policy_grad_norm.mean()):
self.nan_in_grad()
logger.info(f'Detect nan in gradient, skip this batch! (loss : {total_loss}, grad_norm : {policy_grad_norm})')
else:
generator_optimizers.step()
info[f"ppo_grad_norm"] = policy_grad_norm.item()
return info