Training Model¶
Once we have prepared our training dataset (.npz
or .h5
file), decision flow graph description file (.yaml
), and reward function (reward.py
), we can start training the virtual environment model and policy model using the command python train.py
. This script will instantiate revive.server.ReviveServer
and start training.
Example training script
python train.py -df <training data file path> -cf <decision flow graph file path> -rf <reward function file path> -vm <virtual environment model training mode> -pm <policy model training mode> --run_id <training experiment name>
The command line arguments that can be defined when running the train.py
script are as follows:
-df: The file path of the training data (
.npz
or.h5
file).-vf: The file path of the validation data (optional).
-cf: The file path of the decision flow graph (
.yaml
).-rf: The file path of the defined reward function(
reward.py
)(required only in policy training).-rcf: The
.json
file that supports hyperparameter configuration (optional).-tpn: The name of the policy node. It must be a node defined in the decision flow graph. If not specified, by default, the node ranked first in topological order will be used as the policy node.
-vm: The different modes of training the virtual environment model, including:
once
,tune
,None
.once
mode: REVIVE SDK will train the model using default parameters.tune
mode: REVIVE SDK will use hyperparameter search to train the model, which requires a lot of computing power and time to search for hyperparameters to obtain better model results.None
mode: REVIVE SDK will not train the virtual environment. This is suitable for calling an existing virtual environment only for policy training.
-pm: The training mode of the policy model, including:
once
,tune
,None
。once
mode: REVIVE SDK will train the model using default parameters.tune
mode: REVIVE SDK will use hyperparameter search to train the model, which requires a lot of computing power and time to search for hyperparameters to obtain better model results.None
mode: REVIVE SDK will not train the policy. This is suitable for cases where only the virtual environment needs to be trained without policy training.
--run_id: The name provided by the user for the training experiment. REVIVE will create
logs/<run_id>
as the log directory. If not provided, REVIVE will randomly generate a name.
Training the Virtual Environment and Policy Models
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm once -pm once --run_id once
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm tune -pm tune --run_id tune
Training the Virtual Environment Model Only
python train.py -df test_data.npz -cf test.yaml -vm once -pm None --run_id venv_once
python train.py -df test_data.npz -cf test.yaml -vm tune -pm None --run_id venv_tune
Note
The reward function that needs to be defined is only required during policy training.
Train Policy Model on an Existing Environment Model
Important
When training the policy model separately, the REVIVE SDK will look for the completed training virtual environment model based on the index run_id
.
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm once --run_id venv_once
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -vm None -pm tune --run_id venv_tune
Note
train.py
is a generic script that calls the REVIVE SDK for model training. We can also develop custom training methods based on the recovery.server.ReviveServer
.
See revive.server.ReviveServer``(:doc:`REVIVE API <revive_server>`) in the API for more information on ``recovery.server.ReviveServer
.
During training, we can open the log directory with tensorboard at any time to monitor the training process. After REVIVE SDK completes the virtual environment model training and policy model training, we can find the saved models (.pkl
or .onnx
) in the log folder (logs/<run_id>
).
REVIVE provides a hyperparameter tuning tool, and starting training in tune
mode will switch to hyperparameter tuning mode.
In this mode, REVIVE will sample multiple sets of hyperparameters from the preset hyperparameter space for model training. Generally, using hyperparameter tuning mode can obtain better models. For detailed explanations of the hyperparameter tuning mode, please refer to the REVIVE API (revive.conf <revive_conf).
We can also adjust the hyperparameter search space and search method by modifying the relevant configurations in config.json
.
python train.py -df test_data.npz -cf test.yaml -rf test_reward.py -rcf config.json --run_id test
Note
The example data test_data.npz
, test.yaml
, test_reward.py
and config.json
are stored in the revive/data
folder.