Training Model

Once we have prepared our training dataset (.npz or .h5 file), decision flow graph description file (.yaml), and reward function (reward.py), we can start training the virtual environment model and policy model using the command python train.py. This script will instantiate revive.server.ReviveServer and start training.

Example training script

python train.py -df <training data file path> -cf <decision flow graph file path> -rf <reward function file path> -vm <virtual environment model training mode> -pm <policy model training mode> --run_id <training experiment name>

The command line arguments that can be defined when running the train.py script are as follows:

  • -df: The file path of the training data (.npz or .h5 file).

  • -vf: The file path of the validation data (optional).

  • -cf: The file path of the decision flow graph ( .yaml ).

  • -rf: The file path of the defined reward function( reward.py )(required only in policy training).

  • -rcf: The .json file that supports hyperparameter configuration (optional).

  • -tpn: The name of the policy node. It must be a node defined in the decision flow graph. If not specified, by default, the node ranked first in topological order will be used as the policy node.

  • -vm: The different modes of training the virtual environment model, including:once,tune,None.

    • once mode: REVIVE SDK will train the model using default parameters.

    • tune mode: REVIVE SDK will use hyperparameter search to train the model, which requires a lot of computing power and time to search for hyperparameters to obtain better model results.

    • None mode: REVIVE SDK will not train the virtual environment. This is suitable for calling an existing virtual environment only for policy training.

  • -pm: The training mode of the policy model, including: once,tune,None

    • once mode: REVIVE SDK will train the model using default parameters.

    • tune mode: REVIVE SDK will use hyperparameter search to train the model, which requires a lot of computing power and time to search for hyperparameters to obtain better model results.

    • None mode: REVIVE SDK will not train the policy. This is suitable for cases where only the virtual environment needs to be trained without policy training.

  • --run_id: The name provided by the user for the training experiment. REVIVE will create logs/<run_id> as the log directory. If not provided, REVIVE will randomly generate a name.

Training the Virtual Environment and Policy Models

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm once -pm once --run_id once

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm tune -pm tune --run_id tune

Training the Virtual Environment Model Only

python train.py -df test_data.npz -cf test.yaml -vm once -pm None --run_id venv_once

python train.py -df test_data.npz -cf test.yaml -vm tune -pm None --run_id venv_tune

Note

The reward function that needs to be defined is only required during policy training.

Train Policy Model on an Existing Environment Model

Important

When training the policy model separately, the REVIVE SDK will look for the completed training virtual environment model based on the index run_id.

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm once --run_id venv_once

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm tune --run_id venv_tune

Note

train.py is a generic script that calls the REVIVE SDK for model training. We can also develop custom training methods based on the recovery.server.ReviveServer.

See revive.server.ReviveServer``(:doc:`REVIVE API <revive_server>`) in the API for more information on ``recovery.server.ReviveServer. During training, we can open the log directory with tensorboard at any time to monitor the training process. After REVIVE SDK completes the virtual environment model training and policy model training, we can find the saved models (.pkl or .onnx) in the log folder (logs/<run_id>).

REVIVE provides a hyperparameter tuning tool, and starting training in tune mode will switch to hyperparameter tuning mode. In this mode, REVIVE will sample multiple sets of hyperparameters from the preset hyperparameter space for model training. Generally, using hyperparameter tuning mode can obtain better models. For detailed explanations of the hyperparameter tuning mode, please refer to the REVIVE API (revive.conf <revive_conf). We can also adjust the hyperparameter search space and search method by modifying the relevant configurations in config.json.

python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -rcf config.json --run_id test

Note

The example data test_data.npz, test.yaml, test_reward.py and config.json are stored in the revive/data folder.