revive.computation¶
- class revive.computation.inference_cn.VirtualEnv(env_list: List[VirtualEnvDev])[source]¶
Bases:
object
- to(device)[source]¶
切换模型所在设备,可以指定cpu或cuda。
示例:
>>> venv_model.to("cpu") >>> venv_model.to("cuda") >>> venv_model.to("cuda:1")
- property target_policy_name: str¶
获得策略节点的名称。
- replace_policy(policy: PolicyModel) None [source]¶
使用给定的策略节点模型代替当前的策略节点模型。
- infer_one_step(state: Dict[str, ndarray], deterministic: bool = True, clip: bool = True) Dict[str, ndarray] [source]¶
生成1步交互数据,1步表示决策流图完整的运行一遍。
- 参数:
- states:
包含初始输入节点数据的字典,初始的节点数据应包括决策流图的所有叶子节点。
- deterministic:
如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True
- 返回值:
字典,含有1步交互数据,key是节点名,value节点数据数组。
示例:
>>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> one_step_output = venv_model.infer_one_step(state)
- infer_k_steps(states: Dict[str, ndarray] | List[Dict[str, ndarray]], k: int | None = None, deterministic: bool = True, clip: bool = True) List[Dict[str, ndarray]] [source]¶
生成k步交互数据,每一步表示决策流图完整的运行一遍。
- 参数:
- states:
包含初始输入节点数据的字典,初始的节点数据应包括决策流图的所有叶子节点。
- k:
正整数,如果是1,则返回一步的交互数据;如果是10,则决策流图迭代的运行10次,返回10次的数据。
- deterministic:
如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True
- 返回值:
字典,含有k步交互数据,key是节点名,value是含有k步该节点数据的数组。
示例:
>>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> ten_step_output = venv_model.infer_k_steps(state, k=10)
- node_infer(node_name: str, state: Dict[str, ndarray], deterministic: bool = True, clip: bool = True) Dict[str, ndarray] [source]¶
使用指定节点模型进行推理.
- 参数:
- state:
包含节点所有输入数据的字典。
- deterministic:
如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True。
- 返回值:
节点输出。
示例:
>>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action_output = venv_model.node_infer("action", state)
- export2onnx(onnx_file: str, verbose: bool = True)[source]¶
导出环境模型为onnx格式。
参考: https://pytorch.org/docs/stable/onnx.html
- 参数:
- onnx_file:
存储onnx模型的文件地址。
- verbose:
默认为True。 如果为True,打印导出到的模型的描述,最终的ONNX图将包括导出模型中的字段doc_string,其中提到model的源代码位置。.
- class revive.computation.inference_cn.PolicyModel(policy_model_dev: PolicyModelDev, post_process: Callable[[Dict[str, ndarray], ndarray], ndarray] | None = None)[source]¶
Bases:
object
- to(device: str)[source]¶
切换模型所在设备,可以指定cpu或cuda。
示例:
>>> policy_model.to("cpu") >>> policy_model.to("cuda") >>> policy_model.to("cuda:1")
- property target_policy_name: None¶
获得策略节点的名称。
- infer(state: Dict[str, ndarray], deterministic: bool = True, clip: bool = True, additional_info: Dict[str, ndarray] | None = None) ndarray [source]¶
使用策略模型进行推理,输出动作.
- 参数:
- state:
包含策略节点所有输入数据的字典
- deterministic:
如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True。
- clip:
如果为True,输出的动作数值将会被裁剪到YAML文件中配置的范围; 如果为False,不对输出的动作数值进行裁剪,输出值可能存在越界的情况。 默认值:True
- additional_info:
默认为None即可。
- 返回值:
动作。
示例:
>>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action = policy_model.infer(state)
- export2onnx(onnx_file: str, verbose: bool = True)[source]¶
导出策略模型为onnx格式。
参考: https://pytorch.org/docs/stable/onnx.html
- 参数:
- onnx_file:
存储onnx模型的文件地址。
- verbose:
默认为True。 如果为True,打印导出到的模型的描述,最终的ONNX图将包括导出模型中的字段doc_string,其中提到model的源代码位置。.