Source code for revive.computation.inference_cn

''''''
""" 本文件只为用来生成中文文档,请不要另作它用 """
"""
    POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

import torch
import warnings
from loguru import logger
import numpy as np
from copy import deepcopy
from typing import Callable, Dict, List, Union, Optional

from revive import __version__
from revive.computation.graph import DesicionGraph, DesicionNode
from revive.computation.utils import *


[docs]class VirtualEnvDev(torch.nn.Module): def __init__(self, graph : DesicionGraph) -> None: super(VirtualEnvDev, self).__init__() self.models = torch.nn.ModuleList() self.graph = graph for node in self.graph.nodes.values(): if node.node_type == 'network': self.models.append(node.get_network()) self.set_target_policy_name(list(self.graph.keys())) # default self.revive_version = __version__ self.device = "cpu"
[docs] def to(self, device): if device != self.device: self.device = device for node in self.graph.nodes.values(): if node.node_type == 'network': try: node.to(self.device) except: logger.warning(f'Failed to move device the network of "{node.name}"') pass
[docs] def check_version(self): if not self.revive_version == __version__: warnings.warn(f'detect the venv is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self) -> None: self.graph.reset()
[docs] def set_target_policy_name(self, target_policy_name : list) -> None: self.target_policy_name = target_policy_name # find target index self.index = [] for i, (output_name, input_names) in enumerate(self.graph.items()): if output_name in self.target_policy_name: self.index.append(i) self.index.sort()
def _data_preprocess(self, data : np.ndarray, data_key : str = "obs") -> torch.Tensor: data = self.graph.processor.process_single(data, data_key) data = to_torch(data, device=self.device) return data def _data_postprocess(self, data : torch.Tensor, data_key : str) -> np.ndarray: data = to_numpy(data) data = self.graph.processor.deprocess_single(data, data_key) return data def _infer_one_step(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True, node_name : str = None) -> Dict[str, np.ndarray]: self.check_version() state = deepcopy(state) sample_fn = get_sample_function(deterministic) for k in list(state.keys()): state[k] = self._data_preprocess(state[k], k) # Support inferring nodes based on node name calls if node_name is None: compute_nodes = self.graph.keys() else: assert node_name in self.graph.keys() assert not node_name in state.keys() compute_nodes = [node_name, ] for node_name in compute_nodes: if not node_name in state.keys(): # skip provided values output = self.graph.compute_node(node_name, state) if isinstance(output, torch.Tensor): state[node_name] = output else: state[node_name] = sample_fn(output) if clip: state[node_name] = torch.clamp(state[node_name], -1, 1) for k in list(state.keys()): state[k] = self._data_postprocess(state[k], k) return state
[docs] def infer_k_steps(self, states : List[Dict[str, np.ndarray]], deterministic : bool = True, clip : bool = True) -> List[Dict[str, np.ndarray]]: outputs = [] backup = {} tunable = {name : states[0][name] for name in self.graph.tunable} for state in states: state.update(backup) output = self._infer_one_step(state, deterministic=deterministic, clip=clip) outputs.append(output) backup = self.graph.state_transition(output) backup.update(tunable) return outputs
[docs] def infer_one_step(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True) -> Dict[str, np.ndarray]: return self._infer_one_step(state, deterministic=deterministic, clip=clip)
[docs] def node_infer(self, node_name : str, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True) -> Dict[str, np.ndarray]: return self._infer_one_step(state, deterministic=deterministic, clip=clip, node_name=node_name)[node_name]
[docs] def forward(self, data : Dict[str, torch.Tensor], deterministic : bool = True, clip : bool = False) -> Dict[str, torch.Tensor]: ''' run the target node ''' self.check_version() sample_fn = get_sample_function(deterministic) node_name = self.target_policy_name output = self.graph.compute_node(node_name, data) if isinstance(output, torch.Tensor): data[node_name] = output else: data[node_name] = sample_fn(output) if clip: data[node_name] = torch.clamp(data[node_name], -1, 1) return data
[docs] def pre_computation(self, data : Dict[str, torch.Tensor], deterministic : bool = True, clip : bool = False, policy_index : int = 0) -> Dict[str, torch.Tensor]: '''run all the node before target node. skip if the node value is already available.''' self.check_version() sample_fn = get_sample_function(deterministic) for node_name in list(self.graph.keys())[:self.index[policy_index]]: if not node_name in data.keys(): output = self.graph.compute_node(node_name, data) if isinstance(output, torch.Tensor): data[node_name] = output else: data[node_name] = sample_fn(output) if clip: data[node_name] = torch.clamp(data[node_name], -1, 1) else: print(f'Skip {node_name}, since it is provided in the inputs.') return data
[docs] def post_computation(self, data : Dict[str, torch.Tensor], deterministic : bool = True, clip : bool = False, policy_index : int = 0) -> Dict[str, torch.Tensor]: '''run all the node after target node''' self.check_version() sample_fn = get_sample_function(deterministic) for node_name in list(self.graph.keys())[self.index[policy_index]+1:]: output = self.graph.compute_node(node_name, data) if isinstance(output, torch.Tensor): data[node_name] = output else: data[node_name] = sample_fn(output) if clip: data[node_name] = torch.clamp(data[node_name], -1, 1) return data
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True): self.graph.export2onnx(onnx_file, verbose)
[docs]class VirtualEnv: def __init__(self, env_list : List[VirtualEnvDev]): self._env = env_list[0] self.env_list = env_list self.graph = self._env.graph self.revive_version = __version__ self.device = "cpu"
[docs] def to(self, device): r""" 切换模型所在设备,可以指定cpu或cuda。 示例:: >>> venv_model.to("cpu") >>> venv_model.to("cuda") >>> venv_model.to("cuda:1") """ if device != self.device: self.device = device for env in self.env_list: env.to(device)
[docs] def check_version(self): r"""检查训练模型使用的REVIVE SDK版本与当前安装的REVIVE SDK版本是否一致。""" if not self.revive_version == __version__: warnings.warn(f'detect the venv is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self) -> None: r""" 重置模型隐藏层信息,使用RNN训练的模型需要在每次开始使用时调用该方法。 """ for env in self.env_list: env.reset()
@property def target_policy_name(self) -> str: r''' 获得策略节点的名称。 ''' return self._env.target_policy_name
[docs] def set_target_policy_name(self, target_policy_name) -> None: r''' 设置策略节点的名称。 ''' for env in self.env_list: env.set_target_policy_name(target_policy_name)
[docs] def replace_policy(self, policy : 'PolicyModel') -> None: ''' 使用给定的策略节点模型代替当前的策略节点模型。 ''' assert self.target_policy_name == policy.target_policy_name, \ f'policy name does not match, require {self.target_policy_name} but get {policy.target_policy_name}!' for env in self.env_list: env.graph.nodes[self.target_policy_name] = policy._policy_model.node
[docs] @torch.no_grad() def infer_one_step(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True) -> Dict[str, np.ndarray]: r""" 生成1步交互数据,1步表示决策流图完整的运行一遍。 参数: :states: 包含初始输入节点数据的字典,初始的节点数据应包括决策流图的所有叶子节点。 :deterministic: 如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True 返回值: 字典,含有1步交互数据,key是节点名,value节点数据数组。 示例:: >>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> one_step_output = venv_model.infer_one_step(state) """ self.check_version() return self._env.infer_one_step(deepcopy(state), deterministic, clip=clip)
[docs] @torch.no_grad() def infer_k_steps(self, states : Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]], k : Optional[int] = None, deterministic : bool = True, clip : bool = True) -> List[Dict[str, np.ndarray]]: r""" 生成k步交互数据,每一步表示决策流图完整的运行一遍。 参数: :states: 包含初始输入节点数据的字典,初始的节点数据应包括决策流图的所有叶子节点。 :k: 正整数,如果是1,则返回一步的交互数据;如果是10,则决策流图迭代的运行10次,返回10次的数据。 :deterministic: 如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True 返回值: 字典,含有k步交互数据,key是节点名,value是含有k步该节点数据的数组。 示例:: >>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> ten_step_output = venv_model.infer_k_steps(state, k=10) """ self.check_version() if isinstance(states, dict): states = [states] if k is not None: for i in range(k - 1): states.append({}) return self._env.infer_k_steps(deepcopy(states), deterministic, clip=clip)
[docs] @torch.no_grad() def node_infer(self, node_name: str, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True) -> Dict[str, np.ndarray]: r""" 使用指定节点模型进行推理. 参数: :state: 包含节点所有输入数据的字典。 :deterministic: 如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True。 返回值: 节点输出。 示例:: >>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action_output = venv_model.node_infer("action", state) """ self.check_version() return self._env.node_infer(node_name, deepcopy(state), deterministic, clip=clip)
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True): r""" 导出环境模型为onnx格式。 参考: https://pytorch.org/docs/stable/onnx.html 参数: :onnx_file: 存储onnx模型的文件地址。 :verbose: 默认为True。 如果为True,打印导出到的模型的描述,最终的ONNX图将包括导出模型中的字段doc_string,其中提到model的源代码位置。. """ self._env.export2onnx(onnx_file, verbose)
[docs]class PolicyModelDev(torch.nn.Module): def __init__(self, nodes : List[DesicionNode,]): super().__init__() self.nodes = nodes self.node = self.nodes[0] self.models = [node.get_network() for node in self.nodes] self.target_policy_name = [node.name for node in self.nodes] self.target_policy_name = self.target_policy_name[0] self.revive_version = __version__ self.device = "cpu"
[docs] def to(self, device): if device != self.device: self.device = device self.node.to(self.device)
[docs] def check_version(self): if not self.revive_version == __version__: warnings.warn(f'detect the policy is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self): for node in self.nodes: node.reset()
def _data_preprocess(self, data : np.ndarray, data_key : str = "obs") -> torch.Tensor: data = self.node.processor.process_single(data, data_key) data = to_torch(data, device=self.device) return data def _data_postprocess(self, data : torch.tensor, data_key : str = "action1") -> np.ndarray: data = to_numpy(data) data = self.node.processor.deprocess_single(data, data_key) return data
[docs] def infer(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = False) -> np.ndarray: self.check_version() state = deepcopy(state) sample_fn = get_sample_function(deterministic) for k, v in state.items(): state[k] = self._data_preprocess(v, data_key=k) output = self.node(state) if isinstance(output, torch.Tensor): action = output else: action = sample_fn(output) if clip: action = torch.clamp(action, -1, 1) action = self._data_postprocess(action, self.target_policy_name) return action
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True): self.node.export2onnx(onnx_file, verbose)
[docs]class PolicyModel: def __init__(self, policy_model_dev : PolicyModelDev, post_process : Optional[Callable[[Dict[str, np.ndarray], np.ndarray], np.ndarray]] = None): self._policy_model = policy_model_dev self.post_process = post_process self.revive_version = __version__ self.device = "cpu"
[docs] def to(self, device: str): r""" 切换模型所在设备,可以指定cpu或cuda。 示例:: >>> policy_model.to("cpu") >>> policy_model.to("cuda") >>> policy_model.to("cuda:1") """ if device != self.device: self.device = device self._policy_model.to(self.device)
[docs] def check_version(self): r"""检查训练模型使用的REVIVE SDK版本与当前安装的REVIVE SDK版本是否一致。""" if not self.revive_version == __version__: warnings.warn(f'detect the policy is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self): r""" 重置模型隐藏层信息,使用RNN训练的模型需要在每次开始使用时调用该方法。 """ self._policy_model.reset()
@property def target_policy_name(self) -> None: r''' 获得策略节点的名称。 ''' return self._policy_model.target_policy_name
[docs] @torch.no_grad() def infer(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True, additional_info : Optional[Dict[str, np.ndarray]] = None) -> np.ndarray: r""" 使用策略模型进行推理,输出动作. 参数: :state: 包含策略节点所有输入数据的字典 :deterministic: 如果参数是True, 进行确定性的数据生成; 如果参数是False, 从分布中进行采样生成数据。 默认值: True。 :clip: 如果为True,输出的动作数值将会被裁剪到YAML文件中配置的范围; 如果为False,不对输出的动作数值进行裁剪,输出值可能存在越界的情况。 默认值:True :additional_info: 默认为None即可。 返回值: 动作。 示例:: >>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action = policy_model.infer(state) """ self.check_version() action = self._policy_model.infer(deepcopy(state), deterministic, clip=clip) if self.post_process is not None: state.update(additional_info) action = self.post_process(state, action) return action
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True): r""" 导出策略模型为onnx格式。 参考: https://pytorch.org/docs/stable/onnx.html 参数: :onnx_file: 存储onnx模型的文件地址。 :verbose: 默认为True。 如果为True,打印导出到的模型的描述,最终的ONNX图将包括导出模型中的字段doc_string,其中提到model的源代码位置。. """ if self.post_process is not None: warnings.warn('Currently, post process will not be exported.') self._policy_model.export2onnx(onnx_file, verbose)