''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
from copy import deepcopy
import torch
import numpy as np
from loguru import logger
import matplotlib.pyplot as plt
from collections import deque
from torch import nn
from torch.functional import F
from ray import train
from revive.utils.raysgd_utils import BATCH_SIZE, NUM_SAMPLES, AverageMeterCollection
from revive.computation.graph import *
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.dataset import data_creator
from revive.algo.venv.base import VenvOperator, catch_error
[docs]class ReviveOperator(VenvOperator):
NAME = "REVIVE"
[docs] def matcher_model_creator(self, config, graph):
"""
Create matcher models.
:param config: configuration parameters
:return: all the models.
"""
networks = []
self.matching_nodes = graph.get_relation_node_names() if config['matching_nodes'] == 'auto' else config['matching_nodes']
if isinstance(self.matching_nodes[0], str):
self.matching_nodes_list = [self.matching_nodes, ]
else:
self.matching_nodes_list = self.matching_nodes
logger.info(f"matching_nodes_list : {self.matching_nodes_list}")
self.matching_fit_nodes_list = self.matching_nodes_list if ("matching_fit_nodes" not in config.keys() or config['matching_fit_nodes'] == 'auto') else config['matching_fit_nodes']
logger.info(f"matching_fit_nodes_list : {self.matching_fit_nodes_list}")
self.matching_nodes_fit_index_list = []
self.matcher_num = len(self.matching_nodes_list)
for matching_nodes,matching_fit_nodes in zip(self.matching_nodes_list, self.matching_fit_nodes_list):
input_dim = 0
# del nodata node
for node in graph.nodata_node_names:
if node in matching_nodes:
logger.info(f"{matching_nodes}, {node}")
matching_nodes.remove(node)
matching_nodes_fit_index = {}
for node_name in matching_nodes:
oral_node_name = node_name
if node_name.startswith("next_"):
node_name = node_name[5:]
assert node_name in self._graph.fit.keys(), f"Node name {oral_node_name} not in {self._graph.fit.keys()}"
input_dim += np.sum(self._graph.fit[node_name])
matching_nodes_fit_index[oral_node_name] = self._graph.fit[node_name]
self.matching_nodes_fit_index_list.append(matching_nodes_fit_index)
logger.info(f"Matcher nodes: {matching_nodes}, Input dim : {input_dim}")
if config['matcher_type'] in ['mlp', 'res', 'transformer', 'ft_transformer']:
matcher_network = FeedForwardMatcher(input_dim,
config['matcher_hidden_features'],
config['matcher_hidden_layers'],
norm=config['matcher_normalization'],
hidden_activation=config['matcher_activation'],
backbone_type=config['matcher_type'])
matcher_network.matching_nodes = matching_nodes # input to matcher
matcher_network.matching_fit_nodes = matching_fit_nodes
networks.append(matcher_network)
elif config['matcher_type'] in ['gru', 'lstm']:
matcher_network = RecurrentMatcher(input_dim,
config['matcher_hidden_features'],
config['matcher_hidden_layers'],
backbone_type=config['matcher_type'],
bidirect=config['birnn'])
matcher_network.matching_nodes = matching_nodes
matcher_network.matching_fit_nodes = matching_fit_nodes
networks.append(matcher_network)
elif config['matcher_type'] == 'hierarchical':
raise DeprecationWarning('This may not correctly work due to registration of functions and leaf nodes')
dims = []
dims.append(total_dims['obs']['input'])
for policy_name in list(graph.keys())[:-1]:
dims.append(total_dims[policy_name]['input'])
dims.append(total_dims['obs']['input'])
networks.append(HierarchicalMatcher(dims, config['matcher_hidden_features'], config['matcher_hidden_layers'],
norm=config['matcher_normalization'], hidden_activation=config['matcher_activation']))
return networks
[docs] def model_creator(self, config, graph):
matcher_networks = self.matcher_model_creator(config, graph)
generator_networks = self.generator_model_creator(config, graph)
return generator_networks + matcher_networks
[docs] def data_creator(self, config : dict):
config[BATCH_SIZE] = config['bc_batch_size']
if config['policy_backbone'] in ['lstm', 'gru', 'contextual_lstm', 'contextual_gru'] or config['transition_backbone'] in ['lstm', 'gru']:
return data_creator(config, training_is_sample=False, val_horizon=config['venv_rollout_horizon'], double=True)
else:
return data_creator(config, training_mode='transition', training_is_sample=False, val_horizon=config['venv_rollout_horizon'], double=True)
[docs] def switch_data_loader(self):
self.config[BATCH_SIZE] = self.config['revive_batch_size']
train_loader_train, val_loader_train, train_loader_val, val_loader_val = \
data_creator(self.config, training_horizon=self.config['venv_rollout_horizon'], val_horizon=self.config['venv_rollout_horizon'], double=True)
try:
self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False)
self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False)
self._train_loader_val = train.torch.prepare_data_loader(train_loader_val, move_to_device=False)
self._val_loader_val = train.torch.prepare_data_loader(val_loader_val, move_to_device=False)
except:
self._train_loader_train = train_loader_train
self._val_loader_train = val_loader_train
self._train_loader_val = train_loader_val
self._val_loader_val = val_loader_val
@catch_error
def __init__(self, config):
super().__init__(config)
self._v_loss_list = []
self.adapt_stds = [None] * len(self._graph)
# self.state_nodes = self._graph.get_relation_node_names() if config['state_nodes'] == 'auto' else config['state_nodes']
self.state_nodes = self._graph.get_leaf() if config['state_nodes'] == 'auto' else config['state_nodes']
self.history_matcher_train = deque(maxlen=config["history_matcher_num"])
self.history_matcher_val = deque(maxlen=config["history_matcher_num"])
self.matcher_loss_length = config["matcher_loss_length"]
if self.matcher_loss_length:
self.matcher_loss_low = config["matcher_loss_low"]
self.matcher_loss_high = config["matcher_loss_high"]
self.history_matcher_loss_train = deque(maxlen=self.matcher_loss_length)
self.history_matcher_loss_val = deque(maxlen=self.matcher_loss_length)
self.stop_matcher_update = False
self.stop_generator_update_train = False
self.stop_generator_update_val = False
logger.info(f'Using {self.matching_nodes} as matching nodes!')
logger.info(f'Using {self.state_nodes} as state nodes!')
def _early_stop(self, info):
# early stop if last 50 v_loss all greater than 1e4
if 'REVIVE/v_loss_train' in info.keys():
self._v_loss_list.append(max(info['REVIVE/v_loss_train'], info['REVIVE/v_loss_val']))
_last_v_losses = np.array(self._v_loss_list[-50:])
if np.all(_last_v_losses > 1e4):
logger.info('Early stop triggered by value explosion!')
self._stop_flag = self._stop_flag or True
return super()._early_stop(info)
[docs] def bc_train_batch(self, expert_data, batch_info, scope='train', loss_type="nll"):
self._batch_cnt += 1
if scope == 'train':
models = self.train_models
graph = self.graph_train
optimizer = self.train_optimizers[0]
else:
models = self.val_models
graph = self.graph_val
optimizer = self.val_optimizers[0]
graph.reset()
expert_data.to_torch(device=self._device)
info = {}
loss = 0
_generated_data = Batch()
rnn_flag = True if (self.config['policy_backbone'] in ['gru', 'lstm']) or (self.config['transition_backbone'] in ['gru', 'lstm']) else False
if rnn_flag:
sample_fn = lambda dist: dist.rsample()
generated_data = generate_rollout_bc(expert_data, graph, expert_data.shape[0], sample_fn, self.adapt_stds, clip=1.5)
for node_name in graph.keys():
node = graph.get_node(node_name)
if node.node_type == 'network':
isnan_index_list = []
isnan_index = 1.
# check whether nan is in inputs
for node_name_ in node.input_names:
if (node_name_ + "_isnan_index_") in expert_data.keys():
isnan_index_list.append(expert_data[node_name_ + "_isnan_index_"])
# check whether nan is in outputs
if (node_name + "_isnan_index_") in expert_data.keys():
isnan_index_list.append(expert_data[node_name + "_isnan_index_"])
if isnan_index_list:
isnan_index, _ = torch.max(torch.cat(isnan_index_list, axis=-1), axis=-1, keepdim=True)
isnan_index = 1 - isnan_index
if not rnn_flag:
action_dist = graph.compute_node(node_name, expert_data)
# use rollout data as expert data for nodata nodes
if node_name in self._graph.nodata_node_names:
expert_data[node_name] = action_dist.mode
continue
_loss_type = graph.nodes_loss_type.get(node_name, loss_type)
if _loss_type == "mae":
if rnn_flag:
policy_loss = ((generated_data[node_name] - expert_data[node_name])*isnan_index).abs().sum(dim=-1).mean()
else:
policy_loss = ((action_dist.mode - expert_data[node_name])*isnan_index).abs().sum(dim=-1).mean()
elif _loss_type == "mse":
if rnn_flag:
policy_loss = (((generated_data[node_name] - expert_data[node_name])*isnan_index)**2).sum(dim=-1).mean()
else:
policy_loss = (((action_dist.mode - expert_data[node_name])*isnan_index)**2).sum(dim=-1).mean()
elif _loss_type == "nll":
if rnn_flag:
_generated_data_log_prob = []
for i in range(expert_data.shape[0]):
_generated_data_log_prob.append(generated_data[node_name + '_dist' + f"_{i}"][i].log_prob(expert_data[node_name][i, ...]).unsqueeze(0))
_generated_data_log_prob = torch.vstack(_generated_data_log_prob).to(expert_data[node_name])
if isnan_index_list:
isnan_index = (torch.mean(isnan_index, dim=-1) > 0).type_as(isnan_index)
policy_loss = - (_generated_data_log_prob*isnan_index).mean()
else:
_generated_data[node_name + '_log_prob'] = action_dist.log_prob(expert_data[node_name])
_generated_data[node_name] = action_dist.sample()
if isnan_index_list:
isnan_index = (torch.mean(isnan_index, dim=-1) > 0).type_as(isnan_index)
policy_loss = - (_generated_data[node_name + '_log_prob']*isnan_index).mean()
elif _loss_type.startswith("user_module."):
loss_name = _loss_type[len("user_module."):]
loss_function = self.config["user_module"].get(loss_name, None)
assert loss_function is not None
kwargs = {
"node_dist" : action_dist,
"node_name" : node_name,
"isnan_index_list" : isnan_index_list,
"isnan_index" : isnan_index,
"graph" : graph,
"expert_data" : expert_data,
}
policy_loss = loss_function(kwargs)
else:
raise NotImplementedError
loss += policy_loss
info[f"{self.NAME}/{node_name}_loss_{scope}"] = policy_loss.item()
info[f"{self.NAME}/total_loss_{scope}"] = loss.item()
optimizer.zero_grad(set_to_none=True)
loss.backward()
grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(get_models_parameters(*models), 50)
if torch.any(torch.isnan(grad_norm)):
self.nan_in_grad()
logger.info(f'Detect nan in gradient, skip this batch! (loss : {loss}, grad_norm : {grad_norm})')
else:
optimizer.step()
info[f"{self.NAME}/grad_norm"] = grad_norm.item()
return info
@catch_error
def train_epoch(self):
info = {}
# switch to evaluate mode
if hasattr(self, "model"):
self.model.train()
if hasattr(self, "models"):
for _model in self.models:
_model.train()
if self._epoch_cnt == self.config['bc_epoch']:
self.switch_data_loader()
self._load_best_models()
self._epoch_cnt += 1
if self._epoch_cnt <= self.config['bc_epoch']:
# perform bc
metric_meters_train = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._train_loader_train)):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='train')
metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
metric_meters_val = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._val_loader_train)):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='val')
metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
info = metric_meters_train.summary()
info.update(metric_meters_val.summary())
return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
else:
if hasattr(self, "model"):
self.model.train()
if hasattr(self, "models"):
for _model in self.models:
_model.train()
# training on training set
graph = self.graph_train
generator_other_nets = self.other_models_train[:-self.matcher_num] # value net(s)
matchers = self.other_models_train[-self.matcher_num:]
generator_optimizers = self.train_optimizers[1]
other_generator_optimizers = self.train_optimizers[2] # value net optim(s)
matcher_optimizer = self.train_optimizers[-1]
self.scope = "train"
# [ TRAINING D ] avoid matcher loss shocking
for _ in range(10):
for i in range(self.config['d_steps']):
self.global_step += 1
expert_data = next(iter(self._train_loader_train))
expert_data.to_torch(device=self._device)
_, _info = self._run_matcher(expert_data, graph, matchers, matcher_optimizer, test=not self.config["matcher_sample"])
_info = {k + '_train' : v for k, v in _info.items()}
info.update(_info)
d_loss = info[f'{self.NAME}/matcher_loss_train']
if d_loss <= 1.35:
break
if self._epoch_cnt % self.config["history_matcher_save_epochs"] == 0:
logger.info(f"Epoch: {self._epoch_cnt}, Save train matcher in 'history_matcher_train'.")
self.history_matcher_train.append(deepcopy(matchers))
if self._epoch_cnt > self.config['bc_epoch'] + self.config['matcher_pretrain_epoch']:
for i in range(self.config['g_steps']):
self.global_step += 1
expert_data = next(iter(self._train_loader_train))
expert_data.to_torch(device=self._device)
for matcher_index in range(len(matchers)):
_info, generated_data = self._run_generator(deepcopy(expert_data), graph, generator_other_nets[matcher_index], matchers[matcher_index], generator_optimizers, other_generator_optimizers, matcher_index=matcher_index)
_info = {k + '_train' : v for k, v in _info.items()}
info.update(_info)
if self.config["fintune"] >= 1 and self._epoch_cnt%self.config['finetune_fre']==0:
rnn_flag = True if (self.config['policy_backbone'] in ['gru', 'lstm']) or (self.config['transition_backbone'] in ['gru', 'lstm']) else False
# print(f'BC is activated at epoch: {self._epoch_cnt}')
if not rnn_flag:
for _ in range(int(self.config["fintune"])):
expert_data = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in expert_data.items()})
batch_nums = len(expert_data) // max(min(len(expert_data) // 256,1),128)
for batch_data in expert_data.split(batch_nums):
self.bc_train_batch(batch_data, {}, "train")
else:
for _ in range(int(self.config["fintune"])):
batch_nums = max(expert_data.shape[1] // max(min(expert_data.shape[1] // 256, 1), 128), 1)
_expert_data = dict()
for k,v in expert_data.items():
_expert_data[k] = torch.chunk(v, batch_nums, dim=1) # [tensor_1, tensor_2, ..., tensor_128]
iter_num = len(_expert_data[k])
for idx in range(iter_num):
self.bc_train_batch(Batch({k:v[idx] for k,v in _expert_data.items()}), {}, "train")
# training on valiadation set
graph = self.graph_val
generator_other_nets = self.other_models_val[:-self.matcher_num]
matchers = self.other_models_val[-self.matcher_num:]
generator_optimizers = self.val_optimizers[1]
other_generator_optimizers = self.val_optimizers[2]
matcher_optimizer = self.val_optimizers[-1]
self.scope = "val"
# [ TRAINING D ] avoid matcher loss shocking
for _ in range(10):
for i in range(self.config['d_steps']):
self.global_step += 1
expert_data = next(iter(self._val_loader_train))
expert_data.to_torch(device=self._device)
_, _info = self._run_matcher(expert_data, graph, matchers, matcher_optimizer, test=not self.config["matcher_sample"])
_info = {k + '_val' : v for k, v in _info.items()}
info.update(_info)
d_loss = info[f'{self.NAME}/matcher_loss_val']
if d_loss <= 1.35:
break
if self._epoch_cnt % self.config["history_matcher_save_epochs"] == 0:
logger.info(f"Epoch: {self._epoch_cnt}, Save val matcher in 'history_matcher_train'.")
self.history_matcher_val.append(deepcopy(matchers))
if self._epoch_cnt > self.config['bc_epoch'] + self.config['matcher_pretrain_epoch']:
for i in range(self.config['g_steps']):
self.global_step += 1
expert_data = next(iter(self._val_loader_train))
expert_data.to_torch(device=self._device)
for matcher_index in range(len(matchers)):
_info, generated_data = self._run_generator(deepcopy(expert_data), graph, generator_other_nets[matcher_index], matchers[matcher_index], generator_optimizers, other_generator_optimizers, matcher_index=matcher_index)
_info = {k + '_val' : v for k, v in _info.items()}
info.update(_info)
if self.config["fintune"] >= 1 and self._epoch_cnt%self.config['finetune_fre']==0:
rnn_flag = True if (self.config['policy_backbone'] in ['gru', 'lstm']) or (self.config['transition_backbone'] in ['gru', 'lstm']) else False
if not rnn_flag:
for _ in range(int(self.config["fintune"])):
expert_data = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in expert_data.items()})
batch_nums = len(expert_data) // max(min(len(expert_data) // 256,1),128)
for batch_data in expert_data.split(batch_nums):
self.bc_train_batch(batch_data, {}, "val")
else:
for _ in range(int(self.config["fintune"])):
batch_nums = max(expert_data.shape[1] // max(min(expert_data.shape[1] // 256, 1), 128), 1)
_expert_data = dict()
for k,v in expert_data.items():
_expert_data[k] = torch.chunk(v, batch_nums, dim=1) # [tensor_1, tensor_2, ..., tensor_128]
iter_num = len(_expert_data[k])
for idx in range(iter_num):
self.bc_train_batch(Batch({k:v[idx] for k,v in _expert_data.items()}), {}, "val")
# NOTE: Currently, adaptation do not support gmm distributions
if self.config['std_adapt_strategy'] is not None:
with torch.no_grad():
for i, node_name in enumerate(self._graph.keys()):
error = expert_data[node_name] - generated_data[node_name]
if self.config['std_adapt_strategy'] == 'mean':
std = error.abs().view(-1, error.shape[-1]).mean(0)
elif self.config['std_adapt_strategy'] == 'max':
std, _ = error.abs().view(-1, error.shape[-1]).max(0)
std /= 3
self.adapt_stds[i] = std
info = self._early_stop(info)
if self._epoch_cnt >= self.config['revive_epoch'] + self.config['bc_epoch']:
self._stop_flag = True
info["stop_flag"] = self._stop_flag
del expert_data, generated_data
return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
def _get_matcher_score(self, batch_data, _matcher, score_mode="mean"):
"""
Args:
matcher: always a single matcher here
Returns:
score (torch.Tensor):
"""
matching_nodes = _matcher.matching_nodes
matching_nodes_index = self.matching_nodes_list.index(matching_nodes)
matching_nodes_fit_index = self.matching_nodes_fit_index_list[matching_nodes_index]
matcher_input = get_list_traj(batch_data, matching_nodes, matching_nodes_fit_index)
score = _matcher(*matcher_input)
"""
history_matcher = self.history_matcher_train if self.scope == "train" else self.history_matcher_val
for matcher in history_matcher:
scores.append(matcher(*matcher_input))
if use_by_generator:
history_matcher = self.history_matcher_train if self.scope == "train" else self.history_matcher_val
for matcher in history_matcher:
scores.append(matcher(*matcher_input))
"""
"""
score = torch.mean(torch.cat(scores[:1], axis=-1), dim=-1, keepdim=True)
if score_mode in "mean":
score_total = torch.mean(torch.cat(scores, axis=-1), dim=-1, keepdim=True)
elif score_mode in "min":
score_total,_ = torch.min(torch.cat(scores, axis=-1), dim=-1, keepdim=True)
elif score_mode in "trj_min":
# T B N
score_total_cat = torch.cat(scores, axis=-1)
score_total_trj_sum_min_index = torch.min(torch.sum(score_total_cat, dim=0, keepdim=True),dim=-1, keepdim=True)
index = torch.arange(0,score_total_cat.shape[-2], device=score_total_cat.device) * score_total_cat.shape[-1] + score_total_trj_sum_min_index[1].reshape(-1)
score_total = score_total_cat.reshape(score_total_cat.shape[0],-1)
score_total = score_total[:,index]
score_total = score_total.reshape(score_total_cat.shape[0],score_total_cat.shape[1],1)
else:
raise NotImplementedError
"""
return score
def _generate_rewards(self, generated_data, expert_data, graph, matcher):
def reward_fn(data, matcher, graph):
# generate reward with matcher
score = self._get_matcher_score(data, matcher)
matcher_reward = - torch.log(1 - score + 1e-4)
matcher_reward = (matcher_reward - matcher_reward.mean()) / (matcher_reward.std() + 1e-4) # normalize reward
# generate reward with mae
shooting_mae = []
for node_name in graph.keys():
if node_name in graph.metric_nodes:
policy_shooting_error = torch.abs(expert_data[node_name] - generated_data[node_name])
shooting_mae.append(policy_shooting_error)
shooting_mae = torch.mean(torch.cat(shooting_mae, axis=-1),dim=-1,keepdim=True)
mae_reward = - shooting_mae
mae_reward = (mae_reward - mae_reward.mean()) / (mae_reward.std() + 1e-4) # normalize reward
reward = matcher_reward*(1-self.config["mae_reward_weight"]) + mae_reward*self.config["mae_reward_weight"]
if self.config['rule_reward_func']:
if (not self.config['rule_reward_matching_nodes']) or (set(matcher.matching_nodes) == set(self.config['rule_reward_matching_nodes'])):
rule_reward = self.config['rule_reward_func'](generated_data, graph)
assert rule_reward.shape == matcher_reward.shape
if self.config['rule_reward_func_normalize']:
rule_reward = (rule_reward - rule_reward.mean()) / (rule_reward.std() + 1e-4)
reward += (rule_reward * self.config['rule_reward_func_weight'])
return reward.detach()
return generate_rewards(generated_data, reward_fn=lambda data: reward_fn(data, matcher, graph))
def _generate_rollout(self, expert_data, graph, matcher, test=False, generate_reward=False, sample_fn=None, clip=False):
# generate rollout
if sample_fn is None:
sample_fn = lambda dist: dist.mode if test else (dist.rsample() if self.config['generator_algo'] == 'svg' else dist.sample())
generated_data = generate_rollout(expert_data, graph, expert_data.shape[0], sample_fn, self.adapt_stds, clip)
if generate_reward:
generated_data = self._generate_rewards(generated_data, expert_data, graph, matcher)
return generated_data
def _run_matcher(self, expert_data, graph, matchers, matcher_optimizer=None, test=False):
# generate rollout
# [ TRAINING D ] the same setting as policy learning: data should be the same as policy learn
generated_data = self._generate_rollout(expert_data, graph, matchers, test=test, clip=1.5)
isnan_index_list = []
for matching_nodes in self.matching_nodes_list:
for node_name in matching_nodes:
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index_list.append(expert_data[node_name + "_isnan_index_"])
if isnan_index_list:
isnan_index = torch.mean(torch.cat(isnan_index_list,axis=-1),axis=-1)
expert_data = expert_data[isnan_index==0]
generated_data = generated_data[isnan_index==0]
# compute matcher score
expert_scores, generated_scores, matcher_losses = [], [], []
matcher_dict = dict()
for matcher_index in range(len(matchers)):
expert_score = self._get_matcher_score(expert_data, matchers[matcher_index])
generated_score = self._get_matcher_score(generated_data, matchers[matcher_index])
expert_scores.append(deepcopy(expert_score.detach()))
generated_scores.append(deepcopy(generated_score.detach()))
real = torch.ones_like(expert_score)
fake = torch.zeros_like(generated_score)
expert_entropy = torch.distributions.Bernoulli(expert_score).entropy().mean()
generated_entropy = torch.distributions.Bernoulli(generated_score).entropy().mean()
entropy_loss = - 0.5 * (expert_entropy + generated_entropy)
matcher_losses.append(F.binary_cross_entropy(expert_score, real) + F.binary_cross_entropy(generated_score, fake))
matcher_dict.update({
f"{self.NAME}/expert_score_{matcher_index}": expert_score.mean().item(),
f"{self.NAME}/generated_score_{matcher_index}": generated_score.mean().item(),
# f"{self.NAME}/expert_score_total": expert_score_total.mean().item(),
# f"{self.NAME}/generated_score_total": generated_score_total.mean().item(),
f"{self.NAME}/expert_acc_{matcher_index}": (expert_score > 0.5).float().mean().item(),
f"{self.NAME}/generated_acc_{matcher_index}": (generated_score < 0.5).float().mean().item(),
f"{self.NAME}/expert_entropy_{matcher_index}": expert_entropy.item(),
f"{self.NAME}/generated_entropy_{matcher_index}": generated_entropy.item(),
f"{self.NAME}/matcher_loss_{matcher_index}": matcher_losses[-1].item(),
# f"{self.NAME}/matcher_loss_total": matcher_loss_total.item(),
})
expert_score = sum(expert_scores) / len(expert_scores)
generated_score = sum(generated_scores) / len(generated_scores)
matcher_loss = sum(matcher_losses) / len(matcher_losses)
if self.matcher_loss_length:
if self.scope == "train":
self.history_matcher_loss_train.append(matcher_loss.item())
if np.mean(list(self.history_matcher_loss_train)) < self.matcher_loss_low:
self.stop_matcher_update = True
else:
self.stop_matcher_update = False
if np.mean(list(self.history_matcher_loss_train)) > self.matcher_loss_high:
self.stop_generator_update_train = True
else:
self.stop_generator_update_train = False
else:
self.history_matcher_loss_val.append(matcher_loss.item())
if np.mean(list(self.history_matcher_loss_val)) < self.matcher_loss_low:
self.stop_matcher_update = True
else:
self.stop_matcher_update = False
if np.mean(list(self.history_matcher_loss_val)) > self.matcher_loss_high:
self.stop_generator_update_val = True
else:
self.stop_generator_update_val = False
if matcher_optimizer is not None and not self.stop_matcher_update:
matcher_optimizer.zero_grad(set_to_none=True)
matcher_loss.backward()
matcher_grad_norm = nn.utils.clip_grad_norm_(get_models_parameters(*matchers), 0.5)
if torch.any(torch.isnan(matcher_grad_norm)):
self.nan_in_grad()
logger.info(f'Detect nan in gradient, skip this batch! (loss : {matcher_loss}, grad_norm : {matcher_grad_norm})')
else:
matcher_optimizer.step()
# with torch.no_grad():
# matcher_loss_total = F.binary_cross_entropy(expert_score_total, real) + F.binary_cross_entropy(generated_score_total, fake)
info = {
f"{self.NAME}/expert_score": expert_score.mean().item(),
f"{self.NAME}/generated_score": generated_score.mean().item(),
# f"{self.NAME}/expert_score_total": expert_score_total.mean().item(),
# f"{self.NAME}/generated_score_total": generated_score_total.mean().item(),
f"{self.NAME}/expert_acc": (expert_score > 0.5).float().mean().item(),
f"{self.NAME}/generated_acc": (generated_score < 0.5).float().mean().item(),
f"{self.NAME}/expert_entropy": expert_entropy.item(),
f"{self.NAME}/generated_entropy": generated_entropy.item(),
f"{self.NAME}/matcher_loss": matcher_loss.item(),
# f"{self.NAME}/matcher_loss_total": matcher_loss_total.item(),
}
if matcher_optimizer is not None and not self.stop_matcher_update:
info[f"{self.NAME}/matcher_grad_norm"] = matcher_grad_norm.item()
info.update(matcher_dict)
return generated_data, info