revive.server

class revive.server.ReviveServer(dataset_file_path: str, dataset_desc_file_path: str, val_file_path: str | None = None, user_module_file_path: str | None = None, matcher_reward_file_path: str | None = None, reward_file_path: str | None = None, target_policy_name: str | None = None, log_dir: str | None = None, run_id: str | None = None, address: str | None = None, venv_mode: str = 'tune', policy_mode: str = 'tune', tuning_mode: str = 'None', tune_initial_state: Dict[str, ndarray] | None = None, debug: bool = False, revive_config_file_path: str | None = None, **kwargs)[source]

Bases: object

A class that uses ray to manage all the training tasks. It can automatic search for optimal hyper-parameters.

ReviveServer will do five steps to initialize:

1. Create or connect to a ray cluster. The behavior is controlled by address parameter. If the address parameter is None, it will create its own cluster. If the address parameter is specified, it will connect to the existing cluster.

2. Load config for training. The config is stored in revive/config.py. You can change these parameters by editing the file, passing through command line or through custom_config.

3. Load data and its config, register reward function. The data files are specified by parameters dataset_file_path, dataset_desc_file_path and val_file_path. Note the val_file_path is optional. If it is not specified, revive will split the training data. All the data will be put into the ray object store to share among the whole cluster.

4. Create the folder to store results. The top level folder of these logs are controlled by log_dir parameter. If it is not provided, the default value is the logs folder under the revive repertory. The second-level folder is controlled by the run_id parameter in the training config. If it is not specified, we will generate a random id for the folder. All the training results will be placed in the second-level folder.

5. Create result server as ray actor, and try to load existing results in the log folder. This class is very useful when you want to train a policy or tune parameters from an already trained simulator.

Initialization a Revive Server.

Args:
dataset_file_path (str):

The file path where the training dataset is stored. If the val_file_path is “None”, Some data will be cut out from the training dataset as the validation dataset. (e.g., “/data/data.npz” )

dataset_desc_file_path (str):

The file path where the data description file is stored. (e.g., “/data/test.yaml” )

val_file_path (str):

The file path where the validate dataset is stored. If it’s “None”, the validation dataset will be cut out from the training dataset.

reward_file_path (str):

The storage path of the file that defines the reward function.

target_policy_name (str):

Name of target policy to be optimized. Maximize the defined reward by optimizing the policy. If it is None, the first policy in the graph will be chosen.

log_dir (str):

Training log and saved model storage folder.

run_id (str):

The ID of the current running experiment is used to distinguish different training. When it is not provided, an ID will be automatically generated

address (str):

The address of the ray cluster, If the address parameter is None, it will create its own cluster.

venv_mode (“tune”,”once”,”None”):

Control the mode of venv training. tune means conducting hyper-parameter search; once means train with the default hyper-parameters; None means skip.’

policy_mode (“tune”,”once”,”None”):

Control the mode of venv training. tune means conducting hyper-parameter search; once means train with the default hyper-parameters; None means skip.’

tuning_mode (“max”,”min”,”None”):

Control the mode of parameter tuning. max and min means enabling tuning and the direction; None means skip.’ This feature is currently unstable

tune_initial_state (str):

Initial state of parameter tuning, needed when tuning mode is enabled.

debug (bool):

If it is True, Will enter debug mode for debugging.

custom_config:

Json file path. The file content can be used to override the default parameters.

kwargs:

Keyword parameters can be used to override default parameters

revive_config_file_path

preprocess config

train(env_save_path: str | None = None)[source]

Train the virtual environment and policy. Steps

  1. Start ray worker train the virtual environment based on the data;

  2. Start ray worker train train policy based on the virtual environment.

train_venv()[source]

Start ray worker train the virtual environment based on the data;

train_policy(env_save_path: str | None = None)[source]

Start ray worker train train policy based on the virtual environment.

Args:
env_save_path:

virtual environments path

Note

Before train policy, environment models and reward function should be provided.

tune_parameter(env_save_path: str | None = None)[source]

Tune parameters on specified virtual environments.

Args:
env_save_path:

virtual environments path

Note

This feature is currently unstable.

stop_train() None[source]

Stop all training tasks.

get_virtualenv_env() Tuple[VirtualEnv, Dict[str, str | float], Dict[int, Tuple[str, str]]][source]

Get virtual environment models and train log.

Returns:

virtual environment models and train log

get_policy_model() Tuple[PolicyModel, Dict[str, str | float], Dict[int, Tuple[str, str]]][source]

Get policy based on specified virtual environments.

Return:

policy models and train log

get_parameter() Tuple[ndarray, Dict[str, str | float]][source]

Get tuned parameters based on specified virtual environments.

Return:

current best parameters and training log