revive.data package¶
Submodules¶
revive.data.batch module¶
- class revive.data.batch.Batch(batch_dict: Optional[Union[dict, Batch, Sequence[Union[dict, Batch]], ndarray]] = None, copy: bool = False, **kwargs: Any)[source]¶
Bases:
object
- __setitem__(index: Union[str, slice, int, ndarray, List[int]], value: Any) None [source]¶
Assign value to self[index].
- to_torch(dtype: Optional[dtype] = None, device: Union[str, int, device] = 'cpu') None [source]¶
Change all numpy.ndarray to torch.Tensor in-place.
- cat_(batches: Union[Batch, Sequence[Union[dict, Batch]]]) None [source]¶
Concatenate a list of (or one) Batch objects into current batch.
- stack_(batches: Sequence[Union[dict, Batch]], axis: int = 0) None [source]¶
Stack a list of Batch object into current batch.
- static empty(batch: Batch, index: Optional[Union[slice, int, ndarray, List[int]]] = None) Batch [source]¶
- update(batch: Optional[Union[dict, Batch]] = None, **kwargs: Any) None [source]¶
Update this batch from another dict/Batch.
- property shape: List[int]¶
Return self.shape.
revive.data.dataset module¶
- class revive.data.dataset.OfflineDataset(data_file: str, config_file: str, revive_config: dict, ignore_check: bool = False, horizon: Optional[int] = None, reward_func=None)[source]¶
Bases:
Dataset
An offline dataset class.
- Params:
- data_file
The file path where the training dataset is stored.
- config_file
The file path where the data description file is stored.
- horizon
Length of iteration trajectory.
- transition_mode_()[source]¶
Set the dataset in transition mode. __getitem__ will return a transition.
- trajectory_mode_(horizon: Optional[int] = None, fix_sample: bool = False)[source]¶
Set the dataset in trajectory mode. __getitem__ will return a clip of trajectory.
- get_dist_configs(model_config)[source]¶
Get the config of distributions for each node based on the given model config.
- Args:
- model_config
The given model config.
- Return:
- dist_configs
config of distributions for each node.
- total_dims
dimensions for each node when it is considered as input and output. (Output dimensions can be different from input dimensions due to the parameterized distribution)
- split(ratio: float = 0.5, mode: str = 'outside_traj', recall: bool = False) Tuple[OfflineDataset, OfflineDataset] [source]¶
split the dataset into train and validation with the given ratio and mode
- Args:
- ratio
Ratio to split validate dataset if it is not explicitly given.
- mode
Mode of auto splitting training and validation dataset, choose from outside_traj and inside_traj. outside_traj means the split is happened outside the trajectories, one trajectory can only be in one dataset. ‘ + inside_traj means the split is happened inside the trajectories, former part of one trajectory is in training set, later part is in validation set.
- Return:
(TrainDataset, ValidateDataset)
- class revive.data.dataset.UniformSampler(data_source, number_samples, replacement=True)[source]¶
Bases:
Sampler
A uniform data sampler
- Args:
data_source (OfflineDataset): dataset to sample from num_samples (int): number of samples to draw. replacement (bool): samples are drawn on-demand with replacement if
True
, default=``False``
- class revive.data.dataset.InfiniteUniformSampler(data_source, number_samples, replacement=True)[source]¶
Bases:
Sampler
A infinite data sampler, sampler that provides infinite length of data index
- Args:
data_source (OfflineDataset): dataset to sample from num_samples (int): number of samples to draw. replacement (bool): samples are drawn on-demand with replacement if
True
, default=``False``
- class revive.data.dataset.InfiniteDataLoader(dataloader: DataLoader)[source]¶
Bases:
object
Wrapper that enables infinite pre-fetching, must use together with InfiniteUniformSampler
- revive.data.dataset.collect_data(expert_data: List[Batch], graph: DesicionGraph) Batch [source]¶
Collection function for PyTorch DataLoader
- revive.data.dataset.get_loader(dataset: OfflineDataset, config: dict, is_sample: bool = True)[source]¶
Get the PoTorch DataLoader for training
- revive.data.dataset.data_creator(config: dict, training_mode: str = 'trajectory', training_horizon: Optional[int] = None, training_is_sample: bool = True, val_mode: str = 'trajectory', val_horizon: Optional[int] = None, val_is_sample: bool = False, double: bool = False)[source]¶
Get train data loader and validation data loader.
- Returns
train data loader and validation data loader
revive.data.expert_function_parsed module¶
revive.data.processor module¶
- class revive.data.processor.DataProcessor(data_configs, processing_params, orders)[source]¶
Bases:
object
This class deal with the data mapping between original format and the computation format.
There are two steps for mapping from original to computation:
Step 1: Reorder the data. This is to group variables with the same type to accelerate computation.
- Step 2: If the variable is continuous or discrete, normalize the data to [-1, 1].
If the variable is categorical, create an onehot vector.
Mapping from computation to original is the reverse of these steps.
- Args:
data_configs (dict): A dictionary containing the configuration of the input data. processing_params (dict): A dictionary containing the processing parameters. orders (list): A list of variable orders for reordering the data.
- property keys¶
- process_single_torch(data: Tensor, key: str) Tensor [source]¶
Preprocess single data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.
- deprocess_single_torch(data: Tensor, key: str) Tensor [source]¶
Post process single data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.
- process_torch(data)[source]¶
Preprocess batch data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.
- deprocess_torch(data)[source]¶
Post process batch data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.
- process_single(data: ndarray, key: str) ndarray [source]¶
Preprocess single data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.
- deprocess_single(data: ndarray, key: str) ndarray [source]¶
Post process single data according different types of data including ‘category’, ‘continuous’, and ‘discrete’.