revive.algo¶
revive.algo.venv.base¶
- class revive.algo.venv.base.VenvOperator(*args, **kwargs)[source]¶
Bases:
object
The base venv class.validate_epoch
- NAME = None¶
Name of the used algorithm.
- property metric_name¶
This define the metric we try to minimize with hyperparameter search.
- property nodes_models_train¶
- property other_models_train¶
- property nodes_models_val¶
- property other_models_val¶
- PARAMETER_DESCRIPTION = []¶
- classmethod get_tune_parameters(config: dict, **kargs)[source]¶
Use ray.tune to wrap the parameters to be searched.
- model_creator(config: dict, graph: DesicionGraph)[source]¶
Create all the models. The algorithm needs to define models for the nodes to be learned.
- Args:
- config:
configuration parameters
- Return:
a list of models
- optimizer_creator(models: List[Module], config: dict)[source]¶
Define optimizers for the created models.
- Args:
- pmodels:
list of all the models
- config:
configuration parameters
- Return:
a list of optimizers
- data_creator()[source]¶
Create DataLoaders.
- Args:
- config:
configuration parameters
- Return:
(train_loader, val_loader)
- train_batch(expert_data, batch_info, scope='train')[source]¶
Define the training process for an batch data.
- validate_batch(expert_data, batch_info, scope='valEnv_on_trainData', loss_mask=None)[source]¶
Define the validate process for an batch data.
- Args:
expert_data: The batch offline Data.
batch_info: A batch info dict.
scope: if
scope=valEnv_on_trainData
means training data test on the model trained by validation dataset.
revive.algo.policy.base¶
- class revive.algo.policy.base.PolicyOperator(*args, **kwargs)[source]¶
Bases:
object
- property env¶
- property train_policy¶
- property val_policy¶
- property policy¶
- property other_train_models¶
- property other_val_models¶
- PARAMETER_DESCRIPTION = []¶
- classmethod get_tune_parameters(config: Dict[str, Any], **kargs)[source]¶
Use ray.tune to wrap the parameters to be searched.
- optimizer_creator(models: List[Module], config: Dict[str, Any]) List[Optimizer] [source]¶
Define optimizers for the created models.
- Args:
- pmodels:
list of all the models
- config:
configuration parameters
- Return:
a list of optimizers
- data_creator()[source]¶
Create DataLoaders.
- Args:
- config:
configuration parameters
- Return:
(train_loader, val_loader)
- venv_test(expert_data: Batch, target_policy, traj_length=None, scope: str = 'trainPolicy_on_valEnv')[source]¶
Use the virtual env model to test the policy model
- generate_rollout(expert_data: Batch, target_policy, env: VirtualEnvDev | List[VirtualEnvDev], traj_length: int, maintain_grad_flow: bool = False, deterministic: bool = True, clip: bool = False)[source]¶
Generate trajectories based on current policy.
- Args:
- expert_data:
sampled data from the dataset.
:target_policy: target_policy :env: env :traj_length: traj_length :maintain_grad_flow: maintain_grad_flow
- Return:
batch trajectories