''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import importlib
import json
import os
import shutil
import sys
import warnings
from copy import deepcopy
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
from matplotlib.pyplot import step
import time
import numpy as np
import pandas as pd
import ray
import revive
import torch
import yaml
from loguru import logger
#from revive.utils.raysgd_utils import BATCH_SIZE
BATCH_SIZE = "*batch_size"
from revive.computation.funs_parser import parser
from revive.computation.graph import *
from revive.data.batch import Batch
from revive.data.processor import DataProcessor
from revive.utils.common_utils import find_later, load_data, plot_traj, PositionalEncoding, create_unit_vector, generate_bin_encoding
DATADIR = os.path.abspath(os.path.join(os.path.dirname(revive.__file__), '../data/'))
[docs]class OfflineDataset(torch.utils.data.Dataset):
r"""An offline dataset class.
Params:
:data_file: The file path where the training dataset is stored.
:config_file: The file path where the data description file is stored.
:horizon: Length of iteration trajectory.
"""
def __init__(self, data_file : str,
config_file : str,
revive_config : dict,
ignore_check : bool = False,
horizon : int = None,
reward_func = None):
self.data_file = data_file
self.config_file = config_file
self.ignore_check = ignore_check
self.reward_func = reward_func
self.revive_config = revive_config
self._raw_data = None
self._raw_config = None
self._pre_data(self.config_file, self.data_file)
self._load_config(self.config_file)
self._load_data(self.data_file)
self.graph.ts_nodes = self.ts_nodes
self.graph.ts_node_frames = self.ts_node_frames
self.graph.freeze_nodes = getattr(self, "freeze_nodes", [])
self.graph.nodes_loss_type = getattr(self, "nodes_loss_type", {})
# pop up unused keys
used_keys = list(self.graph.keys()) + self.graph.leaf + ['done']
used_keys += list(self.ts_nodes.values())
used_keys += list([key for key in self.data.keys() if key.endswith("_isnan_index_")])
for key in list(self.data.keys()):
if key not in used_keys:
try:
self.data.pop(key)
self.raw_columns.pop(key)
self.data_configs.pop(key)
self.orders.pop(key)
warnings.warn(f'Warning: pop up unused key: {key}')
except Exception as e:
logger.info(f"{e}")
self._check_data()
# construct the data processor
self.processing_params = {k : self._get_process_params(self.data.get(k, None), self.data_configs[k], self.orders[k]) for k in self.data_configs.keys()}
for curr_name, next_name in self.graph.transition_map.items():
self.processing_params[next_name] = self.processing_params[curr_name]
self.orders[next_name] = self.orders[curr_name]
self.data_configs[next_name] = self.data_configs[curr_name]
self.processor = DataProcessor(self.data_configs, self.processing_params, self.orders)
self.graph.register_processor(self.processor)
# separate tunable parameter from self.data
if len(self.graph.tunable) > 0:
self.tunable_data = Batch()
for tunable in self.graph.tunable:
self.tunable_data[tunable] = self.data.pop(tunable)
else:
self.tunable_data = None
self.learning_nodes_num = self.graph.summary_nodes()['network_nodes']
# by default, dataset is in trajectory mode
self.trajectory_mode_(horizon)
if self.reward_func is not None:
self.data.to_torch()
self.data["_data_reward_"] = to_numpy(self.reward_func(self.data))
self.data.to_numpy()
def _pre_data(self, config_file, data_file):
# parse ts_nodes
with open(config_file, 'r', encoding='UTF-8') as f:
raw_config = yaml.load(f, Loader=yaml.FullLoader)
raw_data = load_data(data_file)
graph_dict = raw_config['metadata'].get('graph', None)
nodes_config = raw_config['metadata'].get('nodes', None)
self.ts_nodes = {}
self.ts_node_frames = {}
use_step_node = False
self.expert_functions = raw_config['metadata'].get('expert_functions',{})
if nodes_config:
############################ append step node ############################
step_node_config = {k:v["step_input"] for k,v in nodes_config.items() if "step_input" in v.keys() }
if step_node_config:
pos_emd_flag = self.revive_config["use_time_step_embed"]
assert "step_node_" not in raw_data.keys()
trj_index = [0,] + list(raw_data["index"])
max_traj_len = np.max(np.array(trj_index[1:]) - np.array(trj_index[:-1]))
step_node_data_list = []
if pos_emd_flag:
d_model = self.revive_config["time_step_embed_size"]
pos_emd = PositionalEncoding(d_model=d_model, max_len=max_traj_len)
for trj_start_index, trj_end_index in zip(trj_index[:-1], trj_index[1:]):
step_index = np.arange(0, trj_end_index-trj_start_index).reshape(-1,1)
if pos_emd_flag:
step_index = pos_emd.encode(step_index)
step_node_data_list.append(step_index)
raw_data["step_node_"] = np.concatenate(step_node_data_list, axis=0)
raw_config['metadata']['columns'] += [{'step_node_': {'dim': 'step_node_', 'type': 'continuous', 'min': 0, 'max': max_traj_len*2}},] if not pos_emd_flag else [{f'step_node_{i}': {'dim': 'step_node_', 'type': 'continuous', 'min': -1, 'max': 1}} for i in range(d_model)]
for node, step_node_flag in step_node_config.items():
if node not in graph_dict.keys():
continue
if step_node_flag:
logger.info(f"Append step_node data as the {node} node's input.")
graph_dict[node] = graph_dict[node] + ["step_node_",]
use_step_node = True
# append step_node_function
if use_step_node:
if not pos_emd_flag:
raw_config['metadata']['graph']['next_step_node_'] = ['step_node_']
# self.expert_functions = raw_config['metadata'].get('expert_functions',{})
shutil.copyfile(os.path.join(os.path.dirname(revive.__file__),"./common/step_node_function.py"), os.path.join(os.path.dirname(self.config_file), "./step_node_function.py"))
self.expert_functions['next_step_node_'] = {"node_function": "step_node_function.get_next_step_node"}
self._raw_config = raw_config
self._raw_data = raw_data
############################ append ts ###################################
# collect nodes config
ts_frames_config = {"ts_"+k:v["ts"] for k,v in nodes_config.items() if "ts" in v.keys()}
ts_frames_config = {k:v for k,v in ts_frames_config.items() if v > 1}
if ts_frames_config:
max_ts_frames = max(ts_frames_config.values())
nodes = list(graph_dict.keys())
for output_node in list(graph_dict.keys()):
nodes += list(graph_dict[output_node])
nodes = list(set(nodes))
# parse nno ts_nodes
for index, node in enumerate(nodes):
if node.startswith("next_ts_"):
nodes[index] = "next_" + node[8:]
if node.startswith("ts_"):
nodes[index] = node[3:]
ts_nodes = {"ts_"+node:node for node in nodes if "ts_"+node in ts_frames_config.keys()}
for ts_node,node in ts_nodes.items():
if node not in raw_data.keys():
logger.error(f"Can't find '{node}' node data.")
sys.exit()
self.ts_nodes = ts_nodes
# ts_node npz data
trj_index = [0,] + list(raw_data["index"])
new_data = {k:[] for k in ts_nodes.keys()}
new_index = []
i = 1
for trj_start_index, trj_end_index in zip(trj_index[:-1], trj_index[1:]):
new_index.append(trj_end_index+(i*(max_ts_frames-1)))
i += 1
for ts_node,node in ts_nodes.items():
ts_node_frames = ts_frames_config[ts_node]
self.ts_node_frames[ts_node] = ts_node_frames
pad_data = np.concatenate([np.repeat(raw_data[node][trj_start_index:trj_start_index+1],repeats=ts_node_frames-1,axis=0), raw_data[node][trj_start_index:trj_end_index]])
new_data[ts_node].append(np.concatenate([pad_data[i:i+(trj_end_index-trj_start_index)] for i in range(ts_node_frames)], axis=1))
new_data = {k:np.concatenate(v,axis=0) for k,v in new_data.items()}
raw_data.update(new_data)
# ts_node columns
for ts_node, node in ts_nodes.items():
ts_node_frames = ts_frames_config[ts_node]
node_columns = [c for c in raw_config['metadata']['columns'] if list(c.values())[0]["dim"] == node]
ts_node_columns = []
ts_index = 0
for ts_index in range(ts_node_frames):
for node_column in node_columns:
node_column_value = deepcopy(list(node_column.values())[0])
node_column_name = deepcopy(list(node_column.keys())[0])
node_column_value["dim"] = ts_node
if ts_index+1 < ts_node_frames:
node_column_value["fit"] = False
ts_node_columns.append({"ts-"+str(ts_index)+"_"+node_column_name:node_column_value})
raw_config['metadata']['columns'] += ts_node_columns
self._raw_config = raw_config
self._raw_data = raw_data
# breakpoint()
##################################################################
############################ trajectory_id #################
# collect nodes config
traj_id_node_config = {k:v["traj_id"] for k,v in nodes_config.items() if "traj_id" in v.keys()}
if traj_id_node_config:
traj_id_emd_flag = self.revive_config["use_traj_id_embed"]
traj_index = [0,] + list(raw_data["index"])
start_indexes = traj_index[:-1]
end_indexes = traj_index[1:]
traj_num = len(start_indexes)
if traj_id_emd_flag:
traj_encodings = generate_bin_encoding(traj_num)
traj_ids = []
for i_traj in range(traj_num):
if traj_id_emd_flag:
traj_ids.append(np.ones((end_indexes[i_traj] - start_indexes[i_traj], 1)) * traj_encodings[i_traj])
else:
traj_ids.append(np.ones((end_indexes[i_traj] - start_indexes[i_traj], 1)) * i_traj)
raw_data['traj'] = np.concatenate(traj_ids)
raw_config['metadata']['columns'] += [{'traj': {'dim': 'traj', 'type': 'continuous', 'min': 0, 'max': len(start_indexes)-1}},] if not traj_id_emd_flag else [{f'traj_{i}': {'dim': 'traj', 'type': 'discrete', 'min': 0, 'max': 1, 'num': 2}} for i in range(traj_encodings.shape[1])]
for node, _ in traj_id_node_config.items():
raw_config['metadata']['graph'][node].append('traj')
self._raw_config = raw_config
self._raw_data = raw_data
##################################################################
############################ collect freeze node #################
# collect nodes config
self.freeze_nodes = [k for k,v in nodes_config.items() if "freeze" in v.keys() and v["freeze"]]
logger.info("Freeze the following nodes without training -> {self.freeze_nodes}")
self.nodes_loss_type = {k:v["loss_type"] for k,v in nodes_config.items() if "loss_type" in v.keys() and v["loss_type"]}
logger.info(f"Nodes loss type -> {self.nodes_loss_type}")
##################################################################
def _check_data(self):
'''check if the data format is correct'''
''' prepare fake data for no data nodes '''
for nodata_node_name in self.graph.nodata_node_names:
assert nodata_node_name not in self.data.keys()
node_data_dims = len(self.raw_columns[nodata_node_name])
self.data[nodata_node_name] = np.ones((list(self.data.values())[0].shape[0], node_data_dims), dtype=np.float32)
'''1. check if the dimension of data matches the dimension described in yaml'''
for k, v in self.raw_columns.items():
assert k in self.data.keys(), f'Cannot find `{k}` in the data file, please check!'
data = self.data[k]
assert len(data.shape) == 2, f'Expect data in 2D ndarray, got variable `{k}` in shape {data.shape}, please check!'
assert data.shape[-1] == len(v), f'Variable `{k}` described in yaml has {len(v)} dims, but got {data.shape[-1]} dims from the data file, please check!'
for curr_name, next_name in self.graph.transition_map.items():
assert self.data[curr_name].shape == self.data[next_name].shape, \
f'Shape mismatch between `{curr_name}` (shape {self.data[curr_name].shape}) and `{next_name}` (shape {self.data[next_name].shape}). ' + \
f'If it is you who puts `{next_name}` in the data file, please check the way you generate it. ' + \
f'Otherwise, it is probably you have register a function to compute `{next_name}` but it output a wrong shape, please check the function!'
'''2. check if the functions are correctly defined'''
for node_name in self.graph.keys():
node = self.graph.get_node(node_name)
if node.node_type == 'function':
# test 1D case
input_data = {name : self.data[name][0] for name in node.input_names}
should_output_data = self.data[node.name][0]
if node.node_function_type == 'torch':
input_data = {k : torch.tensor(v) for k, v in input_data.items()}
should_output_data = torch.tensor(should_output_data)
output_data = node.node_function(input_data)
assert output_data.shape == should_output_data.shape, \
f'Testing function for `{node_name}`. Expect function output shape {should_output_data.shape}, got {output_data.shape} instead!'
assert type(output_data) == type(should_output_data), \
f'Testing function for `{node_name}`. Expect function output type {type(should_output_data)}, got {type(output_data)} instead!'
assert output_data.dtype == should_output_data.dtype, \
f'Testing function for `{node_name}`. Expect function output dtype {should_output_data.dtype}, got {output_data.dtype} instead!'
# test 2D case
input_data = {name : self.data[name][:2] for name in node.input_names}
should_output_data = self.data[node.name][:2]
if node.node_function_type == 'torch':
input_data = {k : torch.tensor(v) for k, v in input_data.items()}
should_output_data = torch.tensor(should_output_data)
output_data = node.node_function(input_data)
assert output_data.shape == should_output_data.shape, \
f'Testing function for `{node_name}`. Expect function output shape {should_output_data.shape}, got {output_data.shape} instead!'
assert type(output_data) == type(should_output_data), \
f'Testing function for `{node_name}`. Expect function output type {type(should_output_data)}, got {type(output_data)} instead!'
assert output_data.dtype == should_output_data.dtype, \
f'Testing function for `{node_name}`. Expect function output dtype {should_output_data.dtype}, got {output_data.dtype} instead!'
# test 3D case
input_data = {name : self.data[name][:2][np.newaxis] for name in node.input_names}
should_output_data = self.data[node.name][:2][np.newaxis]
if node.node_function_type == 'torch':
input_data = {k : torch.tensor(v) for k, v in input_data.items()}
should_output_data = torch.tensor(should_output_data)
output_data = node.node_function(input_data)
assert output_data.shape == should_output_data.shape, \
f'Testing function for `{node_name}`. Expect function output shape {should_output_data.shape}, got {output_data.shape} instead!'
assert type(output_data) == type(should_output_data), \
f'Testing function for `{node_name}`. Expect function output type {type(should_output_data)}, got {type(output_data)} instead!'
assert output_data.dtype == should_output_data.dtype, \
f'Testing function for `{node_name}`. Expect function output dtype {should_output_data.dtype}, got {output_data.dtype} instead!'
# test value
input_data = {name : self.data[name] for name in node.input_names}
should_output_data = self.data[node.name]
if node.node_function_type == 'torch':
input_data = {k : torch.tensor(v) for k, v in input_data.items()}
output_data = node.node_function(input_data)
if node.node_function_type == 'torch':
output_data = output_data.numpy()
error = np.abs(output_data - should_output_data)
if np.max(error) > 1e-8:
message = f'Test values for function "{node.name}", find max mismatch {np.max(error, axis=0)}. Please check the function.'
if self.ignore_check or np.max(error) < 1e-4:
logger.warning(message)
else:
message += '\nIf you are sure that the function is right and the value error is acceptable, configure "ignore_check=True" in the config.json to skip.'
message += '\nIf you are using the "train.py" script. You can add the "--ignore_check 1" to skip. E.g. python train.py --ignore_check 1'
logger.error(message)
raise ValueError(message)
'''3. check if the transition variables match'''
for curr_name, next_name in self.graph.transition_map.items():
curr_data = []
next_data = []
for start, end in zip(self._start_indexes, self._end_indexes):
curr_data.append(self.data[curr_name][start+1:end])
next_data.append(self.data[next_name][start:end-1])
if not np.allclose(np.concatenate(curr_data), np.concatenate(next_data), 1e-4):
error = np.abs(np.concatenate(curr_data) - np.concatenate(next_data))
message = f'Test transition values for {curr_name} and {next_name}, find max mismatch {np.max(error, axis=0)}. ' + \
f'If {next_name} is provided by you, please check the data file. If you provide {next_name} as a function, please check the function.'
if self.ignore_check:
logger.warning(message)
else:
logger.error(message)
raise ValueError(message)
''' Delete fake data for no data nodes '''
for nodata_node_name in self.graph.nodata_node_names:
assert nodata_node_name in self.data.keys()
self.data.pop(nodata_node_name)
[docs] def compute_missing_data(self, need_truncate):
if need_truncate:
new_data = {k : [] for k in self.data.keys()}
for node_name in self.graph.transition_map.values():
new_data[node_name] = []
for start, end in zip(self._start_indexes, self._end_indexes):
for k in self.data.keys():
new_data[k].append(self.data[k][start:end-1])
for k in self.data.keys():
new_data[k] = np.concatenate(new_data[k], axis=0)
else:
new_data = deepcopy(self.data)
for node_name in self.graph.graph_list:
if node_name not in self.data.keys():
node = self.graph.get_node(node_name)
if node.node_type == 'function':
# expert processing
warnings.warn(f'Detect node {node_name} is not avaliable in the provided data, trying to compute it ...')
assert node.node_type == 'function', \
f'You need to provide the function to compute node {node_name} since it is not given in the data!'
inputs = {name : new_data[name] for name in node.input_names}
convert_func = torch.tensor if node.node_function_type == 'torch' else np.array
inputs = {k : convert_func(v) for k, v in inputs.items()}
output = node.node_function(inputs)
new_data[node_name] = np.array(output).astype(np.float32)
elif node_name.startswith("next_"):
# next_ processing
assert need_truncate == True, f"need_truncate == False, however {node_name} is missing in data"
ori_name = node_name[5:]
warnings.warn(f'transition variable {node_name} is not provided and cannot be computed!')
for start, end in zip(self._start_indexes, self._end_indexes):
new_data[node_name].append(self.data[ori_name][start+1:end])
new_data[node_name] = np.concatenate(new_data[node_name], axis=0)
elif node.node_type == 'network':
# this is for handling empty node
logger.warning(f"Detect {node_name} as empty node !")
continue
else:
logger.error(f"Node {node_name} is not neither next_ node nor expert node. And {node_name} is missing in data. Please check the dataset file !")
raise NotImplementedError
if need_truncate:
self._start_indexes -= np.arange(self._start_indexes.shape[0])
self._end_indexes -= np.arange(self._end_indexes.shape[0]) + 1
self._traj_lengths -= 1
self._min_length -= 1
self._max_length -= 1
self.data = Batch(new_data)
def _load_data(self, data_file : str):
'''
load data from the data file and conduct following processes:
1. parse trajectory length and start and end indexes from data.
if `index` is not provided, consider trajectory length is equal to 1.
2. if `done` is not in the data, create an all-zero data for it.
3. try to compute values of unprovided node with expert function.
4. if any transition variable is not available, truncate the trajectories by 1.
'''
if self._raw_data:
raw_data = self._raw_data
else:
raw_data = load_data(data_file)
# logger.error("This is for devlop. Please delete it.")
# nan_index = np.random.choice(raw_data["temperature"].shape[0], size=500)
# raw_data["temperature"][nan_index] = np.nan
# make sure data is in float32
for k, v in raw_data.items():
if v.dtype != np.float32:
raw_data[k] = v.astype(np.float32)
# mark the start and end of each trajectory
try:
self._end_indexes = raw_data.pop('index').astype(int)
except:
# if no index, consider the data is with a length of 1
warnings.warn('key `index` is not provided, assuming data with length 1!')
self._end_indexes = np.arange(0, raw_data[list(self.graph.keys())[0]].shape[0]) + 1
# check if the index is correct
assert np.all((self._end_indexes[1:] - self._end_indexes[:-1]) > 0), f'index must be incremental order, but got {self._end_indexes}.'
for output_node_name in raw_data.keys():
if output_node_name not in self.graph.tunable:
datasize = raw_data[output_node_name].shape[0]
assert datasize == self._end_indexes[-1], \
f'detect index exceed the provided data, the max index is {self._end_indexes[-1]}, but got {output_node_name} in shape {raw_data[output_node_name].shape}.'
self._start_indexes = np.concatenate([np.array([0]), self._end_indexes[:-1]])
self._traj_lengths = self._end_indexes - self._start_indexes
self._min_length = np.min(self._traj_lengths)
self._max_length = np.max(self._traj_lengths)
# check if `done` is defined in the data
if not 'done' in raw_data.keys():
# when done is not available, set to all zero
warnings.warn('key `done` is not provided, set it to all zero!')
raw_data['done'] = np.zeros((datasize, 1), dtype=np.float32)
self.data = Batch(raw_data)
# check if all the transition variables are in the data
need_truncate = False
for node_name in self.graph.transition_map.values():
if node_name not in self.data.keys():
warnings.warn(f'transition variable {node_name} is not provided and cannot be computed!')
need_truncate = True
if need_truncate:
for node_name in self.graph.transition_map.values(): # clean out existing variables
if node_name in self.data.keys():
self.data.pop(node_name)
warnings.warn('truncating the trajectory by 1 step to generate transition variables!')
assert self._min_length > 1, 'cannot truncate trajectory with length 1'
# compute node if they are defined on the graph but not present in the data
need_compute = False
for node_name in self.graph.keys():
if node_name not in self.data.keys():
warnings.warn(f'Detect node {node_name} is not avaliable in the provided data, trying to compute it ...')
need_compute = True
# here to process the data
if need_truncate or need_compute:
self.compute_missing_data(need_truncate)
# append mask index for nan value
self.nan_isin_data = False
for key in list(self.data.keys()):
isnan_index = np.isnan(self.data[key]).astype(np.float32)
if np.sum(isnan_index) > 0.5:
logger.warning(f"Find nan value in {key} node data. Auto set the ignore_check=True.")
self.data[key+"_isnan_index_"] = isnan_index
df = pd.DataFrame(self.data[key])
df.fillna(method="bfill", inplace=True)
df.fillna(method="ffill", inplace=True)
self.data[key] = df.values
self.ignore_check = True
self.nan_isin_data = True
# parse no data nodes: No data node support for training using the revive algo
nodata_node_names = []
for node in self.graph.nodes.values():
output_node_name = node.name
if output_node_name not in list(self.data.keys()):
logger.warning(f'Find no data node. Node type: "{node.node_type}". Node name: "{output_node_name}".')
if node.node_type == "network":
nodata_node_names.append(output_node_name)
elif node.node_type == "function":
# TODO: support automatic calculation of data for function nodes.
pass
else:
logger.error(f"Find unknow node_type: {node.node_type}.")
raise NotImplementedError
self.graph.nodata_node_names = nodata_node_names
self.graph.metric_nodes = list(set(self.graph.metric_nodes) - set(self.graph.nodata_node_names))
return self.data
# NOTE: mode should be set before create dataloader
[docs] def transition_mode_(self):
''' Set the dataset in transition mode. `__getitem__` will return a transition. '''
self.end_indexes = np.arange(0, self.data[self.graph.leaf[0]].shape[0]) + 1
# self.end_indexes = np.arange(0, self.data[list(self.graph.keys())[0]].shape[0]) + 1
self.start_indexes = np.concatenate([np.array([0]), self.end_indexes[:-1]])
self.traj_lengths = self.end_indexes - self.start_indexes
self.min_length = np.min(self.traj_lengths)
self.max_length = np.max(self.traj_lengths)
self.set_horizon(1)
self.index_to_traj = [0] + list(np.cumsum(self._traj_lengths))
self.mode = 'transition'
self.fix_sample = True
return self
[docs] def trajectory_mode_(self, horizon : Optional[int] = None, fix_sample : bool = False):
r''' Set the dataset in trajectory mode. `__getitem__` will return a clip of trajectory. '''
self.end_indexes = self._end_indexes
self.start_indexes = self._start_indexes
self.traj_lengths = self._traj_lengths
self.min_length = self._min_length
self.max_length = self._max_length
horizon = horizon or self.min_length
self.set_horizon(horizon)
self.index_to_traj = [0] + list(np.cumsum(self.traj_lengths // self.horizon))
self.mode = 'trajectory'
self.fix_sample = fix_sample
return self
def _find_trajectory(self, index : int):
''' perform binary search for the index of true trajectories from the index of the sample trajectory '''
left, right = 0, len(self.index_to_traj) - 1
mid = (left + right) // 2
while not (index >= self.index_to_traj[mid] and index < self.index_to_traj[mid+1]):
if index < self.index_to_traj[mid]:
right = mid - 1
else:
left = mid + 1
mid = (left + right) // 2
return mid
[docs] def set_horizon(self, horizon : int):
r''' Set the horzion for loading data '''
if horizon > self.min_length:
logger.warning(f'Warning: the min length of dataset is {self.min_length}, which is less than the horzion {horizon} you require. ' + \
f'Fallback to use horzion = {self.min_length}.')
self.horizon = min(horizon, self.min_length)
logger.info(f"Set trajectory horizon : {self.horizon}")
[docs] def get_dist_configs(self, model_config):
r'''
Get the config of distributions for each node based on the given model config.
Args:
:model_config: The given model config.
Return:
:dist_configs: config of distributions for each node.
:total_dims: dimensions for each node when it is considered as input and output.
(Output dimensions can be different from input dimensions due to the parameterized distribution)
'''
dist_configs = {k : self._get_dist_config(self.data_configs[k], model_config) for k in self.data_configs.keys()}
total_dims = {k : self._get_dim(dist_configs[k]) for k in dist_configs.keys()}
return dist_configs, total_dims
def _load_config(self, config_file : str):
"""
load data description from `.yaml` file. Few notes:
1. the name of each dimension will be discarded since they doesn't help the computation.
2. dimensions of each node will be reordered (category, discrete, continuous) to speed up computation.
3. register expert functions and tunable parameters if defined.
"""
if self._raw_config:
raw_config = self._raw_config
else:
with open(config_file, 'r', encoding='UTF-8') as f:
raw_config = yaml.load(f, Loader=yaml.FullLoader)
# collect description for the same node
data_config = raw_config['metadata']['columns']
self.columns = data_config
keys = set([list(d.values())[0]['dim'] for d in data_config])
raw_columns = {}
config = {}
order = {}
fit = {}
for config_key in keys:
raw_columns[config_key], config[config_key], order[config_key], fit[config_key] = self._load_config_for_single_node(data_config, config_key)
# parse graph
graph_dict = raw_config['metadata'].get('graph', None)
# parse metric_nodes
metric_nodes = raw_config['metadata'].get('metric_nodes', None)
graph = DesicionGraph(graph_dict, raw_columns, fit, metric_nodes)
# copy the raw columns for transition variables to allow them as input to other nodes
for curr_name, next_name in graph.transition_map.items():
raw_columns[next_name] = raw_columns[curr_name]
# if you use next_obs in graph, without obs as input, then next_obs would lose its description. [typically if you ts_obs instead of obs]
# following to fix the above
# for node_name in list(graph.keys()):
# if node_name.startswith("next_"):
# ori_name = node_name[5:]
# if ori_name not in graph.leaf and "ts_"+ori_name in graph.leaf:
# raw_columns[node_name] = raw_columns[ori_name]
# mark tunable parameters
for node_name in raw_config['metadata'].get('tunable', []): graph.mark_tunable(node_name)
# expert_functions = raw_config['metadata'].get('expert_functions',{})
custom_nodes = raw_config['metadata'].get('custom_nodes', None)
def get_function_type(file_path : str, function_name : str, file_name: str) -> str:
'''get the function type from type hint'''
with open(file_path, 'r') as f:
for line in f.readlines():
if line.startswith('def ') and function_name in line:
if 'Tensor' in line:
return 'torch'
elif 'ndarray' in line:
return 'numpy'
else:
warnings.warn('Type hint is not provided, assume it is an numpy function!')
return 'numpy'
raise ValueError(f'Cannot find function {function_name} in {file_name}.py, please check your yaml!')
# later = find_later(self.config_file, 'data')
# head = '.'.join(later[:-1])
# register expert functions to the graph
if self.expert_functions is not None:
for node_name, function_description in self.expert_functions.items():
# NOTE: currently we assume the expert functions are also placed in the same folder as the yaml file.
if 'node_function' in function_description.keys(): # `node function` should be like [file].[function_name]`
graph.register_node(node_name, FunctionDecisionNode)
file_name, function_name = function_description['node_function'].split('.')
file_path = os.path.join(os.path.dirname(self.config_file), file_name + '.py')
function_type = get_function_type(file_path, function_name, file_name)
parse_file_path = file_path[:-3]+"_parsed.py"
if not parser(file_path,parse_file_path,self.config_file):
parse_file_path = file_path
function_type = get_function_type(parse_file_path, function_name, file_name)
file_name = os.path.split(os.path.splitext(parse_file_path)[0])[-1]
sys.path.insert(0, os.path.dirname(parse_file_path))
source_file = importlib.import_module(f'{file_name}')
func = eval(f'source_file.{function_name}')
graph.get_node(node_name).register_node_function(func, function_type)
logger.info(f'register node function ({function_type} version) for {node_name}')
for node_name in list(graph.keys()):
if node_name.startswith("next_ts_"):
ori_name = node_name[8:]
if set(graph.graph_dict[node_name]) == set(["ts_"+ori_name, "next_"+ori_name]) and (node_name not in self.expert_functions):
# automatically register expert functions to this node
graph.register_node(node_name, FunctionDecisionNode)
func_description = "next_ts_transition_function.next_ts_placeholder_transition_function"
file_name, function_name = func_description.split('.')
function_name = function_name.replace("placeholder", ori_name)
shutil.copyfile(os.path.join(os.path.dirname(revive.__file__),"./common/next_ts_transition_function.py"), os.path.join(os.path.dirname(self.config_file), "./next_ts_transition_function.py"))
file_path = os.path.join(os.path.dirname(self.config_file), file_name + '.py')
with open(file_path, 'r+') as file:
data = file.read()
genral_func = data[data.find('def'):]
new_func = '\n\n' + genral_func.replace("placeholder", ori_name)
file.write(new_func)
function_type = get_function_type(file_path, function_name, file_name)
sys.path.insert(0, os.path.dirname(file_path))
time.sleep(1)
source_file = importlib.import_module(f'{file_name}')
func = eval(f'source_file.{function_name}')
graph.get_node(node_name).register_node_function(func, function_type)
logger.info(f'register node function ({function_type} version) for {node_name}')
if set(graph.graph_dict[node_name]) == set(["ts_"+ori_name, ori_name]) and (node_name not in self.expert_functions):
# automatically register expert functions to this node
graph.register_node(node_name, FunctionDecisionNode)
func_description = "next_ts_policy_function.next_ts_placeholder_policy_function"
file_name, function_name = func_description.split('.')
function_name = function_name.replace("placeholder", ori_name)
shutil.copyfile(os.path.join(os.path.dirname(revive.__file__),"./common/next_ts_policy_function.py"), os.path.join(os.path.dirname(self.config_file), "./next_ts_policy_function.py"))
file_path = os.path.join(os.path.dirname(self.config_file), file_name + '.py')
with open(file_path, 'r+') as file:
data = file.read()
genral_func = data[data.find('def'):]
new_func = '\n\n' + genral_func.replace("placeholder", ori_name)
file.write(new_func)
function_type = get_function_type(file_path, function_name, file_name)
sys.path.insert(0, os.path.dirname(file_path))
time.sleep(1)
source_file = importlib.import_module(f'{file_name}')
func = eval(f'source_file.{function_name}')
graph.get_node(node_name).register_node_function(func, function_type)
logger.info(f'register node function ({function_type} version) for {node_name}')
# register custom nodes to the graph
if custom_nodes is not None:
for node_name, custom_node in custom_nodes.items():
# NOTE: currently we assume the custom nodes are also placed in the same folder as the yaml file.
# `custom node should be given in the form of [file].[node_class_name]`
file_name, class_name = custom_node.split('.')
file_path = os.path.join(os.path.dirname(self.config_file), file_name + '.py')
sys.path.insert(0, os.path.dirname(file_path))
source_file = importlib.import_module(f'{file_name}')
node_class = eval(f'source_file.{class_name}')
graph.register_node(node_name, node_class)
logger.info(f'register custom node `{node_class}` for {node_name}')
# register other nodes with the default `NetworkDecisionNode`
for node_name, node in graph.nodes.items():
if node is None:
graph.register_node(node_name, NetworkDecisionNode)
logger.info(f'register the default `NetworkDecisionNode` for {node_name}')
self.raw_columns, self.data_configs, self.orders, self.graph, self.fit = raw_columns, config, order, graph, fit
def _load_config_for_single_node(self, raw_config : list, node_name : str):
'''
load config for a single node.
:return
raw_config: columns belong to this node
config: collected columns with type in order: category, discrete, continuous
order: order of the index to convert to the collected columns
'''
raw_config = list(filter(lambda d: list(d.values())[0]['dim'] == node_name, raw_config))
config = []
discrete_count = 0
discrete_min = []
discrete_max = []
discrete_num = []
fit = []
continuous_count = 0
continuous_min = []
continuous_max = []
order = []
for index, d in enumerate(raw_config):
name = list(d.keys())[0]
_config = d[name]
if _config['type'] == 'category':
assert 'values' in _config.keys(), f'Parsing columns for node `{node_name}`, you must provide `values` for a `category` dimension.'
_config['dim'] = len(_config['values'])
order.append((index, 1))
config.append(_config)
elif _config['type'] == 'continuous':
order.append((index, 3))
continuous_count += 1
continuous_max.append(_config.get('max', None))
continuous_min.append(_config.get('min', None))
elif _config['type'] == 'discrete':
assert 'num' in _config.keys() and _config['num'] > 1, f'Parsing columns for node `{node_name}`, you must provide `num` > 1 for a `discrete` dimension.'
discrete_count += 1
order.append((index, 2))
discrete_max.append(_config.get('max', None))
discrete_min.append(_config.get('min', None))
discrete_num.append(_config['num'])
else:
logger.error(f"Data type {_config['type']} is not support. Please check the yaml file.")
raise NotImplementedError
if "fit" in _config.keys() and not _config["fit"]:
if _config['type'] == 'category':
fit += [False,]*len(_config['values'])
else:
fit.append(False)
else:
if _config['type'] == 'category':
fit += [True,]*len(_config['values'])
else:
fit.append(True)
order = sorted(order, key=lambda x: x[1])
forward_order = [o[0] for o in order]
order = [(ordered_index, origin_index) for ordered_index, origin_index in enumerate(forward_order)]
order = sorted(order, key=lambda x: x[1])
backward_order = [o[0] for o in order]
order = {
'forward' : forward_order,
'backward' : backward_order,
}
if discrete_count > 0:
config.append({'type' : 'discrete', "dim" : discrete_count, 'max' : discrete_max, 'min' : discrete_min, 'num' : discrete_num})
if continuous_count > 0:
config.append({'type' : 'continuous', 'dim' : continuous_count, 'max' : continuous_max, 'min' : continuous_min})
return raw_config, config, order, fit
def _get_dist_config(self, data_config : List[Dict[str, Any]], model_config : Dict[str, Any]):
dist_config = []
for config in data_config:
config = config.copy()
if config['type'] == 'category':
assert model_config['category_distribution_type'] in ['onehot'], \
f"distribution type {model_config['category_distribution_type']} is not support for category variables!"
config['type'] = model_config['category_distribution_type']
config['output_dim'] = config['dim']
elif config['type'] == 'discrete':
assert model_config['discrete_distribution_type'] in ['gmm', 'normal', 'discrete_logistic'], \
f"distribution type {model_config['discrete_distribution_type']} is not support for discrete variables!"
config['type'] = model_config['discrete_distribution_type']
if config['type'] == 'discrete_logistic':
config['output_dim'] = config['dim'] * 2
config['num'] = config['num']
elif config['type'] == 'normal':
config['conditioned_std'] = model_config['conditioned_std']
config['output_dim'] = config['dim'] * (1 + config['conditioned_std'])
elif config['type'] == 'gmm':
config['mixture'] = model_config['mixture']
config['conditioned_std'] = model_config['conditioned_std']
config['output_dim'] = config['mixture'] * ((1 + config['conditioned_std']) * config['dim'] + 1)
else:
assert model_config['continuous_distribution_type'] in ['gmm', 'normal'], \
f"distribution type {model_config['continuous_distribution_type']} is not support for discrete variables!"
config['type'] = model_config['continuous_distribution_type']
if config['type'] == 'normal':
config['conditioned_std'] = model_config['conditioned_std']
config['output_dim'] = config['dim'] * (1 + config['conditioned_std'])
elif config['type'] == 'gmm':
config['mixture'] = model_config['mixture']
config['conditioned_std'] = model_config['conditioned_std']
config['output_dim'] = config['mixture'] * ((1 + config['conditioned_std']) * config['dim'] + 1)
dist_config.append(config)
return dist_config
def _get_process_params(self, data : np.ndarray, data_config : Dict[str, Union[int, str, List[float]]], order : List[int]):
''' get necessary parameters for data processor '''
additional_parameters = []
forward_slices = []
backward_slices = []
forward_start_index = 0
backward_start_index = 0
if data is None:
total_dims = sum([config["dim"] for config in data_config])
for dim in range(total_dims):
additional_parameters.append((np.array(0).astype(np.float32), np.array(1).astype(np.float32)))
forward_slices.append(slice(0,total_dims))
backward_slices.append(slice(0,total_dims))
return {
'forward_slices' : forward_slices,
'backward_slices' : backward_slices,
'additional_parameters' : additional_parameters,
}
data = data.copy()
data = data.take(order['forward'], axis=-1)
for config in data_config:
if config['type'] == 'category':
forward_end_index = forward_start_index + 1
additional_parameters.append(np.array(config['values']).astype(np.float32))
elif config['type'] == 'continuous':
forward_end_index = forward_start_index + config['dim']
_data = data[:, forward_start_index : forward_end_index]
_data = _data.reshape((-1, _data.shape[-1]))
data_max = _data.max(axis=0)
data_min = _data.min(axis=0)
for i in range(config['dim']):
if config['max'][i] is None: config['max'][i] = data_max[i]
if config['min'][i] is None: config['min'][i] = data_min[i]
max_num = np.array(config['max']).astype(np.float32)
min_num = np.array(config['min']).astype(np.float32)
interval = max_num - min_num
interval[interval==0] = 2 # prevent dividing zero
additional_parameters.append(((max_num + min_num) / 2, 0.5 * interval))
elif config['type'] == 'discrete':
forward_end_index = forward_start_index + config['dim']
_data = data[:, forward_start_index : forward_end_index]
_data = _data.reshape((-1, _data.shape[-1]))
data_max = _data.max(axis=0)
data_min = _data.min(axis=0)
for i in range(config['dim']):
if config['max'][i] is None: config['max'][i] = data_max[i]
if config['min'][i] is None: config['min'][i] = data_min[i]
max_num = np.array(config['max']).astype(np.float32)
min_num = np.array(config['min']).astype(np.float32)
interval = max_num - min_num
interval[interval==0] = 2 # prevent dividing zero
additional_parameters.append(((max_num + min_num) / 2, 0.5 * interval, np.array(config['num'])))
backward_end_index = backward_start_index + config['dim']
forward_slices.append(slice(forward_start_index, forward_end_index))
backward_slices.append(slice(backward_start_index, backward_end_index))
forward_start_index = forward_end_index
backward_start_index = backward_end_index
return {
'forward_slices' : forward_slices,
'backward_slices' : backward_slices,
'additional_parameters' : additional_parameters,
}
def _get_dim(self, dist_configs):
return {
'input' : sum([d['dim'] for d in dist_configs]),
'output' : sum([d['output_dim'] for d in dist_configs])
}
[docs] def __len__(self) -> int:
return np.sum(self.traj_lengths // self.horizon)
[docs] def __getitem__(self, index : int, raw : bool = False) -> Batch:
if self.mode == 'trajectory':
traj_index = self._find_trajectory(index)
if self.fix_sample: # fix the starting point of each slice in the trajectory
start_index = self.start_indexes[traj_index] + self.horizon * (index - self.index_to_traj[traj_index])
else: # randomly sample valid start point from the trajectory
length = self.end_indexes[traj_index] - self.start_indexes[traj_index]
start_index = self.start_indexes[traj_index] + np.random.randint(0, length - self.horizon + 1)
raw_data = self.data[start_index : (start_index + self.horizon)]
if self.tunable_data is not None:
raw_data.update(self.tunable_data[traj_index])
# tunable_data = Batch()
# for tunable, data in self.tunable_data.items():
# tunable_data[tunable] = data[traj_index][np.newaxis].repeat(self.horizon, axis=0)
# raw_data.update(tunable_data)
# TODO: Update
# Skip nan in start index data
if self.nan_isin_data:
for node in self.graph.get_leaf():
if node + "_isnan_index_" in raw_data.keys():
if np.sum(raw_data[node + "_isnan_index_"][:1]) > 0.5:
return self.__getitem__(np.random.choice(np.arange(self.__len__())))
elif self.mode == 'transition':
raw_data = self.data[index]
if self.tunable_data is not None:
traj_index = self._find_trajectory(index)
raw_data.update(self.tunable_data[traj_index])
if raw:
return raw_data
return self.processor.process(raw_data)
[docs] def split(self, ratio : float = 0.5, mode : str = 'outside_traj', recall : bool = False) -> Tuple['OfflineDataset', 'OfflineDataset']:
r''' split the dataset into train and validation with the given ratio and mode
Args:
:ratio: Ratio to split validate dataset if it is not explicitly given.
:mode: Mode of auto splitting training and validation dataset, choose from `outside_traj` and `inside_traj`.
`outside_traj` means the split is happened outside the trajectories, one trajectory can only be in one dataset. ' +
`inside_traj` means the split is happened inside the trajectories, former part of one trajectory is in training set, later part is in validation set.
Return:
(TrainDataset, ValidateDataset)
'''
val_dataset = deepcopy(self)
if mode == 'outside_traj':
total_traj_num = len(self._start_indexes)
val_traj_num = int(total_traj_num * ratio)
train_traj_num = total_traj_num - val_traj_num
if not (val_traj_num > 0 and train_traj_num > 0):
message = f'Cannot split a dataset with {total_traj_num} trajectories to {train_traj_num} (training) and {val_traj_num} (validation).'
if recall:
raise RuntimeError(message)
else:
warnings.warn(message)
warnings.warn('Fallback to `inside_traj` mode!')
return self.split(ratio=ratio, mode='inside_traj', recall=True)
total_traj_index = list(range(total_traj_num))
np.random.shuffle(total_traj_index)
val_traj_index = sorted(total_traj_index[:val_traj_num])
train_traj_index = sorted(total_traj_index[val_traj_num:])
self._rebuild_from_traj_index(train_traj_index)
val_dataset._rebuild_from_traj_index(val_traj_index)
elif mode == 'inside_traj':
training_slices = []
validation_slices = []
for total_point in self._traj_lengths:
slice_point = int(total_point * (1 - ratio))
train_point = slice_point
validation_point = total_point - slice_point
if not (train_point > 0 and validation_point > 0):
message = f'Cannot split a trajectory with {total_point} steps to {train_point} (training) and {validation_point} (validation).'
if recall:
raise RuntimeError(message)
else:
warnings.warn(message)
warnings.warn('Fallback to `outside_traj` mode!')
return self.split(ratio=ratio, mode='outside_traj', recall=True)
training_slices.append(slice(0, slice_point))
validation_slices.append(slice(slice_point, total_point))
self._rebuild_from_slices(training_slices)
val_dataset._rebuild_from_slices(validation_slices)
else:
raise ValueError(f'Split mode {mode} is not understood, please check your config!')
return self, val_dataset
def _rebuild_from_traj_index(self, traj_indexes : List[int]) -> 'OfflineDataset':
''' rebuild the dataset by subsampling the trajectories '''
# rebuild data
new_data = []
for traj_index in traj_indexes:
start = self._start_indexes[traj_index]
end = self._end_indexes[traj_index]
new_data.append(self.data[start:end])
self.data = Batch.cat(new_data)
# rebuild index
self._traj_lengths = self._end_indexes[traj_indexes] - self._start_indexes[traj_indexes]
self._end_indexes = np.cumsum(self._traj_lengths)
self._start_indexes = np.concatenate([np.array([0]), self._end_indexes[:-1]])
self._min_length = np.min(self._traj_lengths)
self._max_length = np.max(self._traj_lengths)
self.trajectory_mode_() if self.mode == 'trajectory' else self.transition_mode_()
return self
def _rebuild_from_slices(self, slices : List[slice]) -> 'OfflineDataset':
''' rebuild the dataset by slicing inside the trajectory '''
assert len(self._traj_lengths) == len(slices)
# rebuild data
new_data = []
lengths = []
for s, start, end in zip(slices, self._start_indexes, self._end_indexes):
data = self.data[start:end][s]
lengths.append(data.shape[0])
new_data.append(data)
self.data = Batch.cat(new_data)
# rebuild index
self._traj_lengths = np.array(lengths)
self._end_indexes = np.cumsum(self._traj_lengths)
self._start_indexes = np.concatenate([np.array([0]), self._end_indexes[:-1]])
self._min_length = np.min(self._traj_lengths)
self._max_length = np.max(self._traj_lengths)
self.trajectory_mode_() if self.mode == 'trajectory' else self.transition_mode_()
return self
[docs]class InfiniteDataLoader:
r""" Wrapper that enables infinite pre-fetching, must use together with InfiniteUniformSampler"""
def __init__(self, dataloader : torch.utils.data.DataLoader):
self.dataloader = dataloader
self.dataset = self.dataloader.dataset
self.iter = iter(self.dataloader)
def __iter__(self):
return iter([next(self.iter)])
[docs]def collect_data(expert_data : List[Batch], graph : DesicionGraph) -> Batch:
r''' Collection function for PyTorch DataLoader '''
expert_data = Batch.stack(expert_data, axis=-2)
expert_data.to_torch()
if graph.transition_map:
selected_name = list(graph.transition_map.keys())[0]
if len(expert_data[selected_name].shape) == 3:
for tunable_name in graph.tunable:
expert_data[tunable_name] = expert_data[tunable_name].expand(expert_data[selected_name].shape[0], *[-1] * len(expert_data[tunable_name].shape))
return expert_data
[docs]def get_loader(dataset : OfflineDataset, config : dict, is_sample : bool = True):
""" Get the PoTorch DataLoader for training """
batch_size = config[BATCH_SIZE]
if is_sample:
loader = torch.utils.data.DataLoader(dataset, batch_sampler=InfiniteUniformSampler(dataset, batch_size),
collate_fn=partial(collect_data, graph=dataset.graph), pin_memory=True, num_workers=config['data_workers'])
loader = InfiniteDataLoader(loader)
else:
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
collate_fn=partial(collect_data, graph=dataset.graph), pin_memory=True, num_workers=config['data_workers'])
return loader
[docs]def data_creator(config : dict,
training_mode : str = 'trajectory',
training_horizon : int = None,
training_is_sample : bool = True,
val_mode : str = 'trajectory',
val_horizon : int = None,
val_is_sample : bool = False,
double : bool = False):
"""
Get train data loader and validation data loader.
:return: train data loader and validation data loader
"""
train_dataset = ray.get(config['dataset'])
val_dataset = ray.get(config['val_dataset'])
config['dist_configs'], config['total_dims'] = train_dataset.get_dist_configs(config)
config['learning_nodes_num'] = train_dataset.learning_nodes_num
if training_horizon is None and val_horizon is not None:
training_horizon = val_horizon
if training_horizon is not None and val_horizon is None:
val_horizon = training_horizon
if not double:
train_dataset = train_dataset.trajectory_mode_(training_horizon) if training_mode == 'trajectory' else train_dataset.transition_mode_()
val_dataset = val_dataset.trajectory_mode_(val_horizon) if val_mode == 'trajectory' else val_dataset.transition_mode_()
train_loader = get_loader(train_dataset, config, training_is_sample)
val_loader = get_loader(val_dataset, config, val_is_sample)
return train_loader, val_loader
else: # perform double venv training
''' NOTE: train_dataset_val means training set used in validation '''
train_dataset_train = deepcopy(train_dataset)
val_dataset_train = deepcopy(val_dataset)
train_dataset_val = deepcopy(train_dataset)
val_dataset_val = deepcopy(val_dataset)
train_dataset_train = train_dataset_train.trajectory_mode_(training_horizon) if training_mode == 'trajectory' else train_dataset_train.transition_mode_()
val_dataset_train = val_dataset_train.trajectory_mode_(training_horizon) if training_mode == 'trajectory' else val_dataset_train.transition_mode_()
train_dataset_val = train_dataset_val.trajectory_mode_(val_horizon) if val_mode == 'trajectory' else train_dataset_val.transition_mode_()
val_dataset_val = val_dataset_val.trajectory_mode_(val_horizon) if val_mode == 'trajectory' else val_dataset_val.transition_mode_()
train_loader_train = get_loader(train_dataset_train, config, training_is_sample)
val_loader_train = get_loader(val_dataset_train, config, training_is_sample)
train_loader_val = get_loader(train_dataset_val, config, val_is_sample)
val_loader_val = get_loader(val_dataset_val, config, val_is_sample)
return train_loader_train, val_loader_train, train_loader_val, val_loader_val
if __name__ == '__main__':
dataset = OfflineDataset()
loader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=2, collate_fn=partial(Batch.stack, axis=1), shuffle=True)
data = next(iter(loader))
print(data)
print(data[:, 0].obs)
single_data = dataset.__getitem__(np.random.randint(len(dataset)), raw=True)
processed_data = dataset.processor.process(single_data)
deprocessed_data = dataset.processor.deprocess(processed_data)
processor = dataset.processor
processed_obs = processor.process_single(single_data.obs, 'obs')
deprocessed_obs = processor.deprocess_single(processed_obs, 'obs')
for k in single_data.keys():
assert np.all(np.isclose(deprocessed_data[k], single_data[k], atol=1e-6)), [k, deprocessed_data[k] - single_data[k]]
assert np.all(np.isclose(deprocessed_data.obs, deprocessed_obs, atol=1e-6)), [processed_data.obs - processed_obs]
plot_traj(single_data)
# test sampler
data = torch.rand(1000, 4)
dataset = torch.utils.data.TensorDataset(data)
sampler = UniformSampler(dataset, 3)
loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
for _ in range(10):
for b in loader:
print(b)