''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2025 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.
    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
import shutil
from copy import deepcopy
import torch
import numpy as np
from loguru import logger
import matplotlib.pyplot as plt
from collections import deque
from torch import nn
from torch.functional import F
from ray import train
from revive.computation.graph import *
from revive.computation.modules import *
from revive.utils.raysgd_utils import BATCH_SIZE, NUM_SAMPLES, AverageMeterCollection
from revive.utils.common_utils import *
from revive.data.dataset import data_creator, revive_f_rnn_data_creator
from revive.algo.venv.base import catch_error
try:
    from revive.dist.algo.venv.revive_p import PPOOperator
    logger.info(f"Import encryption venv algorithm module -> PPOOperator!")
except:
    from revive.algo.venv.revive_p import PPOOperator
    logger.info(f"Import venv algorithm module -> PPOOperator!")
[docs]
class FILTEROperator(PPOOperator):
    NAME = "REVIVE_VENV" #"REVIVE_FILTER"
    PARAMETER_DESCRIPTION = [ 
        {
            "name" : "bc_epoch",
            "abbreviation" : "bep",
            "type" : int,
            "default" : 1500,
        },   
        {
            "name" : "bc_lr",
            "type" : float,
            "default" : 1e-3,
        }, 
        {
            "name" : "bc_steps",
            "type" : int,
            "default" : 1,
        }, 
        {
            "name" : "revive_batch_size",
            "description": "Batch size of training process.",
            "abbreviation" : "mbs",
            "type" : int,
            "default" : 1024,
            'doc': True,
        },       
        {
            "name" : "revive_epoch",
            "description": "Number of epcoh for the MAIL training process",
            "abbreviation" : "mep",
            "type" : int,
            "default" : 1500,
            'doc': True,
        },
        {
            "name" : "matcher_pretrain_epoch",
            "abbreviation" : "dpe",
            "type" : int,
            "default" : 0,
        },     
        {
            "name" : "policy_hidden_features",
            "description": "Number of neurons per layer of the policy network.",
            "abbreviation" : "phf",
            "type" : int,
            "default" : 256,
            'doc': True,
        },
        {
            "name" : "policy_hidden_layers",
            "description": "Depth of policy network.",
            "abbreviation" : "phl",
            "type" : int,
            "default" : 4,
            'doc': True,
        },
        {
            "name" : "policy_activation",
            "abbreviation" : "pa",
            "type" : str,
            "default" : 'leakyrelu',
        },
        {
            "name" : "policy_normalization",
            "abbreviation" : "pn",
            "type" : str,
            "default" : None,
        },        
        {
            "name" : "policy_backbone",
            "description": "Backbone of policy network.",
            "abbreviation" : "pb",
            "type" : str,
            "default" : "res",
            'doc': True,
        },  
        {
            "name" : "transition_hidden_features",
            "description": "Number of neurons per layer of the transition network.",
            "abbreviation" : "thf",
            "type" : int,
            "default" : 256,
            'doc': True,
        },
        {
            "name" : "transition_hidden_layers",
            "abbreviation" : "thl",
            "type" : int,
            "default" : 4,
            'doc': True,
        },
        {
            "name" : "transition_activation",
            "abbreviation" : "ta",
            "type" : str,
            "default" : 'leakyrelu',
        },
        {
            "name" : "transition_normalization",
            "abbreviation" : "tn",
            "type" : str,
            "default" : None,
        },     
        {
            "name" : "transition_backbone",
            "description": "Backbone of Transition network.",
            "abbreviation" : "tb",
            "type" : str,
            "default" : "res",
            'doc': True,
        },     
        {
            "name" : "matching_nodes",
            "type" : list,
            "default" : 'auto',
        },
        {
            "name" : "matching_fit_nodes",
            "type" : list,
            "default" : 'auto',
        },     
        {
            "name" : "matcher_hidden_features",
            "description": "Number of neurons per layer of the matcher network.",
            "abbreviation" : "dhf",
            "type" : int,
            "default" : 256,
            'doc': True,
        },
        {
            "name" : "matcher_hidden_layers",
            "description": "Depth of the matcher network.",
            "abbreviation" : "dhl",
            "type" : int,
            "default" : 4,
            'doc': True,
        },
        {
            "name" : "matcher_activation",
            "abbreviation" : "da",
            "type" : str,
            "default" : 'leakyrelu',
        },
        {
            "name" : "matcher_normalization",
            "abbreviation" : "dn",
            "type" : str,
            "default" : None,
        },       
        {
            "name" : "state_nodes",
            "type" : list,
            "default" : 'auto',
        },     
        {
            "name" : "value_hidden_features",
            "abbreviation" : "vhf",
            "type" : int,
            "default" : 256,
        },
        {
            "name" : "value_hidden_layers",
            "abbreviation" : "vhl",
            "type" : int,
            "default" : 4,
        },
        {
            "name" : "value_activation",
            "abbreviation" : "va",
            "type" : str,
            "default" : 'leakyrelu',
        },
        {
            "name" : "value_normalization",
            "abbreviation" : "vn",
            "type" : str,
            "default" : None,
        },     
        {
            "name" : "generator_type",
            "abbreviation" : "gt",
            "type" : str,
            "default" : "res",
        },   
        {
            "name" : "matcher_type",
            "abbreviation" : "dt",
            "type" : str,
            "default" : "res",
        },   
        {
            "name" : "birnn",
            "type" : bool,
            "default" : False,
        },   
        {
            "name" : "std_adapt_strategy",
            "abbreviation" : "sas",
            "type" : str,
            "default" : None,
        },     
        {
            "name" : "generator_algo",
            "abbreviation" : "ga",
            "type" : str,
            "default" : "ppo",
        },     
        {
            "name" : "ppo_runs",
            "type" : int,
            "default" : 2,
        },   
        {
            "name" : "ppo_epsilon",
            "type" : float,
            "default" : 0.2,
        },   
        {
            "name" : "ppo_l2norm_cof",
            "type" : float,
            "default" : 0,
        },   
        {
            "name" : "ppo_entropy_cof",
            "type" : float,
            "default" : 0,
        },   
        {
            "name" : "generator_sup_cof",
            "type" : float,
            "default" : 0,
        },   
        {
            "name" : "gae_gamma",
            "type" : float,
            "default" : 0.99,
        },   
        {
            "name" : "gae_lambda",
            "type" : float,
            "default" : 0.95,
        },      
        {
            "name" : "g_steps",
            "description": "The number of update rounds of the generator in each epoch.",
            "type" : int,
            "default" : 1,
            "search_mode" : "grid",
            "search_values" : [1, 3, 5],
            'doc': True,
        },          
        {
            "name" : "d_steps",
            "description": "Number of update rounds of matcher in each epoch.",
            "type" : int,
            "default" : 1,
            "search_mode" : "grid",
            "search_values" : [1, 3, 5],
            'doc': True,
        },    
        {
            "name" : "g_lr",
            "description": "Initial learning rate of the generator nodes nets.",
            "type" : float,
            "default" : 4e-5,
            "search_mode" : "continuous",
            "search_values" : [1e-6, 1e-4],
            'doc': True,
        },
        {
            "name" : "d_lr",
            "description": "Initial learning rate of the matcher.",
            "type" : float,
            "default" : 6e-4,
            "search_mode" : "continuous",
            "search_values" : [1e-6, 1e-3],
            'doc': True,
        },
        {
            "name" : "value_lr",
            "type" : float,
            "default" : 1e-3,
        },
        {
            "name" : "matcher_loss_length",
            "description": "Matcher loss length.",
            "type" : int,
            "default" : 0,
        },    
        {
            "name" : "matcher_loss_high",
            "description": "Matcher loss high value. When the matcher_loss beyond the value, the generator would stop train",
            "type" : float,
            "default" : 1.2,
        },    
        {
            "name" : "matcher_loss_low",
            "description": "Matcher loss high value. When the matcher_loss low the value, the matcher would stop train",
            "type" : float,
            "default" : 0.3,
        },
        {
            "name" : "matcher_sample",
            "description": "Sample the data for tring the matcher.",
            "type" : bool,
            "default" : False,
        },    
        {
            "name" : "mae_reward_weight",
            "description": "reward = (1-mae_reward_weight)*matcher_reward + mae_reward_weight*mae_reward.",
            "type" : float,
            "default" : 0.25,
        },
        {
            "name" : "generator_data_repeat",
            "description": "Repeat rollout more data to train generator.",
            "type" : int,
            "default" : 1,
        },
        {
            "name" : "rnn_hidden_features",
            "description": "RNN hidden dims",
            "type" : int,
            "default" : 64,
        },
        {
            "name" : "window_size",
            "description": "length of the sliding_window in RNN",
            "type" : int,
            "default" : 0,
        },
        {
            "name" : "mix_data",
            "type" : bool,
            "default" : False,
        },
        {
            "name" : "quantile",
            "type" : float,
            "default" : 0.95,
        },
        {
            "name" : "matcher_grad_norm_clip",
            "type" : float,
            "default" : 0.5,
        },
        {
            "name" : "mix_sample",
            "type" : bool,
            "default" : True,
        },
        {
            "name" : "mix_sample_ratio",
            "type" : float,
            "default" : 0.5,
        },
        {
            "name" : "replace_with_expert",
            "type" : bool,
            "default" : True,
        },
        {
            "name" : "replace_ratio",
            "type" : float,
            "default" : 0.1,
        },
        {
            "name" : "gp_coef",
            "type" : float,
            "default" : 0.5,
        },
        {
            "name" : "discr_ent_coef",
            "type" : float,
            "default" : 0.01,
        },
        {
            "name" : "matcher_l2_norm_coeff",
            "type" : float,
            "default" : 0.0005,
        },
        {
            "name" : "value_l2_norm_coef",
            "type" : float,
            "default" : 1e-6,
        },
        {
            "name" : "generator_l2_norm_coef",
            "type" : float,
            "default" : 1e-6,
        },
        {
            "name" : "matcher_record_len",
            "type" : int,
            "default" : 50,
        },
        {
            "name" : "matcher_record_interval",
            "type" : int,
            "default" : 1,
        },
        {
            "name" : "fix_std",
            "type" : float,
            "default" : 0.125,
        },
        {
            "name" : "bc_l2_coef",
            "type" : float,
            "default" : 5e-5,
        },
        {
            "name" : "logstd_loss_coef",
            "type" : float,
            "default" : 0.01,
        },
        {
            "name" : "entropy_coef",
            "type" : float,
            "default" : 0.0,
        },
        {
            "name" : "bc_loss",
            "type" : str,
            "default" : "nll",
        },
        {
            "name" : "controller_weight",
            "type" : float,
            "default" : 10,
        },
    ]
    @property
    def nodes_models_mail(self):
        return self.mail_models[:self.config['learning_nodes_num']]
    @property
    def other_models_mail(self):
        return self.mail_models[self.config['learning_nodes_num']:]
    @catch_error
    def __init__(self, config : dict):
        super().__init__(config)
        # NOTE: Currently, adaptation do not support gmm distributions
        with torch.no_grad():
            for i, _ in enumerate(self._graph.keys()):
                self.adapt_stds[i] = None if self.config['fix_std'] == 0 else self.config['fix_std']
[docs]
    def data_creator(self, config : dict):
        self.config[BATCH_SIZE] = self.config['revive_batch_size']
        if config['rnn_dataset']:
            mail_train_loader_train, mail_val_loader_train, mail_train_loader_val, mail_val_loader_val, \
                
bc_train_loader_train, bc_val_loader_train, bc_train_loader_val, bc_val_loader_val = revive_f_rnn_data_creator(config, double=True)
            return mail_train_loader_train, mail_val_loader_train, mail_train_loader_val, mail_val_loader_val, \
                        
bc_train_loader_train, bc_val_loader_train, bc_train_loader_val, bc_val_loader_val
        else:
            train_loader_train, val_loader_train, train_loader_val, val_loader_val = \
                
data_creator(config, training_horizon=1, val_horizon=1, double=True)
            return train_loader_train, val_loader_train, train_loader_val, val_loader_val 
    def _setup_componects(self):
        r'''setup models, optimizers and dataloaders.'''
        # register data loader for double venv training
        loaders = self.data_creator(self.config)
        if len(loaders) == 4:
            train_loader_train, val_loader_train, train_loader_val, val_loader_val = loaders
        elif len(loaders) == 8:
            mail_train_loader_train, _, _, _, \
                
train_loader_train, val_loader_train, train_loader_val, val_loader_val = loaders
        else:
            raise NotImplementedError
        try:
            self._train_loader_train = train.torch.prepare_data_loader(train_loader_train, move_to_device=False)
            self._val_loader_train = train.torch.prepare_data_loader(val_loader_train, move_to_device=False)
            self._train_loader_val = train.torch.prepare_data_loader(train_loader_val, move_to_device=False)
            self._val_loader_val = train.torch.prepare_data_loader(val_loader_val, move_to_device=False)
        except:
            self._train_loader_train = train_loader_train
            self._val_loader_train = val_loader_train
            self._train_loader_val = train_loader_val
            self._val_loader_val = val_loader_val
        if self.config['rnn_dataset']:
            try:
                self._mail_train_loader_train = train.torch.prepare_data_loader(mail_train_loader_train, move_to_device=False)
            except:
                self._mail_train_loader_train = mail_train_loader_train
        
        # init bc train models
        self.train_models = self.bc_model_creator(self.config, self.graph_train)
        for model_index, model in enumerate(self.train_models):
            try:
                self.train_models[model_index] = train.torch.prepare_model(model)
            except:
                self.train_models[model_index] = model.to(self._device)
        self.train_optimizers = self.bc_optimizer_creator(self.train_models, self.config)
        # register models to graph
        self._register_models_to_graph(self.graph_train, self.nodes_models_train)
        # init bc val models
        self.val_models = self.bc_model_creator(self.config, self.graph_val)
        for model_index, model in enumerate(self.val_models):
            try:
                self.val_models[model_index] = train.torch.prepare_model(model)
            except:
                self.val_models[model_index] = model.to(self._device)
        self.val_optimizers = self.bc_optimizer_creator(self.val_models, self.config)
        # register models to graph
        self._register_models_to_graph(self.graph_val, self.nodes_models_val)
        # prepare for MAIL training
        self.graph_mail = deepcopy(self._graph)
        self._setup_mail_componects(self.config)
        # register models to graph
        self._register_models_to_graph(self.graph_mail, self.nodes_models_mail)
    def _setup_mail_componects(self, config : dict):
        # saving matchers' info
        self.matcher_structure_list = []
        self.matcher_record_list = []
        # init mail models: [matchers, mail_generators]
        self.mail_models = self.mail_model_creator(config, self.graph_mail)
        for model_index, model in enumerate(self.mail_models):
            try:
                self.mail_models[model_index] = train.torch.prepare_model(model)
            except:
                self.mail_models[model_index] = model.to(self._device)
        self.mail_optimizers = self.mail_optimizer_creator(self.mail_models, config)
[docs]
    def mail_model_creator(self, config, graph):
        matcher_networks = self.matcher_model_creator(config, graph)
        generator_networks = self.mail_generator_model_creator(config, graph)
        
        return generator_networks + matcher_networks 
[docs]
    def bc_model_creator(self, config, graph):
        """
        Create generator models.
        :param config: configuration parameters
        :return: all the models.
        """
        additional_kwargs = {"rnn_hidden_features": config['rnn_hidden_features'],
                            "window_size": config['window_size']}
        total_dims = config['total_dims']
        networks = []
        # initialize policy node networks
        for node_name in list(graph.keys()):
            if node_name in graph.transition_map.values(): 
                continue
            node = graph.get_node(node_name)
            if not node.node_type == 'network':
                logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
                continue
            input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
            if hasattr(node, 'custom_node'):
                input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
            input_dim_dict = get_input_dim_dict_from_graph(graph, node_name, total_dims)
            node.initialize_network(input_dim, total_dims[node_name]['output'],
                                    hidden_features=config['policy_hidden_features'],
                                    hidden_layers=config['policy_hidden_layers'],
                                    hidden_activation=config['policy_activation'],
                                    norm=config['policy_normalization'],
                                    backbone_type=config['policy_backbone'],
                                    dist_config=config['dist_configs'][node_name],
                                    is_transition=False,
                                    soft_clamp=True,
                                    input_dim_dict=input_dim_dict,
                                    **additional_kwargs)
            networks.append(node.get_network())
        # initialize transition node networks
        for node_name in graph.transition_map.values():
            node = graph.get_node(node_name)
            if not node.node_type == 'network':
                logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
                continue
            if node_name[5:] not in node.input_names:
                transition_mode = 'global'
                if config['transition_mode']:
                    warnings.warn('Fallback to global transition mode since the transition variable is not provided as an input!')
            else:
                transition_mode = config['transition_mode']
            input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
            if hasattr(node, 'custom_node'):
                input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
            node.initialize_network(input_dim, total_dims[node_name]['output'],
                                    hidden_features=config['transition_hidden_features'],
                                    hidden_layers=config['transition_hidden_layers'],
                                    hidden_activation=config['transition_activation'],
                                    norm=config['transition_normalization'],
                                    backbone_type=config['transition_backbone'],
                                    dist_config=config['dist_configs'][node_name],
                                    is_transition=True,
                                    transition_mode=transition_mode,
                                    obs_dim=total_dims[node_name]['input'],
                                    soft_clamp=True,
                                    **additional_kwargs)
            networks.append(node.get_network())
        assert len(networks) > 0, 'at least one node need to be a network to run training!'
        return networks 
[docs]
    def matcher_model_creator(self, config, graph):
        """
        Create matcher models.
        :param config: configuration parameters
        :return: all the models.
        """
        networks = []
        self.matching_nodes = graph.get_relation_node_names() if config['matching_nodes'] == 'auto' else config['matching_nodes']
        if isinstance(self.matching_nodes[0], str):
            self.matching_nodes_list = [self.matching_nodes, ]
        else:
            self.matching_nodes_list = self.matching_nodes
        logger.info(f"matching_nodes_list : {self.matching_nodes_list}")
        self.matching_fit_nodes_list = self.matching_nodes_list if ("matching_fit_nodes" not in config.keys() or config['matching_fit_nodes'] == 'auto') else config['matching_fit_nodes']
        logger.info(f"matching_fit_nodes_list : {self.matching_fit_nodes_list}")
        self.matching_nodes_fit_index_list = []
        self.matcher_num = len(self.matching_nodes_list)
        
        for matching_nodes,matching_fit_nodes in zip(self.matching_nodes_list, self.matching_fit_nodes_list):
            input_dim = 0
            # del nodata node 
            for node in graph.nodata_node_names:
                if node in matching_nodes:
                    logger.info(f"{matching_nodes}, {node}")
                    matching_nodes.remove(node)
            matching_nodes_fit_index = {}
            for node_name in matching_nodes:
                oral_node_name = node_name
                if node_name.startswith("next_"):
                    node_name = node_name[5:]
                assert node_name in self._graph.fit.keys(), f"Node name {oral_node_name} not in {self._graph.fit.keys()}"
                input_dim += np.sum(self._graph.fit[node_name])
                matching_nodes_fit_index[oral_node_name] = self._graph.fit[node_name]
            self.matching_nodes_fit_index_list.append(matching_nodes_fit_index)
                
            logger.info(f"Matcher nodes: {matching_nodes}, Input dim : {input_dim}")
            if config['matcher_type'] in ['mlp', 'res', 'transformer', 'ft_transformer']:
                matcher_network = FeedForwardMatcher(input_dim,
                                                        config['matcher_hidden_features'], 
                                                        config['matcher_hidden_layers'],
                                                        norm=config['matcher_normalization'], 
                                                        hidden_activation=config['matcher_activation'],
                                                        backbone_type=config['matcher_type'])
                matcher_network.matching_nodes = matching_nodes  # input to matcher
                matcher_network.matching_fit_nodes = matching_fit_nodes
                networks.append(matcher_network)
                self.matcher_record_list.append({"state_dicts": deque(maxlen=self.config["matcher_record_len"]),
                                                "expert_scores": deque(maxlen=self.config["matcher_record_len"]),
                                                "generated_scores": deque(maxlen=self.config["matcher_record_len"]),
                                                })
                self.matcher_structure_list.append({"in_features": input_dim,
                                                    "hidden_features": config['matcher_hidden_features'], 
                                                    "hidden_layers": config['matcher_hidden_layers'],
                                                    "norm": config['matcher_normalization'], 
                                                    "hidden_activation": config['matcher_activation'],
                                                    "backbone_type": config['matcher_type'],
                                                    "matching_nodes": matching_nodes,
                                                    "matching_fit_nodes": matching_fit_nodes,
                                                    "matching_nodes_fit_index": matching_nodes_fit_index,
                                                    })
            elif config['matcher_type'] in ['gru', 'lstm']:
                matcher_network = RecurrentMatcher(input_dim,
                                                    config['matcher_hidden_features'], 
                                                    config['matcher_hidden_layers'],
                                                    backbone_type=config['matcher_type'], 
                                                    bidirect=config['birnn'])
                matcher_network.matching_nodes = matching_nodes
                matcher_network.matching_fit_nodes = matching_fit_nodes
                networks.append(matcher_network)
            elif config['matcher_type'] == 'hierarchical':
                raise DeprecationWarning('This may not correctly work due to registration of functions and leaf nodes')
                dims = []
                dims.append(total_dims['obs']['input'])
                for policy_name in list(graph.keys())[:-1]:
                    dims.append(total_dims[policy_name]['input'])
                dims.append(total_dims['obs']['input'])
                networks.append(HierarchicalMatcher(dims, config['matcher_hidden_features'], config['matcher_hidden_layers'],
                                                        norm=config['matcher_normalization'], hidden_activation=config['matcher_activation']))
        return networks 
[docs]
    def mail_generator_model_creator(self, config, graph):
        """
        Create generator models.
        :param config: configuration parameters
        :return: all the models.
        """
        additional_kwargs = {"rnn_hidden_features": config['rnn_hidden_features'],
                            "window_size": config['window_size']}
        total_dims = config['total_dims']
        networks = []
        # initialize policy node networks
        for node_name in list(graph.keys()):
            if node_name in graph.transition_map.values(): 
                continue
            node = graph.get_node(node_name)
            if not node.node_type == 'network':
                logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
                continue
            input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
            if hasattr(node, 'custom_node'):
                input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
            node.initialize_network(input_dim, total_dims[node_name]['output'],
                                    hidden_features=config['policy_hidden_features'],
                                    hidden_layers=config['policy_hidden_layers'],
                                    hidden_activation=config['policy_activation'],
                                    norm=config['policy_normalization'],
                                    backbone_type=config['policy_backbone'],
                                    dist_config=config['dist_configs'][node_name],
                                    is_transition=False,
                                    soft_clamp=True,
                                    **additional_kwargs)
            networks.append(node.get_network())
        # initialize transition node networks
        for node_name in graph.transition_map.values():
            node = graph.get_node(node_name)
            if not node.node_type == 'network':
                logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
                continue
            if node_name[5:] not in node.input_names:
                transition_mode = 'global'
                if config['transition_mode']:
                    warnings.warn('Fallback to global transition mode since the transition variable is not provided as an input!')
            else:
                transition_mode = config['transition_mode']
            input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
            if hasattr(node, 'custom_node'):
                input_dim = get_input_dim_dict_from_graph(graph, node_name, total_dims)
            input_dim_dict = get_input_dim_dict_from_graph(graph, node_name, total_dims)
            node.initialize_network(input_dim, total_dims[node_name]['output'],
                                    hidden_features=config['transition_hidden_features'],
                                    hidden_layers=config['transition_hidden_layers'],
                                    hidden_activation=config['transition_activation'],
                                    norm=config['transition_normalization'],
                                    backbone_type=config['transition_backbone'],
                                    dist_config=config['dist_configs'][node_name],
                                    is_transition=True,
                                    transition_mode=transition_mode,
                                    obs_dim=total_dims[node_name]['input'],
                                    soft_clamp=True,
                                    input_dim_dict=input_dim_dict,
                                    **additional_kwargs)
            networks.append(node.get_network())
        assert len(networks) > 0, 'at least one node need to be a network to run training!'
        # create value function
        input_dim = 0
        # state_nodes = graph.get_relation_node_names() if config['state_nodes'] == 'auto' else config['state_nodes']
        state_nodes = graph.get_leaf() if config['state_nodes'] == 'auto' else config['state_nodes']
        for node_name in state_nodes:
            input_dim += total_dims[node_name]['input']
        
        for v_num in range(self.matcher_num):
            networks.append(MLP(input_dim, 1, config['value_hidden_features'], config['value_hidden_layers'],
                            norm=config['value_normalization'], hidden_activation=config['value_activation']))
        return networks 
[docs]
    def bc_optimizer_creator(self, models, config):
        optimizer = torch.optim.Adam(get_models_parameters(*models), lr=config['bc_lr'])
        return [optimizer] 
[docs]
    def mail_optimizer_creator(self, models, config):
        """
        Optimizer creator including generator optimizers and matcher optimizers.
        :param models: node models, matcher, value_net
        :param config: configuration parameters
        :return: generator_optimizer, matcher_optimizer
        """
        node_models = models[:-self.matcher_num*2]
        value_nets = models[-self.matcher_num*2:-self.matcher_num]
        matcher_nets = models[-self.matcher_num:]
        self.node_model_num = len(node_models)
        assert self.node_model_num == self.config['learning_nodes_num'], "self.node_model_num != self.config['learning_nodes_num']"
        bc_node_optimizer = None
        
        # generator_node_optimizer = dict()
        # for i in range(len(node_models)):
        #     generator_node_optimizer[node_models[i].kwargs['node_name']] = torch.optim.Adam(get_models_parameters(node_models[i]), lr=config['g_node_lr'])
        generator_node_optimizer = torch.optim.Adam(get_models_parameters(*node_models), lr=config['g_lr'])
        # generator_value_optimizer = dict()
        # for i in range(len(value_nets)):
        #     generator_value_optimizer[i] = torch.optim.Adam(get_models_parameters(value_nets[i]), lr=config['g_value_lr'])
        generator_value_optimizer = torch.optim.Adam(get_models_parameters(*value_nets), lr=config['value_lr'])
        # [ TRAINING D ] RMSProp is able to reduce the discriminator shocking
        matcher_optimizer = torch.optim.RMSprop(get_models_parameters(*matcher_nets), lr=config['d_lr'])
        return [bc_node_optimizer, generator_node_optimizer, generator_value_optimizer, matcher_optimizer] 
    def _save_models(self, path: str, with_env:bool=True, save_best_graph:bool=True, model_prefixes:str=""):
        """ 
            param: path, where to save the models
            param: with_env, whether to save venv along with the models
        """
        if model_prefixes:
            model_prefixes = model_prefixes + "_"
        
        if save_best_graph:
            graph_train = deepcopy(self.best_graph_train).to("cpu")
            graph_val   = deepcopy(self.best_graph_val).to("cpu")
        else:
            graph_train = deepcopy(self.graph_train).to("cpu")
            graph_val   = deepcopy(self.graph_val).to("cpu")  
            
        graph_train.reset()
        graph_val.reset()
        # Save train model for checkpoint
        if not save_best_graph:
            # Save train model
            for node_name in graph_train.keys():
                node = graph_train.get_node(node_name)
                if node.node_type == 'network':
                    network = deepcopy(node.get_network()).cpu()
                    torch.save(network, os.path.join(path, node_name + '_train.pt'))
            # Save val model
            for node_name in graph_val.keys():
                node = graph_val.get_node(node_name)
                if node.node_type == 'network':
                    network = deepcopy(node.get_network()).cpu()
                    torch.save(network, os.path.join(path, node_name + '_val.pt'))
        if with_env:
            if hasattr(self, "matcher_record_list") and bool(self.matcher_record_list[0]['state_dicts']):
                venv_train = VirtualEnvDev(graph_train, 
                                           train_algo="REVIVE_FILTER",
                                           matcher_record_list=self.matcher_record_list,
                                           matcher_structure_list=self.matcher_structure_list,
                                           matching_nodes_fit_index_list=self.matching_nodes_fit_index_list)
            else:
                venv_train = VirtualEnvDev(graph_train, train_algo="REVIVE_FILTER")
            torch.save(venv_train, os.path.join(path, "venv_train.pt"))
            if hasattr(self, "matcher_record_list") and bool(self.matcher_record_list[0]['state_dicts']):
                venv_val = VirtualEnvDev(graph_val,
                                         train_algo="REVIVE_FILTER",
                                         matcher_structure_list=self.matcher_structure_list,
                                         matcher_record_list=self.matcher_record_list,
                                         matching_nodes_fit_index_list=self.matching_nodes_fit_index_list)
            else:
                venv_val = VirtualEnvDev(graph_val, train_algo="REVIVE_FILTER")
            torch.save(venv_val, os.path.join(path, "venv_val.pt"))
            venv = VirtualEnv([venv_train, venv_val])
            with open(os.path.join(path, model_prefixes + 'venv.pkl'), 'wb') as f:
                pickle.dump(venv, f)
            # venv_list = ray.get(self._data_buffer.get_best_venv.remote()) #self._data_buffer.get_best_venv() #
            # with open(os.path.join(path, model_prefixes +'ensemble_env.pkl'), 'wb') as f:
            #     pickle.dump(venv_list, f)
    @catch_error
    def train_epoch(self):
        info = {}
        # if self._epoch_cnt == self.config['bc_epoch']: 
        # self.switch_data_loader()
        self._load_best_models()
        self._epoch_cnt += 1
        logger.info(f"Train Epoch : {self._epoch_cnt} ")
        # Train MAIL:
        if self._epoch_cnt <= self.config['revive_epoch'] + self.config['matcher_pretrain_epoch']:
            graph = self.graph_mail
            generator_other_nets = self.other_models_mail[:-self.matcher_num]  # value net(s)
            matchers = self.other_models_mail[-self.matcher_num:]
            generator_optimizers = self.mail_optimizers[1]
            other_generator_optimizers = self.mail_optimizers[2]  # value net optim(s)
            matcher_optimizer = self.mail_optimizers[-1]
            scope = "train"
            # [ TRAINING D ] avoid matcher loss shocking
            for _ in range(10):
                for i in range(self.config['d_steps']):
                    self.global_step += 1
                    expert_data_train = next(iter(self._train_loader_train))
                    expert_data_val = next(iter(self._val_loader_train))
                    expert_data = Batch({k:torch.cat([expert_data_train[k], expert_data_val[k]]) for k in expert_data_train.keys()})
                    expert_data.to_torch(device=self._device)
                    _, _info = self._run_matcher(expert_data, graph, matchers, matcher_optimizer, test=not self.config["matcher_sample"])
                    _info = {k + '_train' : v for k, v in _info.items()}
                    info.update(_info)
                d_loss = info[f'{self.NAME}/matcher_loss_train']
                if d_loss <= 1.1:
                    break
            
            # [ TRAINING G ]
            if self._epoch_cnt > self.config['matcher_pretrain_epoch']:  # + self.config['bc_epoch']
                for i in range(self.config['g_steps']):
                    self.global_step += 1
                    expert_data_train = next(iter(self._train_loader_train))
                    expert_data_val = next(iter(self._val_loader_train))
                    expert_data = Batch({k:torch.cat([expert_data_train[k], expert_data_val[k]]) for k in expert_data_train.keys()})
                    expert_data.to_torch(device=self._device)
                    # Repeating expert data for more generator data generation
                    for k, _ in expert_data.items():
                        expert_data[k] = torch.cat([expert_data[k] for _ in range(self.config['generator_data_repeat'])], dim=1)
                    for matcher_index in range(len(matchers)):
                        _info, generated_data = self._run_generator(deepcopy(expert_data), 
                                                                    graph, 
                                                                    generator_other_nets[matcher_index], 
                                                                    matchers[matcher_index], 
                                                                    generator_optimizers, 
                                                                    other_generator_optimizers, 
                                                                    matcher_index=matcher_index,
                                                                    scope=scope)
                        _info = {k + '_train' : v for k, v in _info.items()}
                        info.update(_info)
            # [ Val G ]
            info.update(self._mae_test_for_MAIL(expert_data, generated_data, scope="MailModel_on_TrainData"))
            info.update(self._mse_test_for_MAIL(expert_data, generated_data, scope="MailModel_on_TrainData"))
        # [ TRAINING BC ]
        if self._epoch_cnt <= self.config['bc_epoch']:
            metric_meters_train = AverageMeterCollection()
            for i in range(self.config['bc_steps']):
                self.global_step += 1
                batch = next(iter(self._train_loader_train))  # [horizon=1, batch_size, dim]
                if self.config['rnn_dataset']:
                    batch, loss_mask = batch
                else:
                    loss_mask = torch.tensor([1.])
                loss_mask = loss_mask.to(self._device)
                batch.to_torch(device=self._device)
                batch_info = {
                    "global_step": self.global_step,
                }
                batch_info.update(info)
                if self.config['rnn_dataset']:
                    metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='train', loss_type=self.config['bc_loss'], 
                                                  dataset_mode="trajectory", loss_mask=loss_mask)
                else:
                    # mini batch training
                    batch = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in batch.items()})
                    batch_nums = max(len(batch) // max(min(len(batch) // 256, 1), 128), 1)
                    for mini_batch in batch.split(batch_nums):
                        metrics = self.bc_train_batch(mini_batch, batch_info=batch_info, scope='train', loss_type=self.config['bc_loss'])
                metric_meters_train.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
            info.update(metric_meters_train.summary())
            metric_meters_val = AverageMeterCollection()
            for i in range(self.config['bc_steps']):
                self.global_step += 1
                batch = next(iter(self._val_loader_train))
                if self.config['rnn_dataset']:
                    batch, loss_mask = batch
                loss_mask = loss_mask.to(self._device)
                batch.to_torch(device=self._device)  
                batch_info = {
                    "global_step": self.global_step,
                }
                batch_info.update(info)
                if self.config['rnn_dataset']:
                    metrics = self.bc_train_batch(batch, batch_info=batch_info, scope='train', loss_type=self.config['bc_loss'], 
                                                  dataset_mode="trajectory", loss_mask=loss_mask)
                else:
                    # mini batch training
                    batch = Batch({k:v.reshape(-1, v.shape[-1]) for k,v in batch.items()})
                    batch_nums = max(len(batch) // max(min(len(batch) // 256, 1), 128), 1)
                    for mini_batch in batch.split(batch_nums):
                        metrics = self.bc_train_batch(mini_batch, batch_info=batch_info, scope='val', loss_type=self.config['bc_loss'])
                metric_meters_val.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
            info.update(metric_meters_val.summary())
        info = self._early_stop(info)
        if self._epoch_cnt >= (self.config['revive_epoch'] + self.config['matcher_pretrain_epoch']) and self._epoch_cnt >= self.config['bc_epoch']:  # + self.config['bc_epoch']
            self._stop_flag = True
            info["stop_flag"] = self._stop_flag
        # del expert_data, generated_data
        return {k : info[k] for k in filter(lambda k: not k.startswith('last'), info.keys())}
    def _generate_rollout(self, expert_data, graph, matcher=None, test=False, generate_reward=False, sample_fn=None, clip=False,
                          mix_sample=False, mix_sample_ratio=0.5):
        # generate rollout
        if sample_fn is None:
            sample_fn = lambda dist: dist.mode if test else (dist.rsample() if self.config['generator_algo'] == 'svg' else dist.sample())
        if generate_reward:
            generated_data = generate_rollout(expert_data, graph, expert_data.shape[0], sample_fn, self.adapt_stds, clip, 
                                              replace_with_expert=self.config['replace_with_expert'], replace_ratio=self.config['replace_ratio'])
            generated_data = self._generate_rewards(generated_data, expert_data, graph, matcher)
        else:
            generated_data = generate_rollout(expert_data, graph, expert_data.shape[0], sample_fn, self.adapt_stds, clip, 
                                          mix_sample=mix_sample, mix_sample_ratio=mix_sample_ratio)
        return generated_data
    def _run_matcher(self, expert_data, graph, matchers, matcher_optimizer=None, test=False, scope=None):
        # generate rollout
        # [ TRAINING D ] the same setting as policy learning: data should be the same as policy learn
        generated_data = self._generate_rollout(expert_data, 
                                                graph,
                                                test=test, 
                                                clip=1.5, 
                                                mix_sample=self.config['mix_sample'], 
                                                mix_sample_ratio=self.config['mix_sample_ratio'])
        isnan_index_list = []
        for matching_nodes in self.matching_nodes_list:
            for node_name in matching_nodes:
                if node_name + "_isnan_index_" in expert_data.keys():
                    isnan_index_list.append(expert_data[node_name + "_isnan_index_"]) 
        if isnan_index_list:
            isnan_index = torch.mean(torch.cat(isnan_index_list,axis=-1),axis=-1)
            expert_data = expert_data[isnan_index==0]
            generated_data = generated_data[isnan_index==0]
        if self.config["mix_data"]:
            expert_data, generated_data = self.mix_data_process(expert_data, generated_data, matchers, matcher_optimizer)
            # print(f"generated_data.shape: {generated_data.shape}")
            # print(f"expert_data.shape: {expert_data.shape}")
        # compute matcher score
        expert_scores, generated_scores, matcher_losses, grad_losses = [], [], [], []
        matcher_dict = dict()
        for matcher_index in range(len(matchers)):
            expert_score = self._get_matcher_score(expert_data, matchers[matcher_index])
            generated_score = self._get_matcher_score(generated_data, matchers[matcher_index])
            expert_scores.append(deepcopy(expert_score.detach()))
            generated_scores.append(deepcopy(generated_score.detach()))
            if self._epoch_cnt % self.config['matcher_record_interval'] == 0:
                self.matcher_record_list[matcher_index]['expert_scores'].append(expert_score.detach().cpu().numpy().mean().item())
                self.matcher_record_list[matcher_index]['generated_scores'].append(generated_score.detach().cpu().numpy().mean().item())
                self.matcher_record_list[matcher_index]['state_dicts'].append({k: v.detach().cpu() for k, v in matchers[matcher_index].state_dict().items()})
            real = torch.ones_like(expert_score)
            fake = torch.zeros_like(generated_score)
            expert_entropy = torch.distributions.Bernoulli(expert_score).entropy().mean()
            generated_entropy = torch.distributions.Bernoulli(generated_score).entropy().mean()
            entropy_loss = -self.config['discr_ent_coef'] * (expert_entropy + generated_entropy)
            l2_norm_loss = self.net_l2_norm(matchers[matcher_index]) * self.config['matcher_l2_norm_coeff']
            # loss_discriminator = F.binary_cross_entropy(expert_score, real) + F.binary_cross_entropy(generated_score, fake) + \
            #                         entropy_loss + l2_norm_loss
            if self.config['controller_weight'] > 0:
                loss_discriminator = (-real * (torch.log(expert_score + 1e-8) - self.config['controller_weight']/2 * torch.pow(expert_score, 2) + self.config['controller_weight']/2 * expert_score) + \
                                        
- (1 - fake) * (torch.log(1 - generated_score + 1e-8) - self.config['controller_weight']/2 * torch.pow(generated_score, 2) + self.config['controller_weight']/2 * generated_score)).mean()
            else:
                loss_discriminator = (-real * torch.log(expert_score + 1e-8) + \
                                        
- (1 - fake) * torch.log(1 - generated_score + 1e-8)).mean()
            
            loss_discriminator = loss_discriminator + entropy_loss + l2_norm_loss
            # gradient penalty
            if self.config['gp_coef'] > 0:
                matcher_optimizer.zero_grad(set_to_none=True)
                grads = torch.autograd.grad(outputs=loss_discriminator, inputs=matchers[matcher_index].parameters(), create_graph=True)
                grad_loss = 0
                for grad in grads:
                    grad_loss += torch.pow(grad, 2).sum()
                grad_losses.append(grad_loss)
                total_loss = loss_discriminator + self.config['gp_coef'] * grad_loss
                matcher_losses.append(total_loss)
            else:
                total_loss = loss_discriminator
                matcher_losses.append(total_loss)
            matcher_dict.update({
                    f"{self.NAME}/expert_score_{matcher_index}": expert_score.mean().item(),
                    f"{self.NAME}/generated_score_{matcher_index}": generated_score.mean().item(),
                    # f"{self.NAME}/expert_score_total": expert_score_total.mean().item(),
                    # f"{self.NAME}/generated_score_total": generated_score_total.mean().item(),
                    f"{self.NAME}/expert_acc_{matcher_index}": (expert_score > 0.5).float().mean().item(),
                    f"{self.NAME}/generated_acc_{matcher_index}": (generated_score < 0.5).float().mean().item(),
                    f"{self.NAME}/expert_entropy_{matcher_index}": expert_entropy.item(),
                    f"{self.NAME}/generated_entropy_{matcher_index}": generated_entropy.item(),
                    f"{self.NAME}/matcher_loss_{matcher_index}": loss_discriminator.item(),
                    # f"{self.NAME}/matcher_loss_total": matcher_loss_total.item(),
                })
        expert_score = sum(expert_scores) / len(expert_scores)
        generated_score = sum(generated_scores) / len(generated_scores)
        matcher_loss = sum(matcher_losses) / len(matcher_losses)
        info = {}
        if matcher_optimizer is not None:
            matcher_optimizer.zero_grad(set_to_none=True)
            matcher_loss.backward()
            matcher_grad_norm = nn.utils.clip_grad_norm_(get_models_parameters(*matchers), self.config["matcher_grad_norm_clip"])
            if torch.any(torch.isnan(matcher_grad_norm)):
                self.nan_in_grad()
                logger.info(f'Detect nan in gradient, skip this batch! (loss : {matcher_loss}, grad_norm : {matcher_grad_norm})')
            else:
                matcher_optimizer.step()
            info[f"{self.NAME}/matcher_grad_norm"] = matcher_grad_norm.item()
        info.update({
            f"{self.NAME}/expert_score": expert_score.mean().item(),
            f"{self.NAME}/generated_score": generated_score.mean().item(),
            f"{self.NAME}/expert_acc": (expert_score > 0.5).float().mean().item(),
            f"{self.NAME}/generated_acc": (generated_score < 0.5).float().mean().item(),
            f"{self.NAME}/expert_entropy": expert_entropy.item(),
            f"{self.NAME}/generated_entropy": generated_entropy.item(),
            f"{self.NAME}/matcher_loss": matcher_loss.item(),
        })
        if self.config['gp_coef'] > 0:
            grad_loss = sum(grad_losses) / len(grad_losses)
            info.update({
                f"{self.NAME}/grad_penalty_loss": grad_loss.item(),
            })
        info.update(matcher_dict)
        return generated_data, info
    
[docs]
    def PPO_step(self, 
                 generated_data, 
                 graph, 
                 value_net, 
                 generator_optimizers, 
                 matcher, 
                 other_generator_optimizers,
                 epsilon=0.1, 
                 lam=0, 
                 w_ent=0, 
                 matcher_index=None, 
                 scope=None):
        """
        Train Policy including policies, transition, and value_net by PPO algorithm.
        :param generated_data: generated trajectory
        :param graph: decision graph
        :param value_net: value net
        :param generator_optimizers: the optimizers used to optimize node models and value net
        :param epsilon: hyperparameter for clipping in the policy objective
        :param lam: regularization parameter
        :param w_ent: the weight of entropy loss
        :return: v_loss, p_loss, sup_loss, total_loss, generator_grad_norm
        """
        assert scope in ["train", "val"], f"scope : {scope}"
        info = {}
        # update critic
        value_state = get_concat_traj(generated_data, self.state_nodes)
        value_o = value_net(value_state.detach())
        v_loss = (value_o - generated_data.returns.detach()).pow(2).mean()
        v_loss += self.net_l2_norm(value_net) * self.config['value_l2_norm_coef']
        if other_generator_optimizers is not None:
            other_generator_optimizers.zero_grad(set_to_none=True)
            v_loss.backward()
            policy_grad_norm = nn.utils.clip_grad_norm_(get_models_parameters(value_net), 100)
            if torch.any(torch.isnan(policy_grad_norm)):
                self.nan_in_grad()
                logger.info(f'Detect nan in gradient, skip this batch! (loss : {v_loss}, grad_norm : {policy_grad_norm})')
            else:
                other_generator_optimizers.step()
        # update actor
        graph.reset()
        prob_olds = []
        prob_news = []
        e_loss = 0
        l2_norm_loss = 0
        stds = []
        matching_fit_nodes = matcher.matching_fit_nodes
        for node_name, adapt_std in zip(list(graph.keys()), self.adapt_stds):
            # Skip not fit nodes
            if node_name not in matching_fit_nodes:
                continue
            node = graph.get_node(node_name)
            if node.node_type == 'network':
                action_dist = graph.compute_node(node_name, generated_data, adapt_std=adapt_std)
                policy_entropy = action_dist.entropy()
                stds.append(action_dist.std.mean().item())
                e_loss += policy_entropy.mean()
                l2_norm_loss += self.net_l2_norm(graph.get_node(node_name).network)
                info[node_name + '_entropy'] = policy_entropy.mean().item()
                prob_olds.append(generated_data[node_name + '_log_prob'])
                prob_news.append(action_dist.log_prob(generated_data[node_name]).unsqueeze(dim=-1))
        prob_old = torch.cat(prob_olds, dim=-1)
        prob_new = torch.cat(prob_news, dim=-1)
        clipped_prob = torch.clamp_max(prob_new - prob_old.detach(), 40)
        ratio = torch.exp(clipped_prob)
        # [ OTHER ] more logs
        kl = (prob_old - prob_new).detach()
        surr1 = ratio * generated_data.advs
        surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * generated_data.advs
        p_loss = - torch.min(surr1, surr2).mean()
        # total_loss = v_loss + p_loss - e_loss * w_ent + self.config['generator_l2_norm_coef'] * l2_norm_loss
        total_loss = p_loss - e_loss * w_ent + self.config['generator_l2_norm_coef'] * l2_norm_loss
        # [ OTHER ] add more logs
        info.update({
            f"v_loss_{matcher_index}" : v_loss.mean().item(),
            f"p_loss_{matcher_index}" : p_loss.mean().item(),
            f"e_loss_{matcher_index}" : e_loss.mean().item(),
            f"total_loss_{matcher_index}": total_loss.mean().item(),
            f"ratio_{matcher_index}" : ratio.mean().item(),
            f"value_{matcher_index}" : value_o.mean().item(),
            f"return_{matcher_index}" : generated_data.returns.mean().item(),
            f"advantages_{matcher_index}" : generated_data.advs.mean().item(),
            f"std_{matcher_index}": np.mean(stds),
            f"kl_{matcher_index}": kl.mean().item(),
            f"max_kl_{matcher_index}": kl.max().item(),
            f"min_kl_{matcher_index}": kl.min().item(),
            f"max_logp_new_{matcher_index}": prob_new.max().item(),
            f"min_logp_new_{matcher_index}": prob_new.min().item(),
            f"max_logp_old_{matcher_index}": prob_old.max().item(),
            f"min_logp_old_{matcher_index}": prob_old.min().item(),
            f'mean_logp_old_{matcher_index}': prob_old.mean().item(),
            f"mean_logp_new_{matcher_index}": prob_new.mean().item()
        })
        if generator_optimizers is not None:
            generator_optimizers.zero_grad(set_to_none=True)
            total_loss.backward()
            policy_grad_norm = nn.utils.clip_grad_norm_(get_models_parameters(*graph.collect_models()), 100)
            if torch.isnan(policy_grad_norm.mean()):
                self.nan_in_grad()
                logger.info(f'Detect nan in gradient, skip this batch! (loss : {total_loss}, grad_norm : {policy_grad_norm})')
            else:
                generator_optimizers.step()
            info[f"ppo_grad_norm"] = policy_grad_norm.item()
        return info 
    def _mae_test_for_MAIL(self, expert_data, generated_data, scope="MailModel_on_TrainData"):
        info = {}
        # if 'mae' in self.metric_name:
        #     self.graph_to_save_train = self.graph_train
        #     self.graph_to_save_val = self.graph_val
        graph = self.graph_mail
        graph.reset()
        new_data = Batch({name : expert_data[name] for name in graph.leaf})
        total_mae = 0
        for node_name in graph.keys():
            if node_name + "_isnan_index_" in expert_data.keys():
                isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"]) 
            else:
                isnan_index = None
            if graph.get_node(node_name).node_type == 'network':
                node_dist = graph.compute_node(node_name, new_data)
                new_data[node_name] = node_dist.mode
            else:
                new_data[node_name] = graph.compute_node(node_name, new_data)
                continue
            if node_name in graph.metric_nodes:
                if isnan_index is not None:
                    node_mae = ((new_data[node_name] - expert_data[node_name])*isnan_index).abs().sum(dim=-1).mean()
                else:
                    node_mae = (new_data[node_name] - expert_data[node_name]).abs().sum(dim=-1).mean()
                total_mae += node_mae.item()
                info[f"{self.NAME}/{node_name}_one_step_mae_{scope}"] = node_mae.item()
        info[f"{self.NAME}/average_one_step_mae_{scope}"] = total_mae / self.total_dim
        mae_error = 0
        for node_name in graph.keys():
            if node_name in graph.metric_nodes:
                if node_name + "_isnan_index_" in expert_data.keys():
                    isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"]) 
                else:
                    isnan_index = None
                if isnan_index is not None:
                    policy_shooting_error = (torch.abs(expert_data[node_name] - generated_data[node_name])*isnan_index).sum(dim=-1).mean()
                else:
                    policy_shooting_error = torch.abs(expert_data[node_name] - generated_data[node_name]).sum(dim=-1).mean()
                mae_error += policy_shooting_error.item()
                info[f"{self.NAME}/{node_name}_rollout_mae_{scope}"] = policy_shooting_error.item()
                # TODO: plot rollout error
                # rollout_error = torch.abs(expert_data[node_name] - generated_data[node_name]).reshape(expert_data.shape[0],-1).mean(dim=-1)
        info[f"{self.NAME}/average_rollout_mae_{scope}"] = mae_error / self.total_dim
        return info
    def _mse_test_for_MAIL(self, expert_data, generated_data, scope="MailModel_on_TrainData"):
        info = {}
        # if 'mse' in self.metric_name:
        #     self.graph_to_save_train = self.graph_train
        #     self.graph_to_save_val = self.graph_val
        graph = self.graph_mail
        graph.reset()        
        new_data = Batch({name : expert_data[name] for name in graph.leaf})
        total_mse = 0
        for node_name in graph.keys():
            if node_name + "_isnan_index_" in expert_data.keys():
                isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"]) 
            else:
                isnan_index = None
            if graph.get_node(node_name).node_type == 'network':
                node_dist = graph.compute_node(node_name, new_data)
                new_data[node_name] = node_dist.mode
            else:
                new_data[node_name] = graph.compute_node(node_name, new_data)
                continue
            if node_name in graph.metric_nodes:
                if isnan_index is not None:
                    node_mse = (((new_data[node_name] - expert_data[node_name])*isnan_index) ** 2).sum(dim=-1).mean()
                else:
                    node_mse = ((new_data[node_name] - expert_data[node_name]) ** 2).sum(dim=-1).mean()
                total_mse += node_mse.item()
                info[f"{self.NAME}/{node_name}_one_step_mse_{scope}"] = node_mse.item()
        info[f"{self.NAME}/average_one_step_mse_{scope}"] = total_mse / self.total_dim
        mse_error = 0
        for node_name in graph.metric_nodes:
            if node_name + "_isnan_index_" in expert_data.keys():
                isnan_index = 1 - torch.mean(expert_data[node_name + "_isnan_index_"]) 
            else:
                isnan_index = None
            if isnan_index is not None:
                policy_rollout_mse = (((expert_data[node_name] - generated_data[node_name])*isnan_index) ** 2).sum(dim=-1).mean()
            else:
                policy_rollout_mse = ((expert_data[node_name] - generated_data[node_name]) ** 2).sum(dim=-1).mean()
            mse_error += policy_rollout_mse.item()
            info[f"{self.NAME}/{node_name}_rollout_mse_{scope}"] = policy_rollout_mse.item()
        info[f"{self.NAME}/average_rollout_mse_{scope}"] = mse_error / self.total_dim
        return info