Using Model ====================== Once the REVIVE SDK has completed the training of the virtual environment model and policy model, we can find the saved models (``.pkl`` or ``.onnx``) in the log folder (``logs/``). Using the saved ``.pkl`` model ---------------------------------------------------- **Using the trained virtual environment model(env.pkl)** The virtual environment model is serialized into the ``env.pkl`` file. When using the virtual environment model, we need to use pickle to load the serialized environment model, and then use the ``venv.infer_one_step()`` function or the ``venv.infer_k_step()`` function to perform virtual environment inference. .. code:: python import os import pickle import numpy as np # Get the file path of the virtual environment model venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/run_id", "env.pkl") # Load the virtual environment model venv = pickle.load(open(venv_path, 'rb'), encoding='utf-8') # Generate state data state = {"states": np.array([-0.5, 0.5, 0.2]), "actions": np.array([1.])} # Perform single time-step inference using the virtual environment model output = venv.infer_one_step(state) print("Virtualenv model 1-step output:", output) # Perform K time-step inference using the virtual environment model # The length of the returned list is k, corresponding to the output of K time steps. output = venv.infer_k_steps(state, k=3) print("Virtualenv model k-step output:", output) When using the virtual environment model for inference, the additional deterministic and clip parameters can be passed in. The ``deterministic`` parameter is used to determine whether the output is deterministic. If it is ``True``, the most likely output is returned. If it is ``False``, the output is sampled according to the probability distribution of the model. The default value is ``True``. The ``clip`` parameter is used to determine whether the output should be clipped to the specified valid range. The clipping range is based on the configuration in the ``*.yaml`` file. If there is no configuration, the minimum and maximum values are automatically calculated from the data. The default value is ``True``, indicating that the output actions are clipped. .. note:: Please refer to the API :doc:`REVIVE API virtual environment <../revive_inference>` for more details. **Using the policy model(policy.pkl)** The policy model is serialized into the ``policy.pkl`` file. When using the policy model, we need to use ``pickle`` to load the serialized decision model, and then use the ``policy.infer()`` function to perform policy model inference. .. code:: python import os import pickle import numpy as np # Get the file path of the policy model policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/tmp", "policy.pkl") # Load the policy model policy = pickle.load(open(policy_path, 'rb'), encoding='utf-8') # Generate state data state = {"states": np.array([-0.5, 0.5, 0.2])} print("Policy model input state:", state) # Perform inference using the policy model and output the action action = policy.infer(state) print("Policy model output action:", action) When conducting inference with the policy model, additional parameters of ``deterministic`` and ``clip`` can be passed in. The ``deterministic`` parameter is used to determine whether the output is deterministic. If it is set to ``True``, the most likely output policy action will be returned. If it is set to ``False``, the output action will be sampled based on the probability distribution of the policy model. The default value is ``True``. The ``clip`` parameter is used to determine whether the output action should be clipped to the valid range of the action space. The clipping range is based on the configuration in the ``*.yaml`` file. If there is no configuration, the minimum and maximum values will be automatically computed from the data. The default value is ``True`` , which means the output action will be clipped. .. note:: For more details, please refer to the API documentation :doc:`REVIVE API policy model <../revive_inference>`. Using the saved ``.onnx`` model ------------------------------------------------ The virtual environment model and policy model can also be serialized into ``.onnx`` files, which facilitates cross-platform deployment. **Using the virtual environment model(env.onnx)** .. code:: python import os import onnxruntime import numpy as np venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/run_id", "env.onnx") venv = onnxruntime.InferenceSession(venv_path) # Now onnx model supoort flexible batch_size, default on dim: 0 # Ensure that the input data is of float type # venv_input needs to be a python dictionary, similar to the input of the .pkl model venv_input = {'temperature' : np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1), 'door_open': np.array([1., 0., 1.], dtype=np.float32).reshape(3, -1)} # The output data needs to include its name in the venv_output_names list venv_output_names = ["action", "next_temperature"] # Here, one decision flow inference (similar to venv.infer_one_step() of .pkl) will be performed with the input data --> returning a list # The output will be stored in the array in the order of the names in venv_output_names output = venv.run(input_feed=venv_input, output_names=venv_output_names) print(output) **Using the policy model(policy.onnx)** .. code:: python import os import onnxruntime import numpy as np policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/run_id", "policy.onnx") policy = onnxruntime.InferenceSession(policy_path) # Now onnx model supoort flexible batch_size, default on dim: 0 # Ensure that the input data is of float type # policy_input needs to be a python dictionary, similar to the input of the .pkl model policy_input = {'temperature' : np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1)} # The output data needs to include its name in the policy_output_names list policy_output_names = ["action"] # Here, one decision flow inference (similar to policy.infer_one_step() of .pkl) will be performed with the input data --> returning a list # The output will be stored in the array in the order of the names in policy_output_names output = policy.run(input_feed=policy_input, output_names=policy_output_names) print(output)