revive.data package

Submodules

revive.data.batch module

class revive.data.batch.Batch(batch_dict: dict | Batch | Sequence[dict | Batch] | ndarray | None = None, copy: bool = False, **kwargs: Any)[source]

Bases: object

__getitem__(index: str | slice | int | ndarray | List[int]) Any[source]

Return self[index].

__setitem__(index: str | slice | int | ndarray | List[int], value: Any) None[source]

Assign value to self[index].

to_numpy() None[source]

Change all torch.Tensor to numpy.ndarray in-place.

detach() None[source]

Detach the tensor in batch data

to_torch(dtype: dtype | None = None, device: str | int | device = 'cpu') None[source]

Change all numpy.ndarray to torch.Tensor in-place.

cat_(batches: Batch | Sequence[dict | Batch]) None[source]

Concatenate a list of (or one) Batch objects into current batch.

static cat(batches: Sequence[dict | Batch]) Batch[source]
stack_(batches: Sequence[dict | Batch], axis: int = 0) None[source]

Stack a list of Batch object into current batch.

static stack(batches: Sequence[dict | Batch], axis: int = 0) Batch[source]
empty_(index: slice | int | ndarray | List[int] | None = None) Batch[source]
static empty(batch: Batch, index: slice | int | ndarray | List[int] | None = None) Batch[source]
update(batch: dict | Batch | None = None, **kwargs: Any) None[source]

Update this batch from another dict/Batch.

__len__() int[source]

Return len(self).

is_empty(recurse: bool = False) bool[source]
property shape: List[int]

Return self.shape.

split(size: int, dim: int = 0, shuffle: bool = True, merge_last: bool = False) Iterator[Batch][source]

revive.data.dataset module

class revive.data.dataset.OfflineDataset(data_file: str, config_file: str, revive_config: dict | None = None, ignore_check: bool = False, horizon: int | None = None, reward_func=None)[source]

Bases: Dataset

An offline dataset class.

Params:
data_file:

The file path where the training dataset is stored.

config_file:

The file path where the data description file is stored.

horizon:

Length of iteration trajectory.

compute_missing_data(need_truncate)[source]
transition_mode_()[source]

Set the dataset in transition mode. __getitem__ will return a transition.

trajectory_mode_(horizon: int | None = None, fix_sample: bool = False)[source]

Set the dataset in trajectory mode. __getitem__ will return a clip of trajectory.

set_horizon(horizon: int)[source]

Set the horzion for loading data

get_dist_configs(model_config)[source]

Get the config of distributions for each node based on the given model config.

Args:
model_config:

The given model config.

Return:
dist_configs:

config of distributions for each node.

total_dims:

dimensions for each node when it is considered as input and output. (Output dimensions can be different from input dimensions due to the parameterized distribution)

__len__() int[source]
__getitem__(index: int, raw: bool = False) Batch[source]
split(ratio: float = 0.5, mode: str = 'outside_traj', recall: bool = False) Tuple[OfflineDataset, OfflineDataset][source]

split the dataset into train and validation with the given ratio and mode

Args:
ratio:

Ratio to split validate dataset if it is not explicitly given.

mode:

Mode of auto splitting training and validation dataset, choose from outside_traj and inside_traj. outside_traj means the split is happened outside the trajectories, one trajectory can only be in one dataset. ‘ + inside_traj means the split is happened inside the trajectories, former part of one trajectory is in training set, later part is in validation set.

Return:

(TrainDataset, ValidateDataset)

class revive.data.dataset.RNNOfflineDataset(data_file: str, config_file: str, revive_config: dict | None = None, ignore_check: bool = False, horizon: int | None = None, reward_func=None)[source]

Bases: OfflineDataset

trajectory_mode_(horizon: int | None = None, fix_sample: bool = False)[source]

Set the dataset in trajectory mode. __getitem__ will return a clip of trajectory.

__len__() int[source]
__getitem__(index: int, raw: bool = False) Batch[source]
class revive.data.dataset.UniformSampler(data_source, number_samples, replacement=True)[source]

Bases: Sampler

A uniform data sampler

Args:

data_source (OfflineDataset): dataset to sample from num_samples (int): number of samples to draw. replacement (bool): samples are drawn on-demand with replacement if True, default=``False``

__len__()[source]
class revive.data.dataset.InfiniteUniformSampler(data_source, number_samples, replacement=True)[source]

Bases: Sampler

A infinite data sampler, sampler that provides infinite length of data index

Args:

data_source (OfflineDataset): dataset to sample from num_samples (int): number of samples to draw. replacement (bool): samples are drawn on-demand with replacement if True, default=``False``

class revive.data.dataset.InfiniteDataLoader(dataloader: DataLoader)[source]

Bases: object

Wrapper that enables infinite pre-fetching, must use together with InfiniteUniformSampler

revive.data.dataset.collect_data(expert_data: List[Batch], graph: DesicionGraph) Batch[source]

Collection function for PyTorch DataLoader

revive.data.dataset.pad_collect_data(expert_data: List[Batch], graph: DesicionGraph) Batch[source]

Collection function for PyTorch DataLoader

revive.data.dataset.get_loader(dataset: OfflineDataset, config: dict, is_sample: bool = True, rnn=False)[source]

Get the PoTorch DataLoader for training

revive.data.dataset.data_creator(config: dict, training_mode: str = 'trajectory', training_horizon: int | None = None, training_is_sample: bool = True, val_mode: str = 'trajectory', val_horizon: int | None = None, val_is_sample: bool = False, pre_horzion: int = 0, double: bool = False)[source]

Get train data loader and validation data loader.

Returns:

train data loader and validation data loader

revive.data.dataset.revive_f_rnn_data_creator(config: dict, training_mode: str = 'trajectory', training_horizon: int | None = None, training_is_sample: bool = True, val_mode: str = 'trajectory', val_horizon: int | None = None, val_is_sample: bool = False, pre_horzion: int = 0, double: bool = False)[source]

Get train data loader and validation data loader.

Returns:

train data loader and validation data loader

revive.data.processor module

class revive.data.processor.DataProcessor(data_configs, processing_params, orders)[source]

Bases: object

This class deal with the data mapping between original format and the computation format.

There are two steps for mapping from original to computation:

Step 1: Reorder the data. This is to group variables with the same type to accelerate computation.

Step 2: If the variable is continuous or discrete, normalize the data to [-1, 1].

If the variable is categorical, create an onehot vector.

Mapping from computation to original is the reverse of these steps.

Args:

data_configs (dict): A dictionary containing the configuration of the input data. processing_params (dict): A dictionary containing the processing parameters. orders (list): A list of variable orders for reordering the data.

property keys
process_single_torch(data: Tensor, key: str) Tensor[source]

Preprocess single data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.

deprocess_single_torch(data: Tensor, key: str) Tensor[source]

Post process single data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.

process_torch(data)[source]

Preprocess batch data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.

deprocess_torch(data)[source]

Post process batch data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.

process_single(data: ndarray, key: str) ndarray[source]

Preprocess single data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.

deprocess_single(data: ndarray, key: str) ndarray[source]

Post process single data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.

process(data: Dict[str, ndarray])[source]

Preprocess batch data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.

deprocess(data: Dict[str, ndarray])[source]

Post process batch data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.

Module contents