使用训练完成的模型 ====================== 当REVIVE SDK完成虚拟环境模型训练和策略模型训练后。我们可以在日志文件夹( ``logs/`` )下找到保存的模型( ``.pkl`` 或 ``.onnx`` )。 使用保存的 ``.pkl`` 模型 ---------------------------- **使用训练好的虚拟环境模型(env.pkl)** 虚拟环境模型会被序列化为 ``env.pkl`` 文件。使用虚拟环境模型时需要使用 ``pickle`` 加载序列化的环境模型,然后使用 ``venv.infer_one_step()`` 函数或 ``venv.infer_k_step()`` 函数进行虚拟环境推理。 .. code:: python import os import pickle import numpy as np # 获得虚拟环境模型的文件路径 venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/run_id", "env.pkl") # 加载虚拟环境模型 venv = pickle.load(open(venv_path, 'rb'), encoding='utf-8') # 生成状态数据 state = {"states": np.array([-0.5, 0.5, 0.2]), "actions": np.array([1.])} # 使用虚拟环境模型进行单时间步的推理 output = venv.infer_one_step(state) print("Virtualenv model 1-step output:", output) # 使用虚拟环境模型进行K个时间步推理 # 返回列表的长度为k,对应于K个时间步的输出。 output = venv.infer_k_steps(state, k=3) print("Virtualenv model k-step output:", output) 使用虚拟环境模型进行推理时,可以额外传入 ``deterministic`` 和 ``clip`` 参数。其中 ``deterministic`` 参数用来决定 输出是否是确定性的。如果为True,则返回最可能的输出;如果为False,则根据模型的概率分布进行采样输出。默认值为True。 ``clip`` 用来决定是否应将输出的裁剪到指定的有效范围内, 裁剪范围是根据 ``*.yaml`` 文件中的配置进行,如果没有配置则从自动从数据中计算最小值和最大值。默认值为True,表示对输出的动作进行剪切。 .. note:: 详情请参考API :doc:`REVIVE API 虚拟环境 <../revive_inference>`. **使用策略模型(policy.pkl)** 策略模型会被序列化为 ``policy.pkl`` 文件。使用策略模型时需要使用 ``pickle`` 加载序列化的决策模型,然后使用 ``policy.infer()`` 函数进行策略模型推理。 .. code:: python import os import pickle import numpy as np # 获得策略模型的文件路径 policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/tmp", "policy.pkl") # 加载策略模型 policy = pickle.load(open(policy_path, 'rb'), encoding='utf-8') # 生成状态数据 state = {"states": np.array([-0.5, 0.5, 0.2])} print("Policy model input state:", state) # 使用策略模型进行推理,输出动作 action = policy.infer(state) print("Policy model output action:", action) 策略模型进行推理时,同样可以额外传入 ``deterministic`` 和 ``clip`` 参数。其中 ``deterministic`` 参数用来决定 输出是否是确定性的。如果为True,则返回最可能的输出策略动作;如果为False,则根据策略模型的概率分布进行采样动作输出。默认值为True。 ``clip`` 用来决定是否应将输出的动作剪切到动作空间的有效范围内, 裁剪范围是根据 ``*.yaml`` 文件中的配置进行,如果没有配置则从自动从数据中计算最小值和最大值。默认值为True,表示对输出的动作进行剪切。 .. note:: 详情请参考API :doc:`REVIVE API 策略模型 <../revive_inference_cn>`. 使用保存的 ``.onnx`` 模型 ---------------------------- 虚拟环境模型和策略模型也会被序列化为 ``.onnx`` 文件,使用 ``.onnx`` 模型可以方便的进行跨平台部署。 **使用虚拟环境模型(env.onnx)** .. code:: python import os import onnxruntime import numpy as np venv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/run_id", "env.onnx") venv = onnxruntime.InferenceSession(venv_path) # onnx 模型现在已经支持灵活的 batch_size,默认在第 0 维 # 确保输入的数据为浮点类型 # venv_input需要是一个python的字典, 与.pkl模型的输入相似 venv_input = {'temperature' : np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1), 'door_open': np.array([1., 0., 1.], dtype=np.float32).reshape(3, -1)} # 输出的数据需要将其名称囊括在venv_output_names列表中 venv_output_names = ["action", "next_temperature"] # 这里将用输入的数据进行一次决策流的推理(类似于 .pkl 的 venv.infer_one_step()) --> 返回一个列表(List) # 输出将以venv_output_names的名称顺序存储在数组中 output = venv.run(input_feed=venv_input, output_names=venv_output_names) print(output) **使用策略模型(policy.onnx)** .. code:: python import os import onnxruntime import numpy as np policy_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs/run_id", "policy.onnx") policy = onnxruntime.InferenceSession(policy_path) # onnx 模型现在已经支持灵活的 batch_size,默认在第 0 维 # 确保输入的数据为浮点类型 # policy_input需要是一个python的字典, 与.pkl模型的输入相似 policy_input = {'temperature' : np.array([0.5, 0.4, 0.3], dtype=np.float32).reshape(3, -1)} # 输出的数据需要将其名称囊括在policy_output_names列表中 policy_output_names = ["action"] # 这里将用输入的数据进行一次决策流的推理(类似于 .pkl 的 policy.infer_one_step()) --> 返回一个列表(List) # 输出将以venv_output_names的名称顺序存储在数组中 output = policy.run(input_feed=policy_input, output_names=policy_output_names) print(output)