revive.computation

class revive.computation.inference_cn.VirtualEnv(env_list: List[VirtualEnvDev])[source]

Bases: object

to(device)[source]

切换模型所在设备,可以指定cpu或cuda。

示例:

>>> venv_model.to("cpu")
>>> venv_model.to("cuda")
>>> venv_model.to("cuda:1")
check_version()[source]

检查训练模型使用的REVIVE SDK版本与当前安装的REVIVE SDK版本是否一致。

reset() None[source]

重置模型隐藏层信息,使用RNN训练的模型需要在每次开始使用时调用该方法。

property target_policy_name: str

获得策略节点的名称。

set_target_policy_name(target_policy_name) None[source]

设置策略节点的名称。

replace_policy(policy: PolicyModel) None[source]

使用给定的策略节点模型代替当前的策略节点模型。

infer_one_step(state: Dict[str, ndarray], deterministic: bool = True, clip: bool = True) Dict[str, ndarray][source]

生成1步交互数据,1步表示决策流图完整的运行一遍。

参数:
states

包含初始输入节点数据的字典,初始的节点数据应包括决策流图的所有叶子节点。

deterministic

如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True

返回值:

字典,含有1步交互数据,key是节点名,value节点数据数组。

示例:

>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> one_step_output = venv_model.infer_one_step(state)
infer_k_steps(states: Union[Dict[str, ndarray], List[Dict[str, ndarray]]], k: Optional[int] = None, deterministic: bool = True, clip: bool = True) List[Dict[str, ndarray]][source]

生成k步交互数据,每一步表示决策流图完整的运行一遍。

参数:
states

包含初始输入节点数据的字典,初始的节点数据应包括决策流图的所有叶子节点。

k

正整数,如果是1,则返回一步的交互数据;如果是10,则决策流图迭代的运行10次,返回10次的数据。

deterministic

如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True

返回值:

字典,含有k步交互数据,key是节点名,value是含有k步该节点数据的数组。

示例:

>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> ten_step_output = venv_model.infer_k_steps(state, k=10)
node_infer(node_name: str, state: Dict[str, ndarray], deterministic: bool = True, clip: bool = True) Dict[str, ndarray][source]

使用指定节点模型进行推理.

参数:
state

包含节点所有输入数据的字典。

deterministic

如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True。

返回值:

节点输出。

示例:

>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> action_output = venv_model.node_infer("action", state)
export2onnx(onnx_file: str, verbose: bool = True)[source]

导出环境模型为onnx格式。

参考: https://pytorch.org/docs/stable/onnx.html

参数:
onnx_file

存储onnx模型的文件地址。

verbose

默认为True。 如果为True,打印导出到的模型的描述,最终的ONNX图将包括导出模型中的字段doc_string,其中提到model的源代码位置。.

class revive.computation.inference_cn.PolicyModel(policy_model_dev: PolicyModelDev, post_process: Optional[Callable[[Dict[str, ndarray], ndarray], ndarray]] = None)[source]

Bases: object

to(device: str)[source]

切换模型所在设备,可以指定cpu或cuda。

示例:

>>> policy_model.to("cpu")
>>> policy_model.to("cuda")
>>> policy_model.to("cuda:1")
check_version()[source]

检查训练模型使用的REVIVE SDK版本与当前安装的REVIVE SDK版本是否一致。

reset()[source]

重置模型隐藏层信息,使用RNN训练的模型需要在每次开始使用时调用该方法。

property target_policy_name: None

获得策略节点的名称。

infer(state: Dict[str, ndarray], deterministic: bool = True, clip: bool = True, additional_info: Optional[Dict[str, ndarray]] = None) ndarray[source]

使用策略模型进行推理,输出动作.

参数:
state

包含策略节点所有输入数据的字典

deterministic

如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True。

clip

如果为True,输出的动作数值将会被裁剪到YAML文件中配置的范围; 如果为False,不对输出的动作数值进行裁剪,输出值可能存在越界的情况。 默认值:True

additional_info

默认为None即可。

返回值:

动作。

示例:

>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> action = policy_model.infer(state)
export2onnx(onnx_file: str, verbose: bool = True)[source]

导出策略模型为onnx格式。

参考: https://pytorch.org/docs/stable/onnx.html

参数:
onnx_file

存储onnx模型的文件地址。

verbose

默认为True。 如果为True,打印导出到的模型的描述,最终的ONNX图将包括导出模型中的字段doc_string,其中提到model的源代码位置。.