''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2025 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.
    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
import importlib
import json
import os
import shutil
import sys
import warnings
from copy import deepcopy
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
from matplotlib.pyplot import step
import time
import numpy as np
import pandas as pd
import ray
import revive
import torch
import yaml
from loguru import logger
from revive.computation.funs_parser import parser
from revive.computation.graph import *
from revive.data.batch import Batch
from revive.data.processor import DataProcessor
from revive.utils.common_utils import find_later, load_data, plot_traj, PositionalEncoding, create_unit_vector, generate_bin_encoding
DATADIR = os.path.abspath(os.path.join(os.path.dirname(revive.__file__), '../data/'))
BATCH_SIZE = "*batch_size"
[docs]
class OfflineDataset(torch.utils.data.Dataset):
    r"""An offline dataset class.
    Params:
        :data_file: The file path where the training dataset is stored.
        :config_file: The file path where the data description file is stored.
        :horizon: Length of iteration trajectory.
    """
    def __init__(self, data_file : str,
                       config_file : str,
                       revive_config : dict = None,
                       ignore_check : bool = False,
                       horizon : int = None,
                       reward_func = None):
        self.data_file = data_file
        self.config_file = config_file
        self.ignore_check = ignore_check
        self.reward_func = reward_func
        self.revive_config = revive_config
        self._raw_data = None
        self._raw_config = None
        self._pre_data(self.config_file, self.data_file)
        self._load_config(self.config_file)
        self._load_data(self.data_file)
        self.graph.ts_nodes = self.ts_nodes
        self.graph.freeze_nodes = getattr(self, "freeze_nodes", [])
        self.graph.nodes_loss_type = getattr(self, "nodes_loss_type", {})
        
        # pop up unused keys
        used_keys = list(self.graph.keys()) + self.graph.leaf + ['done']
        used_keys += list(self.ts_nodes.values())
        used_keys += list([key for key in self.data.keys() if key.endswith("_isnan_index_")])
        for key in list(self.data.keys()):
            if key not in used_keys:
                try:
                    self.data.pop(key)
                    self.raw_columns.pop(key)
                    self.data_configs.pop(key)
                    self.orders.pop(key)
                    warnings.warn(f'Warning: pop up unused key: {key}')
                except Exception as e:
                    logger.info(f"{e}")
        self._check_data()
        # construct the data processor
        self.processing_params = {k : self._get_process_params(self.data.get(k, None), self.data_configs[k], self.orders[k]) for k in self.data_configs.keys()}
        for curr_name, next_name in self.graph.transition_map.items():
            self.processing_params[next_name] = self.processing_params[curr_name]
            self.orders[next_name] = self.orders[curr_name]
            self.data_configs[next_name] = self.data_configs[curr_name]
        self.processor = DataProcessor(self.data_configs, self.processing_params, self.orders)
        self.graph.register_processor(self.processor)
        # separate tunable parameter from self.data
        if len(self.graph.tunable) > 0:
            self.tunable_data = Batch()
            for tunable in self.graph.tunable:
                self.tunable_data[tunable] = self.data.pop(tunable)
        else:
            self.tunable_data = None
        self.learning_nodes_num = self.graph.summary_nodes()['network_nodes']
        # by default, dataset is in trajectory mode
        self.trajectory_mode_(horizon)
        if self.reward_func is not None:
            self.data.to_torch()
            self.data["_data_reward_"] = to_numpy(self.reward_func(self.data))
            self.data.to_numpy()
    def _pre_data(self, config_file, data_file):
        # parse ts_nodes
        with open(config_file, 'r', encoding='UTF-8') as f:
            raw_config = yaml.load(f, Loader=yaml.FullLoader)
        raw_data = load_data(data_file)
        
        for n in raw_config['metadata']['graph'].keys():
            if raw_config['metadata']['graph'][n] == None:
                raw_config['metadata']['graph'][n] = []
        
        graph_dict = raw_config['metadata'].get('graph', None)
        nodes_config = raw_config['metadata'].get('nodes', None)
        self.ts_nodes = {}
        self.ts_node_frames = {}
        self.ts_frames_config = {}
        use_step_node = False
        self.expert_functions = raw_config['metadata'].get('expert_functions',{})
        if nodes_config:
            ############################ append step node ############################
            step_node_config = {k:v["step_input"] for k,v in nodes_config.items() if "step_input" in v.keys() }
            if step_node_config:
                pos_emd_flag = self.revive_config["use_time_step_embed"]
                assert "step_node_" not in raw_data.keys()
                trj_index = [0,] + list(raw_data["index"])
                max_traj_len = np.max(np.array(trj_index[1:]) - np.array(trj_index[:-1]))
                step_node_data_list = []
                if pos_emd_flag:
                    d_model = self.revive_config["time_step_embed_size"]
                    pos_emd = PositionalEncoding(d_model=d_model, max_len=max_traj_len)
                for trj_start_index, trj_end_index in zip(trj_index[:-1], trj_index[1:]):
                    step_index = np.arange(0, trj_end_index-trj_start_index).reshape(-1,1)
                    if pos_emd_flag:
                        step_index = pos_emd.encode(step_index)
                    step_node_data_list.append(step_index)
                raw_data["step_node_"] = np.concatenate(step_node_data_list, axis=0)
                raw_config['metadata']['columns'] += [{'step_node_': {'dim': 'step_node_', 'type': 'continuous', 'min': 0, 'max': max_traj_len*2}},] if not pos_emd_flag else [{f'step_node_{i}': {'dim': 'step_node_', 'type': 'continuous', 'min': -1, 'max': 1}} for i in range(d_model)]
                for node, step_node_flag in step_node_config.items():
                    if node not in graph_dict.keys():
                        continue 
                    if step_node_flag:
                        logger.info(f"Append step_node data as the {node} node's input.")
                        graph_dict[node] = graph_dict[node] + ["step_node_",]
                        use_step_node = True
                # append step_node_function
                if use_step_node:
                    if not pos_emd_flag:
                        raw_config['metadata']['graph']['next_step_node_'] = ['step_node_']
                        # self.expert_functions = raw_config['metadata'].get('expert_functions',{})
                        shutil.copyfile(os.path.join(os.path.dirname(revive.__file__),"./common/step_node_function.py"),  os.path.join(os.path.dirname(self.config_file), "./step_node_function.py"))
                        self.expert_functions['next_step_node_'] = {"node_function": "step_node_function.get_next_step_node"}
                self._raw_config = raw_config
                self._raw_data = raw_data
                    
            ############################ append ts ###################################
            # collect nodes config
            ts_frames_config = {}
            for k,v in nodes_config.items():                    
                if "ts" in v.keys():
                    if "ts_endpoint" not in v.keys():
                        ts_endpoint = v["ts"] + 1
                    else:
                        ts_endpoint = v["ts_endpoint"]
                    
                    if "ts_repeat" not in v.keys():
                        ts_mode = "repeat"
                    else:
                        if v["ts_repeat"]:
                            ts_mode = "repeat"
                        else:
                            ts_mode = "sequence"
                        
                    assert ts_endpoint <= (v["ts"]+1) and ts_endpoint >= 1
                    assert ts_mode in ["repeat", "sequence"], f'ts_mode -> {ts_mode} not in ["repeat", "sequence"]'
                    
                    ts_frames_config["ts_"+k] = [v["ts"], ts_endpoint, ts_mode]
            
            if ts_frames_config:
                max_ts_frames = max([v[0] for k,v in ts_frames_config.items()])
                if "sequence" in [v[2] for k,v in ts_frames_config.items()]:
                    ts_mode = "sequence"
                else:
                    ts_mode = "repeat"
                nodes = list(graph_dict.keys())
                for output_node in list(graph_dict.keys()):
                    nodes += list(graph_dict[output_node])
                nodes = list(set(nodes))
                # parse nno ts_nodes
                for index, node in enumerate(nodes):
                    if node.startswith("next_ts_"):
                        nodes[index] = "next_" + node[8:]
                    if node.startswith("ts_"):
                        nodes[index] = node[3:]     
                ts_nodes = {"ts_"+node:node for node in nodes if "ts_"+node in ts_frames_config.keys()}
                for ts_node,node in ts_nodes.items():
                    if node not in raw_data.keys():
                        logger.error(f"Can't find '{node}' node data.")
                        sys.exit()
                self.ts_nodes = ts_nodes
                
                # ts_node npz data
                trj_index = [0,] + list(raw_data["index"])
                new_data = {k:[] for k in ts_nodes.keys()}
                new_index = []
                i = 0
                for trj_start_index, trj_end_index in zip(trj_index[:-1], trj_index[1:]):
                    i += 1
                    first_ts_node_flag = True
                    for ts_node,node in ts_nodes.items():
                        ts_node_frames, ts_endpoint, ts_mode = ts_frames_config[ts_node]
                        assert ts_node_frames >= 1
                        self.ts_node_frames[ts_node] = ts_node_frames
                        self.ts_frames_config[ts_node] = ts_frames_config[ts_node]
                        pad_data = np.concatenate([np.repeat(raw_data[node][trj_start_index:trj_start_index+1],repeats=ts_node_frames,axis=0), raw_data[node][trj_start_index:trj_end_index]])
                        if ts_mode != "repeat":
                            assert trj_end_index - trj_start_index > ts_node_frames
                            if first_ts_node_flag:
                                new_index.append(trj_end_index-(max_ts_frames*i))
                            ts_node_data = np.concatenate([pad_data[i:i+(trj_end_index-trj_start_index)] for i in range(ts_node_frames+1)], axis=1)
                            new_data[ts_node].append(ts_node_data[max_ts_frames:])
                            
                            if first_ts_node_flag:
                                for k,v in raw_data.items():
                                    if k == "index":
                                        continue
                                    if k not in new_data.keys():
                                        new_data[k] = []
                                    new_data[k].append(raw_data[k][trj_start_index:trj_end_index][max_ts_frames:])
                            first_ts_node_flag = False
        
                        else:
                            new_data[ts_node].append(np.concatenate([pad_data[i:i+(trj_end_index-trj_start_index)] for i in range(ts_node_frames+1)], axis=1))
                
   
                new_index = list(sorted(set(new_index)))
                new_data = {k:np.concatenate(v,axis=0) for k,v in new_data.items()}
                for k,v in new_data.items():
                    if k in ts_frames_config.keys():
                        node = k[3:]
                        ts_endpoint = ts_frames_config[k][1]
                        new_data[k]= v[:,:ts_endpoint*raw_data[node].shape[-1]]
                if new_index:
                    logger.info("Ts_mode : {Sequence}")
                    new_data["index"] = np.array(new_index)  
                                  
                raw_data.update(new_data)
                # ts_node columns
                for ts_node, node in ts_nodes.items():
                    ts_node_frames, ts_endpoint, ts_mode = ts_frames_config[ts_node]
                    node_columns = [c for c in raw_config['metadata']['columns'] if list(c.values())[0]["dim"] == node]
                    ts_node_columns = []
                    ts_index = 0
                    for ts_index in range(ts_endpoint):
                        for node_column in node_columns:
                            node_column_value = deepcopy(list(node_column.values())[0])
                            node_column_name = deepcopy(list(node_column.keys())[0])
                            node_column_value["dim"] = ts_node
                            if ts_index+1 < ts_node_frames:
                                node_column_value["fit"] = False
                            ts_node_columns.append({"ts_"+str(ts_index)+"_"+node_column_name:node_column_value})
                    raw_config['metadata']['columns'] += ts_node_columns
                self._raw_config = raw_config
                self._raw_data = raw_data
            ##################################################################
            ############################ trajectory_id #################
            # collect nodes config
            traj_id_node_config = {k:v["traj_id"] for k,v in nodes_config.items() if "traj_id" in v.keys()}
            if traj_id_node_config:
                traj_id_emd_flag = self.revive_config["use_traj_id_embed"]
                traj_index = [0,] + list(raw_data["index"])
                start_indexes = traj_index[:-1]
                end_indexes = traj_index[1:]
                traj_num = len(start_indexes)
                if traj_id_emd_flag:
                    traj_encodings = generate_bin_encoding(traj_num)
                traj_ids = []
                for i_traj in range(traj_num):
                    if traj_id_emd_flag:
                        traj_ids.append(np.ones((end_indexes[i_traj] - start_indexes[i_traj], 1)) * traj_encodings[i_traj])
                    else:
                        traj_ids.append(np.ones((end_indexes[i_traj] - start_indexes[i_traj], 1)) * i_traj)
                raw_data['traj'] = np.concatenate(traj_ids)
                raw_config['metadata']['columns'] += [{'traj': {'dim': 'traj', 'type': 'continuous', 'min': 0, 'max': len(start_indexes)-1}},] if not traj_id_emd_flag else [{f'traj_{i}': {'dim': 'traj', 'type': 'discrete', 'min': 0, 'max': 1, 'num': 2}} for i in range(traj_encodings.shape[1])]
                for node, _ in traj_id_node_config.items():
                    raw_config['metadata']['graph'][node].append('traj')
                self._raw_config = raw_config
                self._raw_data = raw_data
            ##################################################################
            ############################ collect freeze node #################
            # collect nodes config
            self.freeze_nodes = [k for k,v in nodes_config.items() if "freeze" in v.keys() and v["freeze"]]
            logger.info("Freeze the following nodes without training -> {self.freeze_nodes}")
            self.nodes_loss_type = {k:v["loss_type"] for k,v in nodes_config.items() if "loss_type" in v.keys() and v["loss_type"]}
            logger.info(f"Nodes loss type -> {self.nodes_loss_type}")
            ##################################################################
        ############################# delta_node ######################
        nodes = list(graph_dict.keys())
        for output_node in list(graph_dict.keys()):
            nodes += list(graph_dict[output_node])
        nodes = list(set(nodes))
        for node in nodes:
            if node.startswith("delta_"):
                if node not in list(raw_data.keys()):
                    if node[6:] in nodes and "next_"+node[6:] in nodes:
                        node_columns = [c for c in raw_config['metadata']['columns'] if list(c.values())[0]["dim"] == node[6:]]
                        if all(isinstance(item[list(item.keys())[0]]['type'], str) and item[list(item.keys())[0]]['type'] == 'continuous' for item in node_columns):
                            raw_data[node] = raw_data["next_"+node[6:]] - raw_data[node[6:]]
                            delta_node_columns = [{"delta_"+list(column.keys())[0] : {"dim":node, "type": list(column.values())[0]["type"]}} for column in node_columns]
                            raw_config['metadata']['columns'] += delta_node_columns
                            logger.warning(f"Automatically compute the {node} node.")
                        else:
                            logger.warning(f"All dimensions of the {node} node must be of continuous type.")    
        self._raw_config = raw_config
        self._raw_data = raw_data
        ##################################################################
    def _check_data(self):
        '''check if the data format is correct'''
        ''' prepare fake data for no data nodes '''
        for nodata_node_name in self.graph.nodata_node_names:
            assert nodata_node_name not in self.data.keys()
            node_data_dims = len(self.raw_columns[nodata_node_name])
            self.data[nodata_node_name] = np.ones((list(self.data.values())[0].shape[0], node_data_dims), dtype=np.float32)
        '''1. check if the dimension of data matches the dimension described in yaml'''
        for k, v in self.raw_columns.items():
            assert k in self.data.keys(), f'Cannot find `{k}` in the data file, please check!'
            data = self.data[k]
            assert len(data.shape) == 2, f'Expect data in 2D ndarray, got variable `{k}` in shape {data.shape}, please check!'
            assert data.shape[-1] == len(v), f'Variable `{k}` described in yaml has {len(v)} dims, but got {data.shape[-1]} dims from the data file, please check!'
        for curr_name, next_name in self.graph.transition_map.items():
            assert self.data[curr_name].shape == self.data[next_name].shape, \
                
f'Shape mismatch between `{curr_name}` (shape {self.data[curr_name].shape}) and `{next_name}` (shape {self.data[next_name].shape}). ' + \
                
f'If it is you who puts `{next_name}` in the data file, please check the way you generate it. ' + \
                
f'Otherwise, it is probably you have register a function to compute `{next_name}` but it output a wrong shape, please check the function!'
        '''2. check if the functions are correctly defined'''
        for node_name in self.graph.keys():
            node = self.graph.get_node(node_name)
            if node.node_type == 'function':
                # test 1D case
                input_data = {name : self.data[name][0] for name in node.input_names}
                should_output_data = self.data[node.name][0]
                
                if node.node_function_type == 'torch':
                    input_data = {k : torch.tensor(v) for k, v in input_data.items()}
                    should_output_data = torch.tensor(should_output_data)
                output_data = node.node_function(input_data)
                
                assert output_data.shape == should_output_data.shape, \
                    
f'Testing function for `{node_name}`. Expect function output shape {should_output_data.shape}, got {output_data.shape} instead!'
                assert type(output_data) == type(should_output_data), \
                    
f'Testing function for `{node_name}`. Expect function output type {type(should_output_data)}, got {type(output_data)} instead!'
                assert output_data.dtype == should_output_data.dtype, \
                    
f'Testing function for `{node_name}`. Expect function output dtype {should_output_data.dtype}, got {output_data.dtype} instead!'
                # test 2D case
                input_data = {name : self.data[name][:2] for name in node.input_names}
                should_output_data = self.data[node.name][:2]
                
                if node.node_function_type == 'torch':
                    input_data = {k : torch.tensor(v) for k, v in input_data.items()}
                    should_output_data = torch.tensor(should_output_data)
                output_data = node.node_function(input_data)
                assert output_data.shape == should_output_data.shape, \
                    
f'Testing function for `{node_name}`. Expect function output shape {should_output_data.shape}, got {output_data.shape} instead!'
                assert type(output_data) == type(should_output_data), \
                    
f'Testing function for `{node_name}`. Expect function output type {type(should_output_data)}, got {type(output_data)} instead!'
                assert output_data.dtype == should_output_data.dtype, \
                    
f'Testing function for `{node_name}`. Expect function output dtype {should_output_data.dtype}, got {output_data.dtype} instead!'
                # test 3D case
                input_data = {name : self.data[name][:2][np.newaxis] for name in node.input_names}
                should_output_data = self.data[node.name][:2][np.newaxis]
                
                if node.node_function_type == 'torch':
                    input_data = {k : torch.tensor(v) for k, v in input_data.items()}
                    should_output_data = torch.tensor(should_output_data)
                output_data = node.node_function(input_data)
                assert output_data.shape == should_output_data.shape, \
                    
f'Testing function for `{node_name}`. Expect function output shape {should_output_data.shape}, got {output_data.shape} instead!'
                assert type(output_data) == type(should_output_data), \
                    
f'Testing function for `{node_name}`. Expect function output type {type(should_output_data)}, got {type(output_data)} instead!'
                assert output_data.dtype == should_output_data.dtype, \
                    
f'Testing function for `{node_name}`. Expect function output dtype {should_output_data.dtype}, got {output_data.dtype} instead!'
                # test value
                input_data = {name : self.data[name] for name in node.input_names}
                should_output_data = self.data[node.name]
                
                if node.node_function_type == 'torch':
                    input_data = {k : torch.tensor(v) for k, v in input_data.items()}
                output_data = node.node_function(input_data)
                if node.node_function_type == 'torch':
                    output_data = output_data.numpy()
                
                error = np.abs(output_data - should_output_data)
                if np.max(error) > 1e-8:
                    message = f'Test values for function "{node.name}", find max mismatch {np.max(error, axis=0)}. Please check the function.'
                    if self.ignore_check or np.max(error) < 1e-2:
                        logger.warning(message)
                    else:
                        message += '\nIf you are sure that the function is right and the value error is acceptable, configure "ignore_check=True" in the config.json to skip.'
                        message += '\nIf you are using the "train.py" script. You can add the "--ignore_check 1" to skip. E.g. python train.py --ignore_check 1'
                        logger.error(message)
                        raise ValueError(message)                    
        '''3. check if the transition variables match'''
        for curr_name, next_name in self.graph.transition_map.items():
            curr_data = []
            next_data = []
            for start, end in zip(self._start_indexes, self._end_indexes):
                curr_data.append(self.data[curr_name][start+1:end])
                next_data.append(self.data[next_name][start:end-1])
            if not np.allclose(np.concatenate(curr_data), np.concatenate(next_data), 1e-4):
                error = np.abs(np.concatenate(curr_data) - np.concatenate(next_data))
                message = f'Test transition values for {curr_name} and {next_name}, find max mismatch {np.max(error, axis=0)}. ' + \
                    
f'If {next_name} is provided by you, please check the data file. If you provide {next_name} as a function, please check the function.'
                if self.ignore_check:
                    logger.warning(message)
                else:
                    logger.error(message)
                    raise ValueError(message)
        ''' Delete fake data for no data nodes '''
        for nodata_node_name in self.graph.nodata_node_names:
            assert nodata_node_name in self.data.keys()
            self.data.pop(nodata_node_name)
[docs]
    def compute_missing_data(self, need_truncate):
        if need_truncate:
            new_data = {k : [] for k in self.data.keys()}
            for node_name in self.graph.transition_map.values():
                new_data[node_name] = []
            for start, end in zip(self._start_indexes, self._end_indexes):
                for k in self.data.keys():
                    new_data[k].append(self.data[k][start:end-1])
            for k in self.data.keys():
                new_data[k] = np.concatenate(new_data[k], axis=0)
        else:
            new_data = deepcopy(self.data)
        for node_name in self.graph.graph_list:
            if node_name not in self.data.keys():
                node = self.graph.get_node(node_name)
                if node.node_type == 'function':
                    # expert processing
                    warnings.warn(f'Detect node {node_name} is not avaliable in the provided data, trying to compute it ...')
                    assert node.node_type == 'function', \
                        
f'You need to provide the function to compute node {node_name} since it is not given in the data!'
                    inputs = {name : new_data[name] for name in node.input_names}
                    convert_func = torch.tensor if node.node_function_type == 'torch' else np.array
                    inputs = {k : convert_func(v) for k, v in inputs.items()}
                    output = node.node_function(inputs)
                    new_data[node_name] = np.array(output).astype(np.float32)
                elif node_name.startswith("next_"):
                    # next_ processing
                    assert need_truncate == True, f"need_truncate == False, however {node_name} is missing in data"
                    ori_name = node_name[5:]
                    warnings.warn(f'transition variable {node_name} is not provided and cannot be computed!')
                    for start, end in zip(self._start_indexes, self._end_indexes):
                        new_data[node_name].append(self.data[ori_name][start+1:end])
                    new_data[node_name] = np.concatenate(new_data[node_name], axis=0)
                elif node.node_type == 'network':
                    # this is for handling empty node 
                    logger.warning(f"Detect {node_name} as empty node !")
                    continue
                else:
                    logger.error(f"Node {node_name} is not neither next_ node nor expert node. And {node_name} is missing in data. Please check the dataset file !")
                    raise NotImplementedError
        if need_truncate:
            self._start_indexes -= np.arange(self._start_indexes.shape[0])
            self._end_indexes -= np.arange(self._end_indexes.shape[0]) + 1
            self._traj_lengths -= 1
            self._min_length -= 1
            self._max_length -= 1
        self.data = Batch(new_data) 
    def _load_data(self, data_file : str):
        '''
            load data from the data file and conduct following processes:
            1. parse trajectory length and start and end indexes from data. 
            if `index` is not provided, consider trajectory length is equal to 1.
            2. if `done` is not in the data, create an all-zero data for it.
            3. try to compute values of unprovided node with expert function.
            4. if any transition variable is not available, truncate the trajectories by 1.
        '''
        if self._raw_data:
            raw_data = self._raw_data
        else:
            raw_data = load_data(data_file)
        
        # logger.error("This is for devlop. Please delete it.")
        # nan_index = np.random.choice(raw_data["temperature"].shape[0], size=500)
        # raw_data["temperature"][nan_index] = np.nan
        # make sure data is in float32
        for k, v in raw_data.items():
            if v.dtype != np.float32:
                raw_data[k] = v.astype(np.float32)
        # mark the start and end of each trajectory
        try:
            self._end_indexes = raw_data.pop('index').astype(int)
        except:
            # if no index, consider the data is with a length of 1
            warnings.warn('key `index` is not provided, assuming data with length 1!')
            self._end_indexes = np.arange(0, raw_data[list(self.graph.keys())[0]].shape[0]) + 1
        
        # check if the index is correct
        assert np.all((self._end_indexes[1:] - self._end_indexes[:-1]) > 0), f'index must be incremental order, but got {self._end_indexes}.'
        for output_node_name in raw_data.keys():
            if output_node_name not in self.graph.tunable:
                datasize = raw_data[output_node_name].shape[0]
                assert datasize == self._end_indexes[-1], \
                    
f'detect index exceed the provided data, the max index is {self._end_indexes[-1]}, but got {output_node_name} in shape {raw_data[output_node_name].shape}.' 
        
        self._start_indexes = np.concatenate([np.array([0]), self._end_indexes[:-1]])
        self._traj_lengths = self._end_indexes - self._start_indexes
        self._min_length = np.min(self._traj_lengths)
        self._max_length = np.max(self._traj_lengths)
        # check if `done` is defined in the data
        if not 'done' in raw_data.keys(): 
            # when done is not available, set to all zero
            warnings.warn('key `done` is not provided, set it to all zero!')
            raw_data['done'] = np.zeros((datasize, 1), dtype=np.float32)
        self.data = Batch(raw_data)
        # check if all the transition variables are in the data
        need_truncate = False
        for node_name in self.graph.transition_map.values():
            if node_name not in self.data.keys() and 'ts' not in node_name:
                warnings.warn(f'transition variable {node_name} is not provided and cannot be computed!')
                need_truncate = True
        if need_truncate:
            for node_name in self.graph.transition_map.values(): # clean out existing variables 
                if node_name in self.data.keys(): 
                    self.data.pop(node_name) 
            warnings.warn('truncating the trajectory by 1 step to generate transition variables!')
            assert self._min_length > 1, 'cannot truncate trajectory with length 1'
        # compute node if they are defined on the graph but not present in the data
        need_compute = False
        for node_name in self.graph.keys():
            if node_name not in self.data.keys():
                warnings.warn(f'Detect node {node_name} is not avaliable in the provided data, trying to compute it ...')
                need_compute = True
        
        # here to process the data
        if need_truncate or need_compute:
            self.compute_missing_data(need_truncate)
        # append mask index for nan value
        self.nan_isin_data = False
        for key in list(self.data.keys()):
            isnan_index = np.isnan(self.data[key]).astype(np.float32) 
            if np.sum(isnan_index) > 0.5:
                logger.warning(f"Find nan value in {key} node data. Auto set the ignore_check=True.")
                self.data[key+"_isnan_index_"] = isnan_index
                df = pd.DataFrame(self.data[key])
                df.fillna(method="bfill", inplace=True)
                df.fillna(method="ffill", inplace=True)
                self.data[key] = df.values
                self.ignore_check = True
                self.nan_isin_data = True
        # parse no data nodes: No data node support for training using the revive algo
        nodata_node_names = []
        for node in self.graph.nodes.values():
            output_node_name = node.name
            if output_node_name not in list(self.data.keys()):
                logger.warning(f'Find no data node. Node type: "{node.node_type}". Node name: "{output_node_name}".')
                if node.node_type == "network":
                    nodata_node_names.append(output_node_name)
                elif node.node_type == "function":
                    # TODO: support automatic calculation of data for function nodes.
                    pass
                else:
                    logger.error(f"Find unknow node_type: {node.node_type}.")
                    raise NotImplementedError
        self.graph.nodata_node_names = nodata_node_names
        self.graph.metric_nodes = list(set(self.graph.metric_nodes) - set(self.graph.nodata_node_names)) 
        return self.data
    
    # NOTE: mode should be set before create dataloader
[docs]
    def transition_mode_(self):
        ''' Set the dataset in transition mode. `__getitem__` will return a transition. '''
        self.end_indexes = np.arange(0, self.data[self.graph.leaf[0]].shape[0]) + 1
        # self.end_indexes = np.arange(0, self.data[list(self.graph.keys())[0]].shape[0]) + 1
        self.start_indexes = np.concatenate([np.array([0]), self.end_indexes[:-1]])
        self.traj_lengths = self.end_indexes - self.start_indexes
        self.min_length = np.min(self.traj_lengths)
        self.max_length = np.max(self.traj_lengths)  
        self.set_horizon(1)
        self.index_to_traj = [0] + list(np.cumsum(self._traj_lengths))
        self.mode = 'transition'
        self.fix_sample = True
        return self 
[docs]
    def trajectory_mode_(self, horizon : Optional[int] = None, fix_sample : bool = False):
        r''' Set the dataset in trajectory mode. `__getitem__` will return a clip of trajectory. '''
        self.end_indexes = self._end_indexes
        self.start_indexes = self._start_indexes
        self.traj_lengths = self._traj_lengths
        self.min_length = self._min_length
        self.max_length = self._max_length
        horizon = horizon or self.min_length
        self.set_horizon(horizon)
        self.index_to_traj = [0] + list(np.cumsum(self.traj_lengths // self.horizon))
        self.mode = 'trajectory'
        self.fix_sample = fix_sample
        return self 
    def _find_trajectory(self, index : int):
        ''' perform binary search for the index of true trajectories from the index of the sample trajectory '''
        left, right = 0, len(self.index_to_traj) - 1
        mid = (left + right) // 2
        while not (index >= self.index_to_traj[mid] and index < self.index_to_traj[mid+1]):
            if index < self.index_to_traj[mid]:
                right = mid - 1
            else:
                left = mid + 1
            mid = (left + right) // 2
        return mid
[docs]
    def set_horizon(self, horizon : int):
        r''' Set the horzion for loading data '''
        if horizon > self.min_length:
            logger.warning(f'Warning: the min length of dataset is {self.min_length}, which is less than the horzion {horizon} you require. ' + \
                          
f'Fallback to use horzion = {self.min_length}.')
        self.horizon = min(horizon, self.min_length)
        logger.info(f"Set trajectory horizon : {self.horizon}") 
[docs]
    def get_dist_configs(self, model_config):
        r'''
        Get the config of distributions for each node based on the given model config.
        Args:
            :model_config: The given model config.
        Return:
            :dist_configs: config of distributions for each node.
            :total_dims: dimensions for each node when it is considered as input and output. 
                        (Output dimensions can be different from input dimensions due to the parameterized distribution)
        '''
        dist_configs = {k : self._get_dist_config(self.data_configs[k], model_config) for k in self.data_configs.keys()}
        total_dims = {k : self._get_dim(dist_configs[k]) for k in dist_configs.keys()}
        return dist_configs, total_dims 
    def _load_config(self, config_file : str):
        """
            load data description from `.yaml` file. Few notes:
            1. the name of each dimension will be discarded since they doesn't help the computation.
            2. dimensions of each node will be reordered (category, discrete, continuous) to speed up computation.
            3. register expert functions and tunable parameters if defined.
        """
        if self._raw_config:
            raw_config = self._raw_config
        else:
            with open(config_file, 'r', encoding='UTF-8') as f:
                raw_config = yaml.load(f, Loader=yaml.FullLoader)
        # collect description for the same node
        data_config = raw_config['metadata']['columns']
        self.columns = data_config
        keys = set([list(d.values())[0]['dim'] for d in data_config])
        raw_columns = {}
        config = {}
        order = {}
        fit = {}
        
        for config_key in keys:
            raw_columns[config_key], config[config_key], order[config_key], fit[config_key] = self._load_config_for_single_node(data_config, config_key)
        # parse graph
        graph_dict = raw_config['metadata'].get('graph', None)
        # parse metric_nodes
        metric_nodes = raw_config['metadata'].get('metric_nodes', None)
        graph = DesicionGraph(graph_dict, raw_columns, fit, metric_nodes)
        graph.ts_node_frames = self.ts_node_frames
        graph.ts_frames_config = self.ts_frames_config
        # copy the raw columns for transition variables to allow them as input to other nodes
        for curr_name, next_name in graph.transition_map.items():
            raw_columns[next_name] = raw_columns[curr_name]
        # if you use next_obs in graph, without obs as input, then next_obs would lose its description. [typically if you ts_obs instead of obs]
        # following to fix the above
        # for node_name in list(graph.keys()):
        #     if node_name.startswith("next_"):
        #         ori_name = node_name[5:]
        #         if ori_name not in graph.leaf and "ts_"+ori_name in graph.leaf:
        #             raw_columns[node_name] = raw_columns[ori_name]
        # mark tunable parameters
        for node_name in raw_config['metadata'].get('tunable', []): graph.mark_tunable(node_name)
        
        # expert_functions = raw_config['metadata'].get('expert_functions',{})
        custom_nodes = raw_config['metadata'].get('custom_nodes', None)
        def get_function_type(file_path : str, function_name : str, file_name: str) -> str:
            '''get the function type from type hint'''
            with open(file_path, 'r') as f:
                for line in f.readlines():
                    if line.startswith('def ') and function_name in line:
                        if 'Tensor' in line:
                            return 'torch'
                        elif 'ndarray' in line:
                            return 'numpy'
                        else:
                            warnings.warn('Type hint is not provided, assume it is an torch function!')
                            return 'torch'
            raise ValueError(f'Cannot find function {function_name} in {file_name}.py, please check your yaml!')
        
        # later = find_later(self.config_file, 'data')
        # head = '.'.join(later[:-1])
        # register expert functions to the graph
        if self.expert_functions is not None:
            for node_name, function_description in self.expert_functions.items():
                # NOTE: currently we assume the expert functions are also placed in the same folder as the yaml file.
                if 'node_function' in function_description.keys(): # `node function` should be like [file].[function_name]`
                    if function_description['node_function'] in ["graph.delta_node_function",]:
                        graph.register_node(node_name, FunctionDecisionNode)
                        func_str = function_description['node_function'].replace("graph","graph.get_node(node_name)")
                        node_function_type="torch"
                        graph.get_node(node_name).register_node_function(eval(func_str), node_function_type=node_function_type)
                        logger.info(f'register node function ({node_function_type} version) for {node_name}')
                    else:
                        graph.register_node(node_name, FunctionDecisionNode)
                        file_name, function_name = function_description['node_function'].split('.')
                        file_path = os.path.join(os.path.dirname(self.config_file), file_name + '.py')
                        function_type = get_function_type(file_path, function_name, file_name)
                        parse_file_path = file_path[:-3]+"_parsed.py"
                        if not parser(file_path,parse_file_path,self.config_file):
                            parse_file_path = file_path
                        function_type = get_function_type(parse_file_path, function_name, file_name)
                        file_name = os.path.split(os.path.splitext(parse_file_path)[0])[-1]
                        sys.path.insert(0, os.path.dirname(parse_file_path))
                        source_file = importlib.import_module(f'{file_name}')
                        func = eval(f'source_file.{function_name}')
                        graph.get_node(node_name).register_node_function(func, function_type)
                        logger.info(f'register node function ({function_type} version) for {node_name}')
        for node_name in list(graph.keys()):
            if node_name.startswith("next_ts_"):
                ori_name = node_name[8:]
                if set(graph.graph_dict[node_name]) == set(["ts_"+ori_name, "next_"+ori_name]) and (node_name not in self.expert_functions):
                    # automatically register expert functions to this node
                    graph.register_node(node_name, FunctionDecisionNode)
                    func_description = "next_ts_transition_function.next_ts_placeholder_transition_function"
                    file_name, function_name = func_description.split('.')
                    function_name = function_name.replace("placeholder", ori_name)
                    shutil.copyfile(os.path.join(os.path.dirname(revive.__file__),"./common/next_ts_transition_function.py"),  os.path.join(os.path.dirname(self.config_file), "./next_ts_transition_function.py"))
                    file_path = os.path.join(os.path.dirname(self.config_file), file_name + '.py')
                    with open(file_path, 'r+') as file:
                        data = file.read()
                        genral_func = data[data.find('def'):]
                        new_func = '\n\n' + genral_func.replace("placeholder", ori_name)
                        file.write(new_func)
                    function_type = get_function_type(file_path, function_name, file_name)
                    sys.path.insert(0, os.path.dirname(file_path))
                    time.sleep(1)
                    source_file = importlib.import_module(f'{file_name}')
                    func = eval(f'source_file.{function_name}')
                    graph.get_node(node_name).register_node_function(func, function_type)
                    logger.info(f'register node function ({function_type} version) for {node_name}')
                if set(graph.graph_dict[node_name]) == set(["ts_"+ori_name, ori_name]) and (node_name not in self.expert_functions):
                    # automatically register expert functions to this node
                    graph.register_node(node_name, FunctionDecisionNode)
                    func_description = "next_ts_policy_function.next_ts_placeholder_policy_function"
                    file_name, function_name = func_description.split('.')
                    function_name = function_name.replace("placeholder", ori_name)
                    shutil.copyfile(os.path.join(os.path.dirname(revive.__file__),"./common/next_ts_policy_function.py"),  os.path.join(os.path.dirname(self.config_file), "./next_ts_policy_function.py"))
                    file_path = os.path.join(os.path.dirname(self.config_file), file_name + '.py')
                    with open(file_path, 'r+') as file:
                        data = file.read()
                        genral_func = data[data.find('def'):]
                        new_func = '\n\n' + genral_func.replace("placeholder", ori_name)
                        file.write(new_func)
                    function_type = get_function_type(file_path, function_name, file_name)
                    sys.path.insert(0, os.path.dirname(file_path))
                    time.sleep(1)
                    source_file = importlib.import_module(f'{file_name}')
                    func = eval(f'source_file.{function_name}')
                    graph.get_node(node_name).register_node_function(func, function_type)
                    logger.info(f'register node function ({function_type} version) for {node_name}')
            elif node_name.startswith("next_"):
                ori_name = node_name[5:]
                if {ori_name, f"delta_{ori_name}"} == set(graph.graph_dict[node_name]):
                    # automatically register delta functions to this node
                    graph.register_node(node_name, FunctionDecisionNode)
                    func_description = f"next_function.next_placeholder_function"
                    file_name, function_name = func_description.split('.')
                    function_name = function_name.replace("placeholder", ori_name)
                    file_path = os.path.join(os.path.dirname(self.config_file), file_name + '.py')
                    if not os.path.isfile(file_path):
                        shutil.copyfile(os.path.join(os.path.dirname(revive.__file__),"./common/next_function.py"),  os.path.join(os.path.dirname(self.config_file), "./next_function.py"))
                    with open(file_path, 'r+') as file:
                        data = file.read()
                        genral_func = data[data.find('def'):]
                        new_func = '\n\n' + genral_func.replace("placeholder", ori_name)
                        file.write(new_func)
                    function_type = get_function_type(file_path, function_name, file_name)
                    sys.path.insert(0, os.path.dirname(file_path))
                    time.sleep(1)
                    source_file = importlib.import_module(f'{file_name}')
                    func = eval(f'source_file.{function_name}')
                    graph.get_node(node_name).register_node_function(func, function_type)
                    logger.info(f'register node function ({function_type} version) for {node_name}')
        # register custom nodes to the graph
        if custom_nodes is not None:
            for node_name, custom_node in custom_nodes.items():
                # NOTE: currently we assume the custom nodes are also placed in the same folder as the yaml file.
                # `custom node should be given in the form of [file].[node_class_name]`
                file_name, class_name = custom_node.split('.')
                file_path = os.path.join(os.path.dirname(self.config_file), file_name + '.py')
                sys.path.insert(0, os.path.dirname(file_path))
                source_file = importlib.import_module(f'{file_name}')
                node_class = eval(f'source_file.{class_name}')
                graph.register_node(node_name, node_class)
                logger.info(f'register custom node `{node_class}` for {node_name}')
        # register other nodes with the default `NetworkDecisionNode`
        for node_name, node in graph.nodes.items():
            if node is None:
                graph.register_node(node_name, NetworkDecisionNode)
                logger.info(f'register the default `NetworkDecisionNode` for {node_name}')
        self.raw_columns, self.data_configs, self.orders, self.graph, self.fit = raw_columns, config, order, graph, fit
    def _load_config_for_single_node(self, raw_config : list, node_name : str):
        '''
            load config for a single node. 
            :return 
                raw_config: columns belong to this node
                config: collected columns with type in order: category, discrete, continuous 
                order: order of the index to convert to the collected columns
        '''
        raw_config = list(filter(lambda d: list(d.values())[0]['dim'] == node_name, raw_config))      
        config = []
        discrete_count = 0
        discrete_min = []
        discrete_max = []
        discrete_num = []
        fit = []
        continuous_count = 0
        continuous_min = []
        continuous_max = []
        order = []
        for index, d in enumerate(raw_config):
            name = list(d.keys())[0]
            _config = d[name]
            if _config['type'] == 'category':
                assert 'values' in _config.keys(), f'Parsing columns for node `{node_name}`, you must provide `values` for a `category` dimension.'
                _config['dim'] = len(_config['values'])
                order.append((index, 1))
                config.append(_config)
            elif _config['type'] == 'continuous':
                order.append((index, 3))
                continuous_count += 1
                continuous_max.append(_config.get('max', None))
                continuous_min.append(_config.get('min', None))
            elif _config['type'] == 'discrete':
                assert 'num' in _config.keys() and _config['num'] > 1, f'Parsing columns for node `{node_name}`, you must provide `num` > 1 for a `discrete` dimension.'
                discrete_count += 1 
                order.append((index, 2))
                discrete_max.append(_config.get('max', None))
                discrete_min.append(_config.get('min', None))
                discrete_num.append(_config['num'])
            else:
                logger.error(f"Data type {_config['type']} is not support. Please check the yaml file.")
                raise NotImplementedError
            if "fit" in _config.keys() and not _config["fit"]:
                if _config['type'] == 'category':
                    fit += [False,]*len(_config['values'])
                else:
                    fit.append(False)
            else:
                if _config['type'] == 'category':
                    fit += [True,]*len(_config['values'])
                else:
                    fit.append(True)
        order = sorted(order, key=lambda x: x[1])
        forward_order = [o[0] for o in order]
        order = [(ordered_index, origin_index) for ordered_index, origin_index in enumerate(forward_order)]
        order = sorted(order, key=lambda x: x[1])
        backward_order = [o[0] for o in order]
        order = {
            'forward' : forward_order,
            'backward' : backward_order,
        }
        
        if discrete_count > 0:
            config.append({'type' : 'discrete', "dim" : discrete_count, 'max' : discrete_max, 'min' : discrete_min, 'num' : discrete_num})
        
        if continuous_count > 0:
            config.append({'type' : 'continuous', 'dim' : continuous_count, 'max' : continuous_max, 'min' : continuous_min})
        return raw_config, config, order, fit
    def _get_dist_config(self, data_config : List[Dict[str, Any]], model_config : Dict[str, Any]):
        dist_config = []
        for config in data_config:
            config = config.copy()
            if config['type'] == 'category':
                assert model_config['category_distribution_type'] in ['onehot'], \
                    
f"distribution type {model_config['category_distribution_type']} is not support for category variables!"
                config['type'] = model_config['category_distribution_type']
                config['output_dim'] = config['dim']
            elif config['type'] == 'discrete':
                assert model_config['discrete_distribution_type'] in ['gmm', 'normal', 'discrete_logistic'], \
                    
f"distribution type {model_config['discrete_distribution_type']} is not support for discrete variables!"
                config['type'] = model_config['discrete_distribution_type']
                if config['type'] == 'discrete_logistic':
                    config['output_dim'] = config['dim'] * 2
                    config['num'] = config['num']
                elif config['type'] == 'normal':
                    config['conditioned_std'] = model_config['conditioned_std']
                    config['output_dim'] = config['dim'] * (1 + config['conditioned_std'])
                elif config['type'] == 'gmm':
                    config['mixture'] = model_config['mixture']
                    config['conditioned_std'] = model_config['conditioned_std']
                    config['output_dim'] = config['mixture'] * ((1 + config['conditioned_std']) * config['dim'] + 1)
            else:
                assert model_config['continuous_distribution_type'] in ['gmm', 'normal'], \
                    
f"distribution type {model_config['continuous_distribution_type']} is not support for discrete variables!"
                config['type'] = model_config['continuous_distribution_type']
                if config['type'] == 'normal':
                    config['conditioned_std'] = model_config['conditioned_std']
                    config['output_dim'] = config['dim'] * (1 + config['conditioned_std'])
                elif config['type'] == 'gmm':
                    config['mixture'] = model_config['mixture']
                    config['conditioned_std'] = model_config['conditioned_std']
                    config['output_dim'] = config['mixture'] * ((1 + config['conditioned_std']) * config['dim'] + 1)
            dist_config.append(config)
        return dist_config
    def _get_process_params(self, data : np.ndarray, data_config : Dict[str, Union[int, str, List[float]]], order : List[int]):
        ''' get necessary parameters for data processor '''
        additional_parameters = []
        standardization_parameters = []
        forward_slices = []
        backward_slices = []
        forward_start_index = 0
        backward_start_index = 0
        if data is None:
            total_dims = sum([config["dim"] for config in data_config])
            for dim in range(total_dims):
                additional_parameters.append((np.array(0).astype(np.float32), np.array(1).astype(np.float32)))
                standardization_parameters.append((np.array(0).astype(np.float32), np.array(1).astype(np.float32)))
            forward_slices.append(slice(0,total_dims))
            backward_slices.append(slice(0,total_dims))
            return {
                'forward_slices' : forward_slices,
                'backward_slices' : backward_slices,
                'additional_parameters' : additional_parameters,
                'standardization_parameters' : standardization_parameters,
            }
        data = data.copy()
        data = data.take(order['forward'], axis=-1)
        for config in data_config:
            if config['type'] == 'category':
                forward_end_index = forward_start_index + 1
                _data = data[:, forward_start_index : forward_end_index]
                _data = _data.reshape((-1, _data.shape[-1]))
                data_mean = _data.mean(axis=0)
                data_std = _data.std(axis=0)
                additional_parameters.append(np.array(config['values']).astype(np.float32))
                # TODO
                standardization_parameters.append((np.array(0).astype(np.float32), np.array(1).astype(np.float32)))
            elif config['type'] == 'continuous':
                forward_end_index = forward_start_index + config['dim']
                _data = data[:, forward_start_index : forward_end_index]
                _data = _data.reshape((-1, _data.shape[-1]))
                data_max = _data.max(axis=0)
                data_min = _data.min(axis=0)
                data_mean = _data.mean(axis=0)
                data_std = _data.std(axis=0)
                for i in range(config['dim']):
                    if config['max'][i] is None: config['max'][i] = data_max[i]
                    if config['min'][i] is None: config['min'][i] = data_min[i]
                max_num = np.array(config['max']).astype(np.float32)
                min_num = np.array(config['min']).astype(np.float32)
                interval = max_num - min_num
                interval[interval==0] = 2 # prevent dividing zero
                additional_parameters.append(((max_num + min_num) / 2, 0.5 * interval))
                standardization_parameters.append((data_mean, data_std))
            elif config['type'] == 'discrete':
                forward_end_index = forward_start_index + config['dim']
                _data = data[:, forward_start_index : forward_end_index]
                _data = _data.reshape((-1, _data.shape[-1]))
                data_max = _data.max(axis=0)
                data_min = _data.min(axis=0)
                data_mean = _data.mean(axis=0)
                data_std = _data.std(axis=0)
                for i in range(config['dim']):
                    if config['max'][i] is None: config['max'][i] = data_max[i]
                    if config['min'][i] is None: config['min'][i] = data_min[i]
                max_num = np.array(config['max']).astype(np.float32)
                min_num = np.array(config['min']).astype(np.float32)
                interval = max_num - min_num
                interval[interval==0] = 2 # prevent dividing zero
                additional_parameters.append(((max_num + min_num) / 2, 0.5 * interval, np.array(config['num'])))
                # TODO
                standardization_parameters.append((np.array(0).astype(np.float32), np.array(1).astype(np.float32)))
            backward_end_index = backward_start_index + config['dim']
            forward_slices.append(slice(forward_start_index, forward_end_index))
            backward_slices.append(slice(backward_start_index, backward_end_index))
            forward_start_index = forward_end_index
            backward_start_index = backward_end_index
        return {
            'forward_slices' : forward_slices,
            'backward_slices' : backward_slices,
            'additional_parameters' : additional_parameters,
            'standardization_parameters' : standardization_parameters,
        }
    def _get_dim(self, dist_configs):
        return {
            'input' : sum([d['dim'] for d in dist_configs]),
            'output' : sum([d['output_dim'] for d in dist_configs])
        }
[docs]
    def __len__(self) -> int:
        return np.sum(self.traj_lengths // self.horizon) 
[docs]
    def __getitem__(self, index : int, raw : bool = False) -> Batch:
        if self.mode == 'trajectory':
            traj_index = self._find_trajectory(index)
            if self.fix_sample: # fix the starting point of each slice in the trajectory
                start_index = self.start_indexes[traj_index] + self.horizon * (index - self.index_to_traj[traj_index])
            else: # randomly sample valid start point from the trajectory
                length = self.end_indexes[traj_index] - self.start_indexes[traj_index]
                start_index = self.start_indexes[traj_index] + np.random.randint(0, length - self.horizon + 1)
            raw_data = self.data[start_index : (start_index + self.horizon)]
            if self.tunable_data is not None:
                raw_data.update(self.tunable_data[traj_index])
                # tunable_data = Batch()
                # for tunable, data in self.tunable_data.items():
                #     tunable_data[tunable] = data[traj_index][np.newaxis].repeat(self.horizon, axis=0)
                # raw_data.update(tunable_data)
            # TODO: Update
            # Skip nan in start index data
            if self.nan_isin_data:
                for node in self.graph.get_leaf():
                    if node + "_isnan_index_" in raw_data.keys():
                        if np.sum(raw_data[node + "_isnan_index_"][:1]) > 0.5:
                            return self.__getitem__(np.random.choice(np.arange(self.__len__())))
            
        elif self.mode == 'transition':
            raw_data = self.data[index]
            if self.tunable_data is not None:
                traj_index = self._find_trajectory(index)
                raw_data.update(self.tunable_data[traj_index])
        if raw:
            return raw_data
        
        return self.processor.process(raw_data) 
[docs]
    def split(self, ratio : float = 0.5, mode : str = 'outside_traj', recall : bool = False) -> Tuple['OfflineDataset', 'OfflineDataset']:
        r''' split the dataset into train and validation with the given ratio and mode
        
        Args:
            :ratio: Ratio to split validate dataset if it is not explicitly given.
            :mode: Mode of auto splitting training and validation dataset, choose from `outside_traj` and `inside_traj`. 
                  `outside_traj` means the split is happened outside the trajectories, one trajectory can only be in one dataset. ' +
                  `inside_traj` means the split is happened inside the trajectories, former part of one trajectory is in training set, later part is in validation set.
        Return: 
            (TrainDataset, ValidateDataset)
        
        '''
        val_dataset = deepcopy(self)
        rng = np.random.default_rng(seed=42)
        if mode == 'outside_traj':
            total_traj_num = len(self._start_indexes)
            val_traj_num = int(total_traj_num * ratio)
            train_traj_num = total_traj_num - val_traj_num
            if not (val_traj_num > 0 and train_traj_num > 0):
                message = f'Cannot split a dataset with {total_traj_num} trajectories to {train_traj_num} (training) and {val_traj_num} (validation).'
                if recall:
                    raise RuntimeError(message)
                else:
                    warnings.warn(message)
                    warnings.warn('Fallback to `inside_traj` mode!')
                    return self.split(ratio=ratio, mode='inside_traj', recall=True)
            total_traj_index = list(range(total_traj_num))
            rng.shuffle(total_traj_index)
            val_traj_index = sorted(total_traj_index[:val_traj_num])
            train_traj_index = sorted(total_traj_index[val_traj_num:])
            self._rebuild_from_traj_index(train_traj_index)
            val_dataset._rebuild_from_traj_index(val_traj_index)
        elif mode == 'inside_traj':
            training_slices = []
            validation_slices = []
            for total_point in self._traj_lengths:
                slice_point = int(total_point * (1 - ratio))
                train_point = slice_point
                validation_point = total_point - slice_point
                if not (train_point > 0 and validation_point > 0):
                    message = f'Cannot split a trajectory with {total_point} steps to {train_point} (training) and {validation_point} (validation).'
                    if recall:
                        raise RuntimeError(message)
                    else:
                        warnings.warn(message)
                        warnings.warn('Fallback to `outside_traj` mode!')
                        return self.split(ratio=ratio, mode='outside_traj', recall=True)
                training_slices.append(slice(0, slice_point))
                validation_slices.append(slice(slice_point, total_point))
            self._rebuild_from_slices(training_slices)
            val_dataset._rebuild_from_slices(validation_slices)    
        else:
            raise ValueError(f'Split mode {mode} is not understood, please check your config!')
        
        return self, val_dataset 
    def _rebuild_from_traj_index(self, traj_indexes : List[int]) -> 'OfflineDataset':
        ''' rebuild the dataset by subsampling the trajectories '''
        # rebuild data
        new_data = []
        for traj_index in traj_indexes:
            start = self._start_indexes[traj_index]
            end = self._end_indexes[traj_index]
            new_data.append(self.data[start:end])
        self.data = Batch.cat(new_data)
        # rebuild index
        self._traj_lengths = self._end_indexes[traj_indexes] - self._start_indexes[traj_indexes]
        self._end_indexes = np.cumsum(self._traj_lengths)
        self._start_indexes = np.concatenate([np.array([0]), self._end_indexes[:-1]])
        self._min_length = np.min(self._traj_lengths)
        self._max_length = np.max(self._traj_lengths)
        self.trajectory_mode_() if self.mode == 'trajectory' else self.transition_mode_()
        return self
    def _rebuild_from_slices(self, slices : List[slice]) -> 'OfflineDataset':
        ''' rebuild the dataset by slicing inside the trajectory '''
        assert len(self._traj_lengths) == len(slices)
        # rebuild data
        new_data = []
        lengths = []
        for s, start, end in zip(slices, self._start_indexes, self._end_indexes):
            data = self.data[start:end][s]
            lengths.append(data.shape[0])
            new_data.append(data)
        self.data = Batch.cat(new_data)
        # rebuild index
        self._traj_lengths = np.array(lengths)
        self._end_indexes = np.cumsum(self._traj_lengths)
        self._start_indexes = np.concatenate([np.array([0]), self._end_indexes[:-1]])
        self._min_length = np.min(self._traj_lengths)
        self._max_length = np.max(self._traj_lengths)
        self.trajectory_mode_() if self.mode == 'trajectory' else self.transition_mode_()
        return self 
[docs]
class RNNOfflineDataset(OfflineDataset):
[docs]
    def trajectory_mode_(self, horizon : Optional[int] = None, fix_sample : bool = False):
        r''' Set the dataset in trajectory mode. `__getitem__` will return a clip of trajectory. '''
        self.end_indexes = self._end_indexes
        self.start_indexes = self._start_indexes
        self.traj_lengths = self._traj_lengths
        self.min_length = self._min_length
        self.max_length = self._max_length
        # horizon = horizon or self.min_length
        # self.set_horizon(horizon)
        # self.index_to_traj = [0] + list(np.cumsum(self.traj_lengths // self.horizon))
        self.mode = 'trajectory'
        self.fix_sample = fix_sample
        return self 
[docs]
    def __len__(self) -> int:
        return len(self.start_indexes) 
[docs]
    def __getitem__(self, index : int, raw : bool = False) -> Batch:
        if self.mode == 'trajectory':
            # traj_index = self._find_trajectory(index)
            # if self.fix_sample: # fix the starting point of each slice in the trajectory
            #     start_index = self.start_indexes[traj_index] + self.horizon * (index - self.index_to_traj[traj_index])
            # else: # randomly sample valid start point from the trajectory
            #     length = self.end_indexes[traj_index] - self.start_indexes[traj_index]
            #     start_index = self.start_indexes[traj_index] + np.random.randint(0, length - self.horizon + 1)
            traj_index = index
            start_index = self.start_indexes[index]
            end_index = self.end_indexes[index]
            raw_data = self.data[start_index : end_index]
            if self.tunable_data is not None:
                raw_data.update(self.tunable_data[traj_index])
            # TODO: Update
            # Skip nan in start index data
            if self.nan_isin_data:
                for node in self.graph.get_leaf():
                    if node + "_isnan_index_" in raw_data.keys():
                        if np.sum(raw_data[node + "_isnan_index_"][:1]) > 0.5:
                            return self.__getitem__(np.random.choice(np.arange(self.__len__())))
        if raw:
            return raw_data
        
        return self.processor.process(raw_data) 
 
[docs]
class InfiniteDataLoader:
    r""" Wrapper that enables infinite pre-fetching, must use together with InfiniteUniformSampler"""
    def __init__(self, dataloader : torch.utils.data.DataLoader):
        self.dataloader = dataloader
        self.dataset = self.dataloader.dataset
        self.iter = iter(self.dataloader)
    def __iter__(self):
        return iter([next(self.iter)]) 
[docs]
def collect_data(expert_data : List[Batch], graph : DesicionGraph) -> Batch:
    r''' Collection function for PyTorch DataLoader '''
    expert_data = Batch.stack(expert_data, axis=-2)
    expert_data.to_torch()
    if graph.transition_map:
        selected_name = list(graph.transition_map.keys())[0]
        if len(expert_data[selected_name].shape) == 3:
            for tunable_name in graph.tunable:
                expert_data[tunable_name] = expert_data[tunable_name].expand(expert_data[selected_name].shape[0], *[-1] * len(expert_data[tunable_name].shape))
    return expert_data 
[docs]
def pad_collect_data(expert_data : List[Batch], graph : DesicionGraph) -> Batch:
    r''' Collection function for PyTorch DataLoader '''
    batch_size = len(expert_data)
    # print(f"len: {[data.shape[0] for data in expert_data]}")
    max_len = max([data.shape[0] for data in expert_data])
    loss_mask = np.ones((max_len, batch_size, 1))
    for index in range(batch_size):
        data = expert_data[index]
        padding_len = max_len - data.shape[0]
        if padding_len > 0:
            padding = ((0, padding_len), (0, 0))
            for k, v in data.items():
                data[k] = np.pad(v, padding)
            loss_mask[-padding_len:, index] = 0.
    expert_data = Batch.stack(expert_data, axis=-2)
    expert_data.to_torch()
    loss_mask = torch.as_tensor(loss_mask)
    if graph.transition_map:
        selected_name = list(graph.transition_map.keys())[0]
        if len(expert_data[selected_name].shape) == 3:
            for tunable_name in graph.tunable:
                expert_data[tunable_name] = expert_data[tunable_name].expand(expert_data[selected_name].shape[0], *[-1] * len(expert_data[tunable_name].shape))
    return expert_data, loss_mask 
[docs]
def get_loader(dataset : OfflineDataset, config : dict, is_sample : bool = True, rnn=False):
    """ Get the PoTorch DataLoader for training """
    batch_size = config[BATCH_SIZE]
    collate_fn = partial(pad_collect_data, graph=dataset.graph) if rnn else partial(collect_data, graph=dataset.graph)
    if is_sample:
        loader = torch.utils.data.DataLoader(dataset, batch_sampler=InfiniteUniformSampler(dataset, batch_size),
                                            collate_fn=collate_fn, pin_memory=True, num_workers=config['data_workers'])
        loader = InfiniteDataLoader(loader)
    else:
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
                                            collate_fn=collate_fn, pin_memory=True, num_workers=config['data_workers'])
    return loader 
[docs]
def data_creator(config : dict, 
                 training_mode : str = 'trajectory', 
                 training_horizon : int = None,
                 training_is_sample : bool = True,
                 val_mode : str = 'trajectory',
                 val_horizon : int = None,
                 val_is_sample : bool = False,
                 pre_horzion: int = 0,
                 double : bool = False):
    """
        Get train data loader and validation data loader.
        :return: train data loader and validation data loader
    """        
    train_dataset = ray.get(config['dataset'])
    val_dataset = ray.get(config['val_dataset'])
    config['dist_configs'], config['total_dims'] = train_dataset.get_dist_configs(config)
    config['learning_nodes_num'] = train_dataset.learning_nodes_num
    if training_horizon is None and val_horizon is not None:
        training_horizon = val_horizon
    if training_horizon is not None and val_horizon is None:
        val_horizon = training_horizon
    
    # Pre horizion defines the horizion used for preprocessing, which must be less than the entire sequence length
    if pre_horzion > 0:
        if training_mode == "transition":
            logger.error("When using pre_horzion > 0, dataset loader mode must be set to 'trajectory'!")
            import time
            time.sleep(3)
            sys.exit()
        elif training_mode == "trajectory":
            assert training_horizon > pre_horzion
            assert val_horizon > pre_horzion
        else:
            raise ModuleNotFoundError
    
    if not double:
        train_dataset = train_dataset.trajectory_mode_(training_horizon) if training_mode == 'trajectory' else train_dataset.transition_mode_()
        val_dataset = val_dataset.trajectory_mode_(val_horizon) if val_mode == 'trajectory' else val_dataset.transition_mode_()
        train_loader = get_loader(train_dataset, config, training_is_sample)
        val_loader = get_loader(val_dataset, config, val_is_sample)
        return train_loader, val_loader
    else: # perform double venv training
        ''' NOTE: train_dataset_val means training set used in validation '''
        train_dataset_train = deepcopy(train_dataset)
        val_dataset_train = deepcopy(val_dataset)
        train_dataset_val = deepcopy(train_dataset)
        val_dataset_val = deepcopy(val_dataset)
        train_dataset_train = train_dataset_train.trajectory_mode_(training_horizon) if training_mode == 'trajectory' else train_dataset_train.transition_mode_()
        val_dataset_train = val_dataset_train.trajectory_mode_(training_horizon) if training_mode == 'trajectory' else val_dataset_train.transition_mode_()
        train_dataset_val = train_dataset_val.trajectory_mode_(val_horizon) if val_mode == 'trajectory' else train_dataset_val.transition_mode_()
        val_dataset_val = val_dataset_val.trajectory_mode_(val_horizon) if val_mode == 'trajectory' else val_dataset_val.transition_mode_()
        train_loader_train = get_loader(train_dataset_train, config, training_is_sample)
        val_loader_train = get_loader(val_dataset_train, config, training_is_sample)
        train_loader_val = get_loader(train_dataset_val, config, val_is_sample)
        val_loader_val = get_loader(val_dataset_val, config, val_is_sample)
        return train_loader_train, val_loader_train, train_loader_val, val_loader_val 
[docs]
def revive_f_rnn_data_creator(config : dict, 
                            training_mode : str = 'trajectory', 
                            training_horizon : int = None,
                            training_is_sample : bool = True,
                            val_mode : str = 'trajectory',
                            val_horizon : int = None,
                            val_is_sample : bool = False,
                            pre_horzion: int = 0,
                            double : bool = False):
    """
        Get train data loader and validation data loader.
        :return: train data loader and validation data loader
    """        
    mail_train_dataset = ray.get(config['dataset'])
    mail_val_dataset = ray.get(config['val_dataset'])
    bc_train_dataset = ray.get(config['rnn_dataset'])
    bc_val_dataset = ray.get(config['rnn_val_dataset'])
    config['dist_configs'], config['total_dims'] = mail_train_dataset.get_dist_configs(config)
    config['learning_nodes_num'] = mail_train_dataset.learning_nodes_num
    if training_horizon is None and val_horizon is not None:
        training_horizon = val_horizon
    if training_horizon is not None and val_horizon is None:
        val_horizon = training_horizon
    
    # Pre horizion defines the horizion used for preprocessing, which must be less than the entire sequence length
    if pre_horzion > 0:
        if training_mode == "transition":
            logger.error("When using pre_horzion > 0, dataset loader mode must be set to 'trajectory'!")
            import time
            time.sleep(3)
            sys.exit()
        elif training_mode == "trajectory":
            assert training_horizon > pre_horzion
            assert val_horizon > pre_horzion
        else:
            raise ModuleNotFoundError
    
    if not double:
        mail_train_dataset = mail_train_dataset.trajectory_mode_(training_horizon)
        mail_val_dataset = mail_val_dataset.trajectory_mode_(val_horizon)
        bc_train_dataset = bc_train_dataset.trajectory_mode_(training_horizon)
        bc_val_dataset = bc_val_dataset.trajectory_mode_(val_horizon)
        mail_train_loader = get_loader(mail_train_dataset, config, training_is_sample)
        mail_val_loader = get_loader(mail_val_dataset, config, val_is_sample)
        bc_train_loader = get_loader(bc_train_dataset, config, training_is_sample, rnn=True)
        bc_val_loader = get_loader(bc_val_dataset, config, val_is_sample, rnn=True)
        return mail_train_loader, mail_val_loader, bc_train_loader, bc_val_loader
    else: # perform double venv training
        ''' NOTE: train_dataset_val means training set used in validation '''
        mail_train_dataset_train = deepcopy(mail_train_dataset)
        mail_val_dataset_train = deepcopy(mail_val_dataset)
        mail_train_dataset_val = deepcopy(mail_train_dataset)
        mail_val_dataset_val = deepcopy(mail_val_dataset)
        
        mail_train_dataset_train = mail_train_dataset_train.trajectory_mode_(training_horizon)
        mail_val_dataset_train = mail_val_dataset_train.trajectory_mode_(training_horizon)
        mail_train_dataset_val = mail_train_dataset_val.trajectory_mode_(val_horizon)
        mail_val_dataset_val = mail_val_dataset_val.trajectory_mode_(val_horizon)
        mail_train_loader_train = get_loader(mail_train_dataset_train, config, training_is_sample)
        mail_val_loader_train = get_loader(mail_val_dataset_train, config, training_is_sample)
        mail_train_loader_val = get_loader(mail_train_dataset_val, config, val_is_sample)
        mail_val_loader_val = get_loader(mail_val_dataset_val, config, val_is_sample)
        # *******
        bc_train_dataset_train = deepcopy(bc_train_dataset)
        bc_val_dataset_train = deepcopy(bc_val_dataset)
        bc_train_dataset_val = deepcopy(bc_train_dataset)
        bc_val_dataset_val = deepcopy(bc_val_dataset)
        bc_train_dataset_train = bc_train_dataset_train.trajectory_mode_(training_horizon)
        bc_val_dataset_train = bc_val_dataset_train.trajectory_mode_(training_horizon)
        bc_train_dataset_val = bc_train_dataset_val.trajectory_mode_(val_horizon)
        bc_val_dataset_val = bc_val_dataset_val.trajectory_mode_(val_horizon)
        bc_train_loader_train = get_loader(bc_train_dataset_train, config, training_is_sample, rnn=True)
        bc_val_loader_train = get_loader(bc_val_dataset_train, config, training_is_sample, rnn=True)
        bc_train_loader_val = get_loader(bc_train_dataset_val, config, val_is_sample, rnn=True)
        bc_val_loader_val = get_loader(bc_val_dataset_val, config, val_is_sample, rnn=True)
        return mail_train_loader_train, mail_val_loader_train, mail_train_loader_val, mail_val_loader_val, \
                
bc_train_loader_train, bc_val_loader_train, bc_train_loader_val, bc_val_loader_val 
if __name__ == '__main__':
    dataset = RNNOfflineDataset("/home/ubuntu/chenjiawei/Refrigerator/data/refrigeration_test.npz", "/home/ubuntu/chenjiawei/Refrigerator/data/refrigeration.yaml")
    batch_size = 10
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
                                                collate_fn=partial(pad_collect_data, graph=dataset.graph), pin_memory=True, num_workers=0)
    data, loss_mask = next(iter(loader))
    breakpoint()
    print(data)
    print(data[:, 0].obs)
    
    single_data = dataset.__getitem__(np.random.randint(len(dataset)), raw=True)
    processed_data = dataset.processor.process(single_data)
    deprocessed_data = dataset.processor.deprocess(processed_data)
    processor = dataset.processor
    processed_obs = processor.process_single(single_data.obs, 'obs')
    deprocessed_obs = processor.deprocess_single(processed_obs, 'obs')
    for k in single_data.keys():
        assert np.all(np.isclose(deprocessed_data[k], single_data[k], atol=1e-6)), [k, deprocessed_data[k] - single_data[k]]
    assert np.all(np.isclose(deprocessed_data.obs, deprocessed_obs, atol=1e-6)), [processed_data.obs - processed_obs]
    plot_traj(single_data)
    # test sampler
    data = torch.rand(1000, 4)
    dataset = torch.utils.data.TensorDataset(data)
    sampler = UniformSampler(dataset, 3)
    loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
    for _ in range(10):
        for b in loader:
            print(b)