revive.algo

revive.algo.venv.base

class revive.algo.venv.base.VenvOperator(*args, **kwargs)[source]

Bases: object

The base venv class.

NAME = None

Name of the used algorithm.

property metric_name

This define the metric we try to minimize with hyperparameter search.

property nodes_models_train
property other_models_train
property nodes_models_val
property other_models_val
PARAMETER_DESCRIPTION = []
classmethod get_parameters(command=None, **kargs)[source]
classmethod get_tune_parameters(config: dict, **kargs)[source]

Use ray.tune to wrap the parameters to be searched.

model_creator(config: dict, graph: DesicionGraph)[source]

Create all the models. The algorithm needs to define models for the nodes to be learned.

Args:
config

configuration parameters

Return:

a list of models

optimizer_creator(models: List[Module], config: dict)[source]

Define optimizers for the created models.

Args:
pmodels

list of all the models

config

configuration parameters

Return:

a list of optimizers

data_creator(config: dict)[source]

Create DataLoaders.

Args:
config

configuration parameters

Return:

(train_loader, val_loader)

nan_in_grad()[source]
before_train_epoch(*args, **kwargs)[source]
train_epoch(*args, **kwargs)[source]
validate(*args, **kwargs)[source]
train_batch(expert_data, batch_info, scope='train')[source]

Define the training process for an batch data.

validate_batch(expert_data, batch_info, scope='valEnv_on_trainData')[source]

Define the validate process for an batch data.

Args:

expert_data: The batch offline Data.

batch_info: A batch info dict.

scope: if scope=valEnv_on_trainData means training data test on the model trained by validation dataset.

revive.algo.policy.base

class revive.algo.policy.base.PolicyOperator(*args, **kwargs)[source]

Bases: object

property env
property policy
property val_policy
property other_models
PARAMETER_DESCRIPTION = []
classmethod get_parameters(command=None, **kargs)[source]
classmethod get_tune_parameters(config: Dict[str, Any], **kargs)[source]

Use ray.tune to wrap the parameters to be searched.

model_creator(config: Dict[str, Any], node: FunctionDecisionNode) List[Module][source]

Create all the models. The algorithm needs to define models for the nodes to be learned.

Args:
config

configuration parameters

Return:

a list of models

optimizer_creator(models: List[Module], config: Dict[str, Any]) List[Optimizer][source]

Define optimizers for the created models.

Args:
pmodels

list of all the models

config

configuration parameters

Return:

a list of optimizers

data_creator(config: Dict[str, Any])[source]

Create DataLoaders.

Args:
config

configuration parameters

Return:

(train_loader, val_loader)

get_ope_dataset()[source]

convert the dataset to OPEDataset used in d3pe

venv_test(expert_data: Batch, target_policy, traj_length=None, scope: str = 'trainPolicy_on_valEnv')[source]

Use the virtual env model to test the policy model

generate_rollout(expert_data: Batch, target_policy, env: Union[VirtualEnvDev, List[VirtualEnvDev]], traj_length: int, maintain_grad_flow: bool = False, deterministic: bool = True, clip: bool = False)[source]

Generate trajectories based on current policy.

Args:
expert_data

sampled data from the dataset.

:target_policy: target_policy

:env: env

:traj_length: traj_length

:maintain_grad_flow: maintain_grad_flow

Return:

batch trajectories

before_train_epoch(*args, **kwargs)[source]
train_epoch(*args, **kwargs)[source]
validate(*args, **kwargs)[source]
train_batch(expert_data: Batch, batch_info: Dict[str, float], scope: str = 'train')[source]
validate_batch(expert_data: Batch, batch_info: Dict[str, float], scope: str = 'trainPolicy_on_valEnv')[source]