Using Model¶
Once the REVIVE SDK has completed the training of the virtual environment model and policy model, we can find the saved models (.pkl
or .onnx
) in the log folder (logs/<run_id>
).
Using the saved .pkl
model¶
Using the trained virtual environment model(env.pkl)
The virtual environment model is serialized into the env.pkl
file. When using the virtual environment model, we need to use pickle to load the serialized environment model, and then use the venv.infer_one_step()
function or the venv.infer_k_step()
function to perform virtual environment inference.
import os
import pickle
import numpy as np
# Get the file path of the virtual environment model
venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/run_id", "env.pkl")
# Load the virtual environment model
venv = pickle.load(open(venv_path, 'rb'), encoding='utf-8')
# Generate state data
state = {"states": np.array([-0.5, 0.5, 0.2]), "actions": np.array([1.])}
# Perform single time-step inference using the virtual environment model
output = venv.infer_one_step(state)
print("Virtualenv model 1-step output:", output)
# Perform K time-step inference using the virtual environment model
# The length of the returned list is k, corresponding to the output of K time steps.
output = venv.infer_k_steps(state, k=3)
print("Virtualenv model k-step output:", output)
When using the virtual environment model for inference, the additional deterministic and clip parameters can be passed in. The deterministic
parameter is used to determine whether the output is deterministic. If it is True
, the most likely output is returned. If it is False
, the output is sampled according to the probability distribution of the model. The default value is True
. The clip
parameter is used to determine whether the output should be clipped to the specified valid range. The clipping range is based on the configuration in the *.yaml
file. If there is no configuration, the minimum and maximum values are automatically calculated from the data. The default value is True
, indicating that the output actions are clipped.
Note
Please refer to the API REVIVE API virtual environment for more details.
Using the policy model(policy.pkl)
The policy model is serialized into the policy.pkl
file. When using the policy model, we need to use pickle
to load the serialized decision model, and then use the policy.infer()
function to perform policy model inference.
import os
import pickle
import numpy as np
# Get the file path of the policy model
policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/tmp", "policy.pkl")
# Load the policy model
policy = pickle.load(open(policy_path, 'rb'), encoding='utf-8')
# Generate state data
state = {"states": np.array([-0.5, 0.5, 0.2])}
print("Policy model input state:", state)
# Perform inference using the policy model and output the action
action = policy.infer(state)
print("Policy model output action:", action)
When conducting inference with the policy model, additional parameters of deterministic
and clip
can be passed in. The deterministic
parameter is used to determine whether the output is deterministic. If it is set to True
, the most likely output policy action will be returned. If it is set to False
, the output action will be sampled based on the probability distribution of the policy model. The default value is True
.
The clip
parameter is used to determine whether the output action should be clipped to the valid range of the action space. The clipping range is based on the configuration in the *.yaml
file. If there is no configuration, the minimum and maximum values will be automatically computed from the data. The default value is True
, which means the output action will be clipped.
Note
For more details, please refer to the API documentation REVIVE API policy model.
Using the saved .onnx
model¶
The virtual environment model and policy model can also be serialized into .onnx
files, which facilitates cross-platform deployment.
Using the virtual environment model(env.onnx)
import os
import onnxruntime
import numpy as np
venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/run_id", "env.onnx")
venv = onnxruntime.InferenceSession(venv_path)
# Now onnx model supoort flexible batch_size, default on dim: 0
# Ensure that the input data is of float type
# venv_input needs to be a python dictionary, similar to the input of the .pkl model
venv_input = {'temperature' : np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1), 'door_open': np.array([1., 0., 1.], dtype=np.float32).reshape(3, -1)}
# The output data needs to include its name in the venv_output_names list
venv_output_names = ["action", "next_temperature"]
# Here, one decision flow inference (similar to venv.infer_one_step() of .pkl) will be performed with the input data --> returning a list
# The output will be stored in the array in the order of the names in venv_output_names
output = venv.run(input_feed=venv_input, output_names=venv_output_names)
print(output)
Using the policy model(policy.onnx)
import os
import onnxruntime
import numpy as np
policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/run_id", "policy.onnx")
policy = onnxruntime.InferenceSession(policy_path)
# Now onnx model supoort flexible batch_size, default on dim: 0
# Ensure that the input data is of float type
# policy_input needs to be a python dictionary, similar to the input of the .pkl model
policy_input = {'temperature' : np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1)}
# The output data needs to include its name in the policy_output_names list
policy_output_names = ["action"]
# Here, one decision flow inference (similar to policy.infer_one_step() of .pkl) will be performed with the input data --> returning a list
# The output will be stored in the array in the order of the names in policy_output_names
output = policy.run(input_feed=policy_input, output_names=policy_output_names)
print(output)