revive.data¶
- class revive.data.dataset.OfflineDataset(data_file: str, config_file: str, revive_config: dict | None = None, ignore_check: bool = False, horizon: int | None = None, reward_func=None)[source]¶
Bases:
Dataset
An offline dataset class.
- Params:
- data_file:
The file path where the training dataset is stored.
- config_file:
The file path where the data description file is stored.
- horizon:
Length of iteration trajectory.
- transition_mode_()[source]¶
Set the dataset in transition mode. __getitem__ will return a transition.
- trajectory_mode_(horizon: int | None = None, fix_sample: bool = False)[source]¶
Set the dataset in trajectory mode. __getitem__ will return a clip of trajectory.
- get_dist_configs(model_config)[source]¶
Get the config of distributions for each node based on the given model config.
- Args:
- model_config:
The given model config.
- Return:
- dist_configs:
config of distributions for each node.
- total_dims:
dimensions for each node when it is considered as input and output. (Output dimensions can be different from input dimensions due to the parameterized distribution)
- split(ratio: float = 0.5, mode: str = 'outside_traj', recall: bool = False) Tuple[OfflineDataset, OfflineDataset] [source]¶
split the dataset into train and validation with the given ratio and mode
- Args:
- ratio:
Ratio to split validate dataset if it is not explicitly given.
- mode:
Mode of auto splitting training and validation dataset, choose from outside_traj and inside_traj. outside_traj means the split is happened outside the trajectories, one trajectory can only be in one dataset. ‘ + inside_traj means the split is happened inside the trajectories, former part of one trajectory is in training set, later part is in validation set.
- Return:
(TrainDataset, ValidateDataset)