''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
from copy import deepcopy
from concurrent import futures
import torch
import numpy as np
from ray import train
from loguru import logger
from torch import nn
from torch.functional import F
from revive.computation.graph import *
from revive.computation.modules import *
from revive.computation.dists import all_equal
from revive.utils.common_utils import *
from revive.data.dataset import data_creator
from revive.algo.venv.base import VenvOperator, catch_error
from revive.utils.raysgd_utils import BATCH_SIZE, NUM_SAMPLES, AverageMeterCollection
[docs]
class ReviveOperator(VenvOperator):
NAME = "REVIVE"
[docs]
def matcher_model_creator(self, config, graph):
"""
Create matcher models.
:param config: configuration parameters
:return: all the models.
"""
networks = []
self.matching_nodes = graph.get_relation_node_names() if config['matching_nodes'] == 'auto' else config['matching_nodes']
if isinstance(self.matching_nodes[0], str):
self.matching_nodes_list = [self.matching_nodes, ]
else:
self.matching_nodes_list = self.matching_nodes
logger.info(f"matching_nodes_list : {self.matching_nodes_list}")
self.matching_fit_nodes_list = self.matching_nodes_list if ("matching_fit_nodes" not in config.keys() or config['matching_fit_nodes'] == 'auto') else config['matching_fit_nodes']
logger.info(f"matching_fit_nodes_list : {self.matching_fit_nodes_list}")
self.matching_nodes_fit_index_list = []
self.matcher_num = len(self.matching_nodes_list)
for matching_nodes, matching_fit_nodes in zip(self.matching_nodes_list, self.matching_fit_nodes_list):
input_dim = 0
# del nodata node
for node in graph.nodata_node_names:
if node in matching_nodes:
logger.info(f"{matching_nodes}, {node}")
matching_nodes.remove(node)
matching_nodes_fit_index = {}
for node_name in matching_nodes:
oral_node_name = node_name
if node_name.startswith("next_"):
node_name = node_name[5:]
assert node_name in self._graph.fit.keys(), f"Node name {oral_node_name} not in {self._graph.fit.keys()}"
input_dim += np.sum(self._graph.fit[node_name])
matching_nodes_fit_index[oral_node_name] = self._graph.fit[node_name]
self.matching_nodes_fit_index_list.append(matching_nodes_fit_index)
logger.info(f"Matcher nodes: {matching_nodes}, Input dim : {input_dim}")
if config['matcher_type'] in ['mlp', 'res', 'transformer', 'ft_transformer']:
matcher_network = FeedForwardMatcher(input_dim,
config['matcher_hidden_features'],
config['matcher_hidden_layers'],
norm=config['matcher_normalization'],
hidden_activation=config['matcher_activation'],
backbone_type=config['matcher_type'])
matcher_network.matching_nodes = matching_nodes # input to matcher
matcher_network.matching_fit_nodes = matching_fit_nodes
networks.append(matcher_network)
elif config['matcher_type'] in ['gru', 'lstm']:
matcher_network = RecurrentMatcher(input_dim,
config['matcher_hidden_features'],
config['matcher_hidden_layers'],
backbone_type=config['matcher_type'],
bidirect=config['birnn'])
matcher_network.matching_nodes = matching_nodes
matcher_network.matching_fit_nodes = matching_fit_nodes
networks.append(matcher_network)
elif config['matcher_type'] == 'hierarchical':
raise DeprecationWarning('This may not correctly work due to registration of functions and leaf nodes')
dims = []
dims.append(total_dims['obs']['input'])
for policy_name in list(graph.keys())[:-1]:
dims.append(total_dims[policy_name]['input'])
dims.append(total_dims['obs']['input'])
networks.append(HierarchicalMatcher(dims, config['matcher_hidden_features'], config['matcher_hidden_layers'],
norm=config['matcher_normalization'], hidden_activation=config['matcher_activation']))
return networks
[docs]
def model_creator(self, config, graph):
matcher_networks = self.matcher_model_creator(config, graph)
generator_networks = self.generator_model_creator(config, graph)
return generator_networks + matcher_networks
[docs]
def data_creator(self):
self.config[BATCH_SIZE] = self.config['bc_batch_size']
if self.config["venv_train_dataset_mode"] == "trajectory":
return data_creator(self.config,
training_mode="trajectory",
training_horizon=self.config['venv_rollout_horizon'],
training_is_sample=False,
val_horizon=self.config['venv_rollout_horizon'],
double=True)
else:
return data_creator(self.config,
training_mode='transition',
training_is_sample=False,
val_horizon=self.config['venv_rollout_horizon'],
double=True)
[docs]
def switch_mail_data_loader(self, mode="revive"):
logger.info("Switch bc dataloader to revive dataloader.")
self.config[BATCH_SIZE] = self.config['revive_batch_size']
train_loader_train, val_loader_train, train_loader_val, val_loader_val = \
data_creator(self.config,
training_mode="trajectory",
training_horizon=self.config['venv_rollout_horizon'],
val_horizon=self.config['venv_rollout_horizon'],
double=True)
try:
self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False)
self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False)
self._train_loader_val = train.torch.prepare_data_loader(train_loader_val, move_to_device=False)
self._val_loader_val = train.torch.prepare_data_loader(val_loader_val, move_to_device=False)
except:
self._train_loader_train = train_loader_train
self._val_loader_train = val_loader_train
self._train_loader_val = train_loader_val
self._val_loader_val = val_loader_val
@catch_error
def __init__(self, config):
super().__init__(config)
self._v_loss_list = []
self.adapt_stds = [None] * len(self._graph)
self.state_nodes = self._graph.get_leaf() if config['state_nodes'] == 'auto' else config['state_nodes']
self._total_epoch = self.config['revive_epoch'] + self.config['bc_epoch']
logger.info(f'Using {self.matching_nodes} as matching nodes!')
logger.info(f'Using {self.state_nodes} as state nodes!')
def _early_stop(self, info):
# early stop if last 50 v_loss all greater than 1e4
if 'REVIVE/v_loss_train' in info.keys():
self._v_loss_list.append(max(info['REVIVE/v_loss_train'], info['REVIVE/v_loss_val']))
_last_v_losses = np.array(self._v_loss_list[-50:])
if np.all(_last_v_losses > 1e4):
logger.info('Early stop triggered by value explosion!')
self._stop_flag = self._stop_flag or True
return super()._early_stop(info)
[docs]
def bc_train_batch(self, expert_data, batch_info, scope='train', loss_type=None, dataset_mode=None, loss_mask=None):
self._batch_cnt += 1
if loss_type is None:
loss_type = self.config.get("bc_loss_type", "nll")
if scope == 'train':
models = self.train_models
graph = self.graph_train
optimizer = self.train_optimizers[0]
else:
models = self.val_models
graph = self.graph_val
optimizer = self.val_optimizers[0]
graph.reset()
info = {}
loss = 0
logstd_loss_flag = True
if not dataset_mode:
dataset_mode = self.config["venv_train_dataset_mode"]
if dataset_mode == "trajectory":
sample_fn = lambda dist: dist.mode
generated_data, dist_dict = generate_rollout(expert_data,
graph,
expert_data.shape[0],
sample_fn,
self.adapt_stds,
clip=1.5,
return_dist=True,
mode="train")
if self.config["pre_horzion"] > 0:
expert_data = expert_data[self.config["pre_horzion"]:]
generated_data = generated_data[self.config["pre_horzion"]:]
for node_name in graph.keys():
node = graph.get_node(node_name)
if node.node_type == 'network':
# Missing value processing by isnan_index
isnan_index_list = []
isnan_index = 1.
# check whether nan is in inputs
for node_name_ in node.input_names:
if (node_name_ + "_isnan_index_") in expert_data.keys():
isnan_index_list.append(expert_data[node_name_ + "_isnan_index_"])
# check whether nan is in outputs
if (node_name + "_isnan_index_") in expert_data.keys():
isnan_index_list.append(expert_data[node_name + "_isnan_index_"])
if loss_mask is not None:
pass
else:
if isnan_index_list:
isnan_index, _ = torch.max(torch.cat(isnan_index_list, axis=-1), axis=-1, keepdim=True)
loss_mask = 1 - isnan_index
else:
loss_mask = 1.
if dataset_mode == "trajectory":
action_dist = dist_dict[node_name]
action = generated_data[node_name]
else:
action_dist = graph.compute_node(node_name, expert_data)
action = action_dist.mode
if not (all_equal(node.network.dist_wrapper.wrapper_list) and node.network.dist_wrapper.wrapper_list[0].distribution_type == 'normal'):
loss_type = "nll"
logstd_loss_flag = False
# Empty node processing, use rollout data as expert data for nodata nodes
if node_name in self._graph.nodata_node_names:
if dataset_mode == "trajectory":
expert_data[node_name] = generated_data[node_name]
else:
expert_data[node_name] = action_dist.mode
continue
# Calculate loss
_loss_type = graph.nodes_loss_type.get(node_name, loss_type)
if _loss_type == "mae":
policy_loss = ((action - expert_data[node_name])*loss_mask).abs().sum(dim=-1).mean()
elif _loss_type == "mse":
policy_loss = (((action - expert_data[node_name])*loss_mask)**2).sum(dim=-1).mean()
elif _loss_type == "nll" or _loss_type == "log_prob" :
if dataset_mode == "trajectory":
policy_loss = 0
for traj_index in range(expert_data.shape[0]):
if isinstance(loss_mask, float):
policy_loss -= (dist_dict[node_name][traj_index].log_prob(expert_data[node_name][traj_index])).mean()
else:
policy_loss -= (dist_dict[node_name][traj_index].log_prob(expert_data[node_name][traj_index])*loss_mask[traj_index]).mean()
else:
policy_loss = - (action_dist.log_prob(expert_data[node_name]) * loss_mask).mean()
if self.config["bc_l2_coef"] > 0:
policy_loss += self.net_l2_norm(node.network) * self.config["bc_l2_coef"]
if self.config["entropy_coef"] > 0:
policy_loss += self.config["entropy_coef"] * action_dist.entropy().mean()
if logstd_loss_flag and self.config["logstd_loss_coef"]>0:
policy_loss += self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].max_logstd.sum() \
- self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].min_logstd.sum()
elif _loss_type == "gaussian_loss":
assert dataset_mode != "trajectory", "gaussian_loss does not support RNN yet"
target = expert_data[node_name]
mean, std = action_dist.mode, action_dist.std
var = torch.pow(std, 2)
# Average over batch and dim, sum over ensembles.
mse_loss_inv = (torch.pow(mean - target, 2) / var * loss_mask).mean()
var_loss = (torch.log(var) * loss_mask).mean()
policy_loss = mse_loss_inv + var_loss
if self.config["bc_l2_coef"] > 0:
policy_loss += self.net_l2_norm(node.network) * self.config["bc_l2_coef"]
if self.config["entropy_coef"] > 0:
policy_loss += self.config["entropy_coef"] * action_dist.entropy().mean()
if logstd_loss_flag and self.config["logstd_loss_coef"]>0:
policy_loss += self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].max_logstd.sum() \
- self.config["logstd_loss_coef"] * node.network.dist_wrapper.wrapper_list[0].min_logstd.sum()
info[f"{self.NAME}/{node_name}_mse_loss_inv_{scope}"] = mse_loss_inv.item()
info[f"{self.NAME}/{node_name}_var_loss_{scope}"] = var_loss.item()
elif _loss_type.startswith("user_module."):
loss_name = _loss_type[len("user_module."):]
loss_function = self.config["user_module"].get(loss_name, None)
assert loss_function is not None
kwargs = {
"node_dist" : action_dist,
"node_name" : node_name,
"isnan_index_list" : isnan_index_list,
"isnan_index" : isnan_index,
"loss_mask" : loss_mask,
"graph" : graph,
"expert_data" : expert_data,
}
policy_loss = loss_function(kwargs)
else:
raise NotImplementedError
loss += policy_loss
info[f"{self.NAME}/{node_name}_bc_loss_{scope}"] = policy_loss.item()
if dataset_mode == "trajectory":
info[f"{self.NAME}/{node_name}_bc_std_{scope}"] = sum([_action_dist.std.mean().item() for _action_dist in action_dist]) / len(action_dist)
else:
info[f"{self.NAME}/{node_name}_bc_std_{scope}"] = action_dist.std.mean().item()
info[f"{self.NAME}/total_bc_loss_{scope}"] = loss.item()
optimizer.zero_grad(set_to_none=True)
loss.backward()
grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(get_models_parameters(*models), 50)
if torch.isnan(grad_norm.mean()):
self.nan_in_grad()
logger.info(f'Detect nan in gradient, skip this batch! (loss : {loss}, grad_norm : {grad_norm})')
else:
optimizer.step()
info[f"{self.NAME}/bc_grad_norm"] = grad_norm.item()
return info
@catch_error
def train_epoch(self):
info = {}
if self._epoch_cnt == self.config['bc_epoch']:
self.switch_mail_data_loader()
# self._load_best_models()
self._epoch_cnt += 1
logger.info(f"Train epoch : {self._epoch_cnt} ")
if self._epoch_cnt <= self.config['bc_epoch']:
# Bc Pretrain
def bc_on_train_data(self):
metric_meters_train = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._train_loader_train)):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
batch.to_torch(device=self._device)
metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='train')
metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
torch.cuda.empty_cache()
return metric_meters_train
def bc_on_val_data(self):
metric_meters_val = AverageMeterCollection()
for batch_idx, batch in enumerate(iter(self._val_loader_train)):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
batch.to_torch(device=self._device)
metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='val')
metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
self.global_step += 1
torch.cuda.empty_cache()
return metric_meters_val
if self._speedup:
with futures.ThreadPoolExecutor() as executor:
bc_on_train_data_result = executor.submit(bc_on_train_data, self)
bc_on_val_data_result = executor.submit(bc_on_val_data, self)
futures.wait([bc_on_train_data_result, bc_on_val_data_result])
metric_meters_train = bc_on_train_data_result.result()
metric_meters_val = bc_on_val_data_result.result()
else:
metric_meters_train = bc_on_train_data(self)
metric_meters_val = bc_on_val_data(self)
info = metric_meters_train.summary()
info.update(metric_meters_val.summary())
return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
else:
# Mail Train
def mail_on_train_data(self):
graph = self.graph_train
# value net(s)
generator_other_nets = self.other_models_train[:-self.matcher_num]
matchers = self.other_models_train[-self.matcher_num:]
generator_optimizers = self.train_optimizers[1]
# value net optim(s)
other_generator_optimizers = self.train_optimizers[2]
matcher_optimizer = self.train_optimizers[-1]
scope = "train"
info ={}
# [ TRAINING D ] avoid matcher loss shocking
for _ in range(10):
for _ in range(self.config['d_steps']):
self.global_step += 1
expert_data = next(iter(self._train_loader_train))
expert_data.to_torch(device=self._device)
_, _info = self._run_matcher(expert_data, graph, matchers, matcher_optimizer, test=not self.config["matcher_sample"])
_info = {k + '_train' : v for k, v in _info.items()}
info.update(_info)
d_loss = info[f'{self.NAME}/matcher_loss_train']
if d_loss <= 1.35:
break
# Train generator
if self._epoch_cnt > self.config['bc_epoch'] + self.config['matcher_pretrain_epoch']:
for _ in range(self.config['g_steps']):
self.global_step += 1
expert_data = next(iter(self._train_loader_train))
expert_data.to_torch(device=self._device)
# Repeating expert data for more generator data generation
for k, _ in expert_data.items():
expert_data[k] = torch.cat([expert_data[k] for _ in range(self.config['generator_data_repeat'])], dim=1)
for matcher_index in range(len(matchers)):
_info, generated_data = self._run_generator(deepcopy(expert_data),
graph,
generator_other_nets[matcher_index],
matchers[matcher_index],
generator_optimizers,
other_generator_optimizers,
matcher_index=matcher_index,
scope=scope)
_info = {k + f'_{scope}' : v for k, v in _info.items()}
info.update(_info)
if self.config["fintune"] >= 1 and self._epoch_cnt % self.config['finetune_fre']==0:
if self.config["venv_train_dataset_mode"] == "transition":
for _ in range(int(self.config["fintune"])):
expert_data = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in expert_data.items()})
batch_nums = max(len(expert_data) // max(min(len(expert_data) // 256, 1), 128), 1)
for batch_data in expert_data.split(batch_nums):
_info = self.bc_train_batch(batch_data, {}, scope, dataset_mode="transition",
loss_type=self.config['bc_loss'])
info.update(_info)
else:
for _ in range(int(self.config["fintune"])):
if self.config['bc_repeat']>= 0 :
expert_data = Batch({k:v.repeat(1, self.config["bc_repeat"], 1) for k,v in expert_data.items()})
logger.info(f'bc_repeat num is {expert_data.shape[1]}')
else:
expert_data = Batch({k:v.repeat(1, v.shape[0], 1) for k,v in expert_data.items()})
batch_nums = max(expert_data.shape[1] // max(min(expert_data.shape[1] // 256, 1), 128), 1)
batch_nums = min(16, batch_nums)
random_index = np.random.permutation(expert_data.shape[1])
idx_index = torch.from_numpy(random_index).chunk(batch_nums)
for idx in range(batch_nums):
_info = self.bc_train_batch(Batch({k:v[:,idx_index[idx],:] for k,v in expert_data.items()}), {}, scope,
loss_type=self.config['bc_loss'])
info.update(_info)
torch.cuda.empty_cache()
return info
def mail_on_val_data(self):
# training on valiadation set
graph = self.graph_val
generator_other_nets = self.other_models_val[:-self.matcher_num]
matchers = self.other_models_val[-self.matcher_num:]
generator_optimizers = self.val_optimizers[1]
other_generator_optimizers = self.val_optimizers[2]
matcher_optimizer = self.val_optimizers[-1]
scope = "val"
info = {}
# [ TRAINING D ] avoid matcher loss shocking
for _ in range(10):
for _ in range(self.config['d_steps']):
self.global_step += 1
expert_data = next(iter(self._val_loader_train))
expert_data.to_torch(device=self._device)
_, _info = self._run_matcher(expert_data, graph, matchers, matcher_optimizer, test=not self.config["matcher_sample"])
_info = {k + '_val' : v for k, v in _info.items()}
info.update(_info)
d_loss = info[f'{self.NAME}/matcher_loss_val']
if d_loss <= 1.35:
break
# Train generator
if self._epoch_cnt > self.config['bc_epoch'] + self.config['matcher_pretrain_epoch']:
for _ in range(self.config['g_steps']):
self.global_step += 1
expert_data = next(iter(self._val_loader_train))
expert_data.to_torch(device=self._device)
# Repeating expert data for more generator data generation
for k, _ in expert_data.items():
expert_data[k] = torch.cat([expert_data[k] for _ in range(self.config['generator_data_repeat'])], dim=1)
for matcher_index in range(len(matchers)):
_info, generated_data = self._run_generator(deepcopy(expert_data),
graph,
generator_other_nets[matcher_index],
matchers[matcher_index],
generator_optimizers,
other_generator_optimizers,
matcher_index=matcher_index,
scope=scope)
_info = {k + '_val' : v for k, v in _info.items()}
info.update(_info)
if self.config["fintune"] >= 1 and self._epoch_cnt%self.config['finetune_fre']==0:
if self.config["venv_train_dataset_mode"] == "transition":
for _ in range(int(self.config["fintune"])):
expert_data = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in expert_data.items()})
batch_nums = max(len(expert_data) // max(min(len(expert_data) // 256, 1), 128), 1)
for batch_data in expert_data.split(batch_nums):
_info = self.bc_train_batch(batch_data, {}, scope, dataset_mode="transition",
loss_type=self.config['bc_loss'])
info.update(_info)
else:
for _ in range(int(self.config["fintune"])):
if self.config['bc_repeat']>= 0 :
expert_data = Batch({k:v.repeat(1, self.config["bc_repeat"], 1) for k,v in expert_data.items()})
logger.info(f'bc_repeat num is {expert_data.shape[1]}')
else:
expert_data = Batch({k:v.repeat(1, v.shape[0], 1) for k,v in expert_data.items()})
expert_data = Batch({k:v.repeat(1, 1, 1) for k,v in expert_data.items()})
batch_nums = max(expert_data.shape[1] // max(min(expert_data.shape[1] // 256, 1), 128), 1)
batch_nums = min(16, batch_nums)
random_index = np.random.permutation(expert_data.shape[1])
idx_index = torch.from_numpy(random_index).chunk(batch_nums)
for idx in range(batch_nums):
_info = self.bc_train_batch(Batch({k:v[:,idx_index[idx],:] for k,v in expert_data.items()}), {}, "val",
loss_type=self.config['bc_loss'])
info.update(_info)
torch.cuda.empty_cache()
return info
"""
# NOTE: Currently, adaptation do not support gmm distributions
if self.config['std_adapt_strategy'] is not None:
with torch.no_grad():
for i, node_name in enumerate(self._graph.keys()):
error = expert_data[node_name] - generated_data[node_name]
if self.config['std_adapt_strategy'] == 'mean':
std = error.abs().view(-1, error.shape[-1]).mean(0)
elif self.config['std_adapt_strategy'] == 'max':
std, _ = error.abs().view(-1, error.shape[-1]).max(0)
std /= 3
self.adapt_stds[i] = std
"""
if self._speedup:
with futures.ThreadPoolExecutor() as executor:
mail_on_train_data_result = executor.submit(mail_on_train_data, self)
mail_on_val_data_result = executor.submit(mail_on_val_data, self)
futures.wait([mail_on_train_data_result, mail_on_val_data_result])
mail_on_train_data_info = mail_on_train_data_result.result()
mail_on_val_data_info = mail_on_val_data_result.result()
else:
mail_on_train_data_info = mail_on_train_data(self)
mail_on_val_data_info = mail_on_val_data(self)
info.update(mail_on_train_data_info)
info.update(mail_on_val_data_info)
info = self._early_stop(info)
return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
def _get_matcher_score(self, batch_data, matcher):
"""
Args:
matcher: always a single matcher here
Returns:
score (torch.Tensor):
"""
matching_nodes = matcher.matching_nodes
matching_nodes_index = self.matching_nodes_list.index(matching_nodes)
matching_nodes_fit_index = self.matching_nodes_fit_index_list[matching_nodes_index]
matcher_input = get_list_traj(batch_data, matching_nodes, matching_nodes_fit_index)
matcher_input = torch.cat(matcher_input, dim=-1).detach()
score = matcher(matcher_input)
return score
def _generate_rewards(self, generated_data, expert_data, graph, matcher):
def reward_fn(data, matcher, graph):
# Generate reward by matcher
score = self._get_matcher_score(data, matcher)
matcher_reward = - torch.log(1 - score + 1e-4)
matcher_reward = (matcher_reward - matcher_reward.mean()) / (matcher_reward.std() + 1e-4)
# Generate reward by mae metric
# Using config['mae_reward_weight'] parameter configuration weight
if self.config["mae_reward_weight"]:
shooting_mae = []
for node_name in graph.keys():
if node_name in graph.metric_nodes:
policy_shooting_error = torch.abs(expert_data[node_name] - generated_data[node_name])
shooting_mae.append(policy_shooting_error)
shooting_mae = torch.mean(torch.cat(shooting_mae, axis=-1),dim=-1,keepdim=True)
mae_reward = - shooting_mae
# normalize reward
mae_reward = (mae_reward - mae_reward.mean()) / (mae_reward.std() + 1e-4)
reward = matcher_reward*(1-self.config["mae_reward_weight"]) + mae_reward*self.config["mae_reward_weight"]
else:
reward = matcher_reward
# Generate reward by user-defined reward function.
# Using config['rule_reward_func_weight'] parameter configuration weight.
# Using config['rule_reward_func_normalize'] parameter configuration weight normalize.
if self.config['rule_reward_func']:
if (not self.config['rule_reward_matching_nodes']) or (set(matcher.matching_nodes) == set(self.config['rule_reward_matching_nodes'])):
rule_reward = self.config['rule_reward_func'](generated_data, graph)
assert rule_reward.shape == matcher_reward.shape
if self.config['rule_reward_func_normalize']:
rule_reward = (rule_reward - rule_reward.mean()) / (rule_reward.std() + 1e-4)
reward += (rule_reward * self.config['rule_reward_func_weight'])
return reward.detach()
return generate_rewards(generated_data, reward_fn=lambda data: reward_fn(data, matcher, graph))
def _generate_rollout(self,
expert_data,
graph,
matcher=None,
test=False,
generate_reward=False,
sample_fn=None,
clip=False,
mix_sample=False,
mix_sample_ratio=0.5):
"""
Generate rollout data
"""
# Get data rollout sampling function
if sample_fn is None:
def sample_fn(dist,test=test,generator_algo=self.config['generator_algo']):
if test:
return dist.mode
elif generator_algo == 'svg':
return dist.rsample()
else:
return dist.sample()
if generate_reward:
generated_data = generate_rollout(expert_data,
graph,
expert_data.shape[0],
sample_fn,
self.adapt_stds,
clip)
generated_data = self._generate_rewards(generated_data, expert_data, graph, matcher)
else:
generated_data = generate_rollout(expert_data,
graph,
expert_data.shape[0],
sample_fn,
self.adapt_stds,
clip,
mix_sample=mix_sample,
mix_sample_ratio=mix_sample_ratio)
return generated_data
[docs]
def mix_data_process(self, expert_data, generated_data, matchers, matcher_optimizer):
_matchers = deepcopy(matchers)
_matcher_optimizer = deepcopy(matcher_optimizer)
_matcher_optimizer.load_state_dict(matcher_optimizer.state_dict())
# Reshape the expert data and generated data
for k, v in generated_data.items():
shape = v.shape
generated_data[k] = generated_data[k].reshape(-1, shape[-1])
if k in expert_data:
expert_data[k] = expert_data[k].reshape(-1, shape[-1])
# Compute matcher loss
expert_scores, generated_scores, matcher_losses = [], [], []
for matcher_index in range(len(_matchers)):
expert_score = self._get_matcher_score(expert_data, _matchers[matcher_index])
generated_score = self._get_matcher_score(generated_data, _matchers[matcher_index])
expert_scores.append(deepcopy(expert_score.detach()))
generated_scores.append(deepcopy(generated_score.detach()))
real = torch.ones_like(expert_score)
fake = torch.zeros_like(generated_score)
expert_entropy = torch.distributions.Bernoulli(expert_score).entropy().mean()
generated_entropy = torch.distributions.Bernoulli(generated_score).entropy().mean()
total_loss = F.binary_cross_entropy(expert_score, real) + F.binary_cross_entropy(generated_score, fake)
if self.config['discr_ent_coef'] > 0:
total_loss += -self.config['discr_ent_coef'] * (expert_entropy + generated_entropy)
if self.config['matcher_l2_norm_coeff'] > 0:
total_loss += self.net_l2_norm(matchers[matcher_index]) * self.config['matcher_l2_norm_coeff']
matcher_losses.append(total_loss)
matcher_loss = sum(matcher_losses) / len(matcher_losses)
if _matcher_optimizer is not None:
_matcher_optimizer.zero_grad(set_to_none=True)
matcher_loss.backward()
matcher_grad_norm = nn.utils.clip_grad_norm_(get_models_parameters(*_matchers), self.config["matcher_grad_norm"])
if torch.isnan(matcher_grad_norm.mean()):
self.nan_in_grad()
logger.info(f'Detect nan in gradient, skip this batch! (loss : {matcher_loss}, grad_norm : {matcher_grad_norm})')
else:
_matcher_optimizer.step()
for matcher_index in range(len(_matchers)):
generated_score = self._get_matcher_score(generated_data, _matchers[matcher_index]).detach().cpu().numpy()
expert_score = self._get_matcher_score(expert_data, _matchers[matcher_index]).detach().cpu().numpy()
threshold = np.quantile(expert_score, self.config['quantile'])
selected_indices = np.where(generated_score > threshold)[0]
# in case all generated data are regarded as expert data
if len(selected_indices) >= generated_score.shape[0]:
selected_indices = selected_indices[::2]
mask = np.ones(generated_score.shape[0], np.bool_)
mask[selected_indices] = 0
move_num = len(selected_indices)
shuffle_idx = torch.randperm(expert_data.shape[0] + move_num)
for k, v in expert_data.items():
temp = torch.cat([v, generated_data[k][selected_indices]], dim=0)
expert_data[k] = temp[shuffle_idx, :]
generated_data[k] = generated_data[k][mask]
return expert_data, generated_data
[docs]
def net_l2_norm(self, network, mean=False):
weights = 0
param_num = 0
for item in list(network.parameters()):
if item.requires_grad:
weights = weights + item.pow(2).sum()
param_num = param_num + np.prod(list(item.data.shape))
if mean:
param_num = max(param_num, 1)
weights = weights / param_num
return weights
def _run_matcher(self, expert_data, graph, matchers, matcher_optimizer=None, test=False):
""" Train matcher.
"""
# Generate rollout data
# [ TRAINING D ] the same setting as policy learning: data should be the same as policy learn
generated_data = self._generate_rollout(expert_data,
graph,
test=test,
clip=1.5,
mix_sample=self.config['mix_sample'],
mix_sample_ratio=self.config['mix_sample_ratio'])
if self.config["pre_horzion"] > 0:
expert_data = expert_data[self.config["pre_horzion"]:]
generated_data = generated_data[self.config["pre_horzion"]:]
# Missing value processing
isnan_index_list = []
for matching_nodes in self.matching_nodes_list:
for node_name in matching_nodes:
if node_name + "_isnan_index_" in expert_data.keys():
isnan_index_list.append(expert_data[node_name + "_isnan_index_"])
if isnan_index_list:
isnan_index = torch.mean(torch.cat(isnan_index_list,axis=-1),axis=-1)
expert_data = expert_data[isnan_index==0]
generated_data = generated_data[isnan_index==0]
# Using mixed data techniques
if self.config["mix_data"]:
expert_data, generated_data = self.mix_data_process(expert_data, generated_data, matchers, matcher_optimizer)
# Calculate the score of the matcher
expert_scores, generated_scores, matcher_losses, grad_losses = [], [], [], []
matcher_dict = dict()
for matcher_index in range(len(matchers)):
expert_score = self._get_matcher_score(expert_data, matchers[matcher_index])
generated_score = self._get_matcher_score(generated_data, matchers[matcher_index])
expert_scores.append(deepcopy(expert_score.detach()))
generated_scores.append(deepcopy(generated_score.detach()))
real = torch.ones_like(expert_score)
fake = torch.zeros_like(generated_score)
expert_entropy = torch.distributions.Bernoulli(expert_score).entropy().mean()
generated_entropy = torch.distributions.Bernoulli(generated_score).entropy().mean()
# loss_discriminator = F.binary_cross_entropy(expert_score, real) + F.binary_cross_entropy(generated_score, fake)
# loss_discriminator = F.binary_cross_entropy(expert_score, real) + F.binary_cross_entropy(generated_score, fake)
if self.config['controller_weight'] > 0:
output_weight = self.config['controller_weight']
# logger.info(f'Detect controller_weight! ({output_weight})')
loss_discriminator = (-real * (torch.log(expert_score + 1e-8) - self.config['controller_weight']/2 * torch.pow(expert_score, 2) + self.config['controller_weight']/2 * expert_score) + \
- (1 - fake) * (torch.log(1 - generated_score + 1e-8) - self.config['controller_weight']/2 * torch.pow(generated_score, 2) + self.config['controller_weight']/2 * generated_score)).mean()
else:
loss_discriminator = F.binary_cross_entropy(expert_score, real) + F.binary_cross_entropy(generated_score, fake)
if self.config['discr_ent_coef']>0:
loss_discriminator += -self.config['discr_ent_coef'] * (expert_entropy + generated_entropy)
if self.config['matcher_l2_norm_coeff']>0:
loss_discriminator += self.net_l2_norm(matchers[matcher_index]) * self.config['matcher_l2_norm_coeff']
# Gradient penalty
if self.config['gp_coef'] > 0:
matcher_optimizer.zero_grad(set_to_none=True)
grads = torch.autograd.grad(outputs=loss_discriminator, inputs=matchers[matcher_index].parameters(), create_graph=True)
grad_loss = 0
for grad in grads:
grad_loss += torch.pow(grad, 2).sum()
grad_losses.append(grad_loss)
total_loss = loss_discriminator + self.config['gp_coef'] * grad_loss
matcher_losses.append(total_loss)
else:
total_loss = loss_discriminator
matcher_losses.append(total_loss)
matcher_dict.update({
f"{self.NAME}/expert_score_{matcher_index}": expert_score.mean().item(),
f"{self.NAME}/generated_score_{matcher_index}": generated_score.mean().item(),
f"{self.NAME}/expert_acc_{matcher_index}": (expert_score > 0.5).float().mean().item(),
f"{self.NAME}/generated_acc_{matcher_index}": (generated_score < 0.5).float().mean().item(),
f"{self.NAME}/expert_entropy_{matcher_index}": expert_entropy.item(),
f"{self.NAME}/generated_entropy_{matcher_index}": generated_entropy.item(),
f"{self.NAME}/matcher_loss_{matcher_index}": loss_discriminator.item(),
})
expert_score = sum(expert_scores) / len(expert_scores)
generated_score = sum(generated_scores) / len(generated_scores)
matcher_loss = sum(matcher_losses) / len(matcher_losses)
info = {}
if matcher_optimizer is not None:
matcher_optimizer.zero_grad(set_to_none=True)
matcher_loss.backward()
matcher_grad_norm = nn.utils.clip_grad_norm_(get_models_parameters(*matchers), self.config["matcher_grad_norm_clip"])
if torch.isnan(matcher_grad_norm.mean()):
self.nan_in_grad()
logger.info(f'Detect nan in gradient, skip this batch! (loss : {matcher_loss}, grad_norm : {matcher_grad_norm})')
else:
matcher_optimizer.step()
info[f"{self.NAME}/matcher_grad_norm"] = matcher_grad_norm.item()
info.update({
f"{self.NAME}/expert_score": expert_score.mean().item(),
f"{self.NAME}/generated_score": generated_score.mean().item(),
f"{self.NAME}/expert_acc": (expert_score > 0.5).float().mean().item(),
f"{self.NAME}/generated_acc": (generated_score < 0.5).float().mean().item(),
f"{self.NAME}/expert_entropy": expert_entropy.item(),
f"{self.NAME}/generated_entropy": generated_entropy.item(),
f"{self.NAME}/matcher_loss": matcher_loss.item(),
})
if self.config['gp_coef'] > 0:
grad_loss = sum(grad_losses) / len(grad_losses)
info.update({
f"{self.NAME}/grad_penalty_loss": grad_loss.item(),
})
info.update(matcher_dict)
return generated_data, info