revive.data

class revive.data.dataset.OfflineDataset(data_file: str, config_file: str, revive_config: dict | None = None, ignore_check: bool = False, horizon: int | None = None, reward_func=None)[source]

Bases: Dataset

An offline dataset class.

Params:
data_file:

The file path where the training dataset is stored.

config_file:

The file path where the data description file is stored.

horizon:

Length of iteration trajectory.

compute_missing_data(need_truncate)[source]
transition_mode_()[source]

Set the dataset in transition mode. __getitem__ will return a transition.

trajectory_mode_(horizon: int | None = None, fix_sample: bool = False)[source]

Set the dataset in trajectory mode. __getitem__ will return a clip of trajectory.

set_horizon(horizon: int)[source]

Set the horzion for loading data

get_dist_configs(model_config)[source]

Get the config of distributions for each node based on the given model config.

Args:
model_config:

The given model config.

Return:
dist_configs:

config of distributions for each node.

total_dims:

dimensions for each node when it is considered as input and output. (Output dimensions can be different from input dimensions due to the parameterized distribution)

__len__() int[source]
__getitem__(index: int, raw: bool = False) Batch[source]
split(ratio: float = 0.5, mode: str = 'outside_traj', recall: bool = False) Tuple[OfflineDataset, OfflineDataset][source]

split the dataset into train and validation with the given ratio and mode

Args:
ratio:

Ratio to split validate dataset if it is not explicitly given.

mode:

Mode of auto splitting training and validation dataset, choose from outside_traj and inside_traj. outside_traj means the split is happened outside the trajectories, one trajectory can only be in one dataset. ‘ + inside_traj means the split is happened inside the trajectories, former part of one trajectory is in training set, later part is in validation set.

Return:

(TrainDataset, ValidateDataset)