Source code for revive.computation.inference

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
from collections import deque
import time
import torch
import warnings
from loguru import logger
import numpy as np
from copy import deepcopy
from typing import Callable, Dict, List, Union, Optional

from revive import __version__
from revive.computation.graph import DesicionGraph, DesicionNode
from revive.computation.utils import *
from revive.computation.modules import EnsembleMatcher#, SequentialMatcher
from revive.computation.dists import MixDistribution

import time


[docs] def find_key_precedings(graph : DesicionGraph, node_name : str, preceding_type : str="network") -> List: precedings = graph.get_node(node_name).input_names key_precedings = [preceding for preceding in precedings if preceding in graph.keys()] return [key_preceding for key_preceding in key_precedings if graph.get_node(key_preceding).node_type == preceding_type]
[docs] def get_matcher_input(batch_data: Batch, target_node_name: str, candidate_num: int, node_names: List[str], nodes_fit_dict: dict = None) -> torch.Tensor: """ return all data of node_names from batch_data Args: batch_data (Batch): Batch of data node_names (List): list of node names to get data nodes_fit_index (Dict): dict of fixed index for nodel_names Return: data to get """ datas = [] for name in node_names: if nodes_fit_dict: if name != target_node_name: temp = batch_data[name][..., nodes_fit_dict[name]].unsqueeze(0).repeat_interleave(candidate_num, dim=0) else: temp = batch_data[name][..., nodes_fit_dict[name]] else: if name != target_node_name: temp = batch_data[name].unsqueeze(0).repeat_interleave(candidate_num, dim=0) else: temp = batch_data[name] temp = temp.view(-1, temp.size(-1)) datas.append(temp) datas = torch.cat(datas, dim=-1).detach() return datas
[docs] class VirtualEnvDev(torch.nn.Module): def __init__(self, graph : DesicionGraph, train_algo : str = None, info: dict = {}, **kwargs) -> None: super(VirtualEnvDev, self).__init__() self.models = torch.nn.ModuleList() self.graph = graph for node in self.graph.nodes.values(): if node.node_type == 'network': self.models.append(node.get_network()) self.set_target_policy_name(list(self.graph.keys())) # default self.revive_version = __version__ self.device = "cpu" self.train_algo = train_algo self.info = info if self.train_algo == "REVIVE_FILTER": self.load_filter_components(**kwargs) self.reset() def __str__(self): return f"Model Info: {self.info}\n" + super().__str__()
[docs] def to(self, device): if device != self.device: self.device = device for node in self.graph.nodes.values(): if node.node_type == 'network': try: node.to(self.device) except: logger.warning(f'Failed to move device the network of "{node.name}"') pass if hasattr(self, "ensemble_matcher_list"): [ensemble_matcher.to(self.device) for ensemble_matcher in self.ensemble_matcher_list]
[docs] def check_version(self): if not self.revive_version == __version__: warnings.warn(f'detect the venv is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self) -> None: self.graph.reset()
[docs] def set_target_policy_name(self, target_policy_name : list) -> None: self.target_policy_name = target_policy_name # find target index self.index = [] for i, (output_name, input_names) in enumerate(self.graph.items()): if output_name in self.target_policy_name: self.index.append(i) self.index.sort()
[docs] def load_filter_components(self, **kwargs): if "matcher_record_list" in kwargs: self.matcher_record_list = kwargs['matcher_record_list'] assert "matcher_structure_list" in kwargs self.matcher_structure_list = kwargs['matcher_structure_list'] self.init_ensemble_matcher() if "matching_nodes_fit_index_list" in kwargs: self.matching_nodes_fit_index_list = kwargs['matching_nodes_fit_index_list']
def _data_preprocess(self, data : np.ndarray, data_key : str = "obs") -> torch.Tensor: data = self.graph.processor.process_single(data, data_key) data = to_torch(data, device=self.device) return data def _data_postprocess(self, data : torch.Tensor, data_key : str) -> np.ndarray: data = to_numpy(data) data = self.graph.processor.deprocess_single(data, data_key) return data
[docs] def init_ensemble_matcher(self, ensemble_size: int=1, ensemble_choosing_interval: int=1): if hasattr(self, "ensemble_matcher_list"): return self.reset_ensemble_matcher(ensemble_size, ensemble_choosing_interval)
[docs] def reset_ensemble_matcher(self, ensemble_size: int, ensemble_choosing_interval: int): self.ensemble_matcher_list = [EnsembleMatcher(matcher_record, ensemble_size=ensemble_size, ensemble_choosing_interval=ensemble_choosing_interval, config=matcher_config) for matcher_record, matcher_config in zip(self.matcher_record_list, self.matcher_structure_list)] [ensemble_matcher.to(self.device) for ensemble_matcher in self.ensemble_matcher_list]
# self.vmap_matcher_list = [EnsembleMatcher(matcher_record, ensemble_size=ensemble_size, ensemble_choosing_interval=ensemble_choosing_interval, config=matcher_config) # for matcher_record, matcher_config in zip(self.matcher_record_list, self.matcher_structure_list)] # [vmap_matcher.to(self.device) for vmap_matcher in self.vmap_matcher_list] def _get_filter_scores_ensemble(self, output: MixDistribution, state: Batch, ensemble_matcher: EnsembleMatcher, node_name : str, candidate_num: int = 50) -> torch.Tensor: state = deepcopy(state) candidates = output.sample((candidate_num-1,)) # [candidate_num, batch_size, dim] candidates = torch.cat([candidates, output.mode.unsqueeze(0)], dim=0) # [candidate_num, batch_size, dim] state[node_name] = candidates batch_size = candidates.shape[1] matching_nodes = ensemble_matcher.matching_nodes matcher_input = get_matcher_input(state, node_name, candidate_num, matching_nodes, nodes_fit_dict=None) #[candidate_num * batch_size, dim] scores = ensemble_matcher.run_scores(matcher_input, aggregation="mean", clip=True) # [batch_size] scores = scores.view(candidate_num, batch_size) # [candidate_num, batch_size] return scores, candidates def _filter_one_step_ensemble(self, output: MixDistribution, state: Batch, ensemble_matcher: EnsembleMatcher, node_name : str, candidate_num: int = 50, **kwargs) -> torch.Tensor: scores, candidates = self._get_filter_scores_ensemble(output, state, ensemble_matcher, node_name, candidate_num) indices = scores.argmax(0) # print(indices) batch_size = scores.shape[1] selected = candidates[indices, torch.arange(batch_size)] # [batch_size, dim] assert len(selected.shape) == 2 return selected def _filter_one_step_vmap(self, output: MixDistribution, state: Batch, ensemble_matcher: None ,#SequentialMatcher, node_name : str, candidate_num: int = 50) -> torch.Tensor: state = deepcopy(state) candidates = output.sample((candidate_num-1,)) # [candidate_num, batch_size, dim] candidates = torch.cat([candidates, output.mode.unsqueeze(0)], dim=0) # [candidate_num, batch_size, dim] state[node_name] = candidates batch_size = candidates.shape[1] matching_nodes = ensemble_matcher.matching_nodes matcher_input = get_matcher_input(state, node_name, candidate_num, matching_nodes, nodes_fit_dict=None) #[candidate_num * batch_size, dim] scores = ensemble_matcher.run_scores(matcher_input, aggregation="mean", clip=True) # [batch_size] scores = scores.view(candidate_num, batch_size) indices = scores.argmax(0) # print(indices) output = candidates[indices, torch.arange(batch_size)] assert len(output.shape) == 2 return output def _infer_one_step(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True, node_name : str = None, **kwargs) -> Dict[str, np.ndarray]: self.check_version() state = deepcopy(state) sample_fn = get_sample_function(deterministic) for k in list(state.keys()): state[k] = self._data_preprocess(state[k], k) # Support inferring nodes based on node name calls if node_name is None: compute_nodes = self.graph.keys() else: assert node_name in self.graph.keys() assert not node_name in state.keys() compute_nodes = [node_name, ] for node_name in compute_nodes: if not node_name in state.keys(): # skip provided values output = self.graph.compute_node(node_name, state) if isinstance(output, torch.Tensor): state[node_name] = output else: if "filter" in kwargs and kwargs['filter'] == True: assert self.train_algo == "REVIVE_FILTER" assert "candidate_num" in kwargs, f"filter requires param \'candidate_num\'" for _ensemble_matcher in self.ensemble_matcher_list: if node_name in _ensemble_matcher.single_structure_dict['matching_fit_nodes']: ensemble_matcher = _ensemble_matcher break state[node_name] = self._filter_one_step_ensemble(output, state, ensemble_matcher, node_name, candidate_num=kwargs['candidate_num']) else: state[node_name] = sample_fn(output) if clip: state[node_name] = torch.clamp(state[node_name], -1, 1) for k in list(state.keys()): state[k] = self._data_postprocess(state[k], k) return state
[docs] def infer_k_steps(self, states : List[Dict[str, np.ndarray]], deterministic : bool = True, clip : bool = True) -> List[Dict[str, np.ndarray]]: outputs = [] backup = {} tunable = {name : states[0][name] for name in self.graph.tunable} for state in states: state.update(backup) output = self._infer_one_step(state, deterministic=deterministic, clip=clip) outputs.append(output) backup = self.graph.state_transition(output) backup.update(tunable) return outputs
[docs] def infer_one_step(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True, **kwargs) -> Dict[str, np.ndarray]: return self._infer_one_step(state, deterministic=deterministic, clip=clip, **kwargs)
[docs] def node_infer(self, node_name : str, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True) -> Dict[str, np.ndarray]: return self._infer_one_step(state, deterministic=deterministic, clip=clip, node_name=node_name)[node_name]
[docs] def node_dist(self, node_name : str, state : Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: self.check_version() state = deepcopy(state) for k in list(state.keys()): state[k] = self._data_preprocess(state[k], k) assert node_name in self.graph.keys() assert not node_name in state.keys() compute_nodes = [node_name, ] for node_name in compute_nodes: if not node_name in state.keys(): # skip provided values output = self.graph.compute_node(node_name, state) return output
[docs] def node_pre_computation(self, node_name: str, node_data: np.ndarray): self.check_version() return self._data_preprocess(node_data, node_name)
[docs] def node_post_computation(self, node_name: str, node_data: np.ndarray): self.check_version() return self._data_postprocess(node_data, node_name)
[docs] def forward(self, data : Dict[str, torch.Tensor], deterministic : bool = True, clip : bool = False) -> Dict[str, torch.Tensor]: ''' run the target node ''' self.check_version() sample_fn = get_sample_function(deterministic) node_name = self.target_policy_name output = self.graph.compute_node(node_name, data) if isinstance(output, torch.Tensor): data[node_name] = output else: data[node_name] = sample_fn(output) if clip: data[node_name] = torch.clamp(data[node_name], -1, 1) return data
[docs] def pre_computation(self, data : Dict[str, torch.Tensor], deterministic : bool = True, clip : bool = False, policy_index : int = 0, disturb_node_list : Dict = {}, **kwargs) -> Dict[str, torch.Tensor]: '''run all the node before target node. skip if the node value is already available.''' self.check_version() sample_fn = get_sample_function(deterministic) for node_name in list(self.graph.keys())[:self.index[policy_index]]: if not node_name in data.keys(): output = self.graph.compute_node(node_name, data) if isinstance(output, torch.Tensor): data[node_name] = output else: if "filter" in kwargs and kwargs['filter'] == True: assert "candidate_num" in kwargs, f"filter requires param \'candidate_num\'" for _ensemble_matcher in self.ensemble_matcher_list: if node_name in _ensemble_matcher.single_structure_dict['matching_fit_nodes']: ensemble_matcher = _ensemble_matcher break data[node_name] = self._filter_one_step_ensemble(output, data, ensemble_matcher, node_name, candidate_num=kwargs['candidate_num']) else: data[node_name] = sample_fn(output) if node_name in disturb_node_list.keys(): _add_value = self._disturb(data, disturb_node_list[node_name]['network'], disturb_node_list[node_name]['input_nodes'], disturb_node_list[node_name]['rnd_idx'], disturb_node_list[node_name]['disturb_weight']) data[node_name] += _add_value if clip: data[node_name] = torch.clamp(data[node_name], -1, 1) # else: # print(f'Skip {node_name}, since it is provided in the inputs.') return data
[docs] def filter_penalty(self, penalty_type : str, data : Dict[str, torch.Tensor], sample_num: int, clip : bool = False, policy_index : int = 0, **kwargs) -> Dict[str, torch.Tensor]: '''run all the node after target node''' assert self.train_algo == "REVIVE_FILTER", "filter_penalty_score requires REVIVE_FILTER" penaltys = [] penalty_nodes = [] after_action_nodes = list(self.graph.keys())[self.index[policy_index]+1:] # find which nodes need to compute penaltys for node_name in after_action_nodes: if node_name.startswith("next_"): if self.graph.get_node(node_name).node_type == "network": penalty_nodes.append(node_name) else: penalty_nodes += find_key_precedings(self.graph, node_name, preceding_type="network") # rollout and compute penaltys for node_name in after_action_nodes: output = self.graph.compute_node(node_name, data) if isinstance(output, torch.Tensor): data[node_name] = output else: # network node if node_name in penalty_nodes: for _ensemble_matcher in self.ensemble_matcher_list: if node_name in _ensemble_matcher.single_structure_dict['matching_fit_nodes']: ensemble_matcher = _ensemble_matcher break if penalty_type == "filter": scores, _ = self._get_filter_scores_ensemble(output, data, ensemble_matcher, node_name, candidate_num=1) # [candidate_num, batch_size] penalty = 1 - scores.max(0)[0].unsqueeze(-1) elif penalty_type == "filter_score_std": scores, _ = self._get_filter_scores_ensemble(output, data, ensemble_matcher, node_name, candidate_num=sample_num) # [candidate_num, batch_size] penalty = scores.std(0).unsqueeze(-1) else: raise NotImplementedError(f"penalty type {penalty_type} not supported!") penaltys.append(penalty) data[node_name] = output.mode if clip: data[node_name] = torch.clamp(data[node_name], -1, 1) return torch.stack(penaltys).mean(0)
[docs] def post_computation(self, data : Dict[str, torch.Tensor], deterministic : bool = True, clip : bool = False, policy_index : int = 0, disturb_node_list : Dict = {}, **kwargs) -> Dict[str, torch.Tensor]: '''run all the node after target node''' # start = time.time() self.check_version() sample_fn = get_sample_function(deterministic) for node_name in list(self.graph.keys())[self.index[policy_index]+1:]: output = self.graph.compute_node(node_name, data) if isinstance(output, torch.Tensor): data[node_name] = output else: if "filter" in kwargs and kwargs['filter'] == True: assert "candidate_num" in kwargs, f"filter requires param \'candidate_num\'" for _ensemble_matcher in self.ensemble_matcher_list: if node_name in _ensemble_matcher.single_structure_dict['matching_fit_nodes']: ensemble_matcher = _ensemble_matcher break data[node_name] = self._filter_one_step_ensemble(output, data, ensemble_matcher, node_name, candidate_num=kwargs['candidate_num']) else: data[node_name] = sample_fn(output) if node_name in disturb_node_list.keys(): _add_value = self._disturb(data, disturb_node_list[node_name]['network'], disturb_node_list[node_name]['input_nodes'], disturb_node_list[node_name]['rnd_idx'], disturb_node_list[node_name]['disturb_weight']) data[node_name] += _add_value if clip: data[node_name] = torch.clamp(data[node_name], -1, 1) # end = time.time() # print(f"post: {end - start}") return data
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True): self.graph.export2onnx(onnx_file, verbose)
def _disturb(self, data, disturb_net, input_nodes, rnd_idx, disturb_weight): with torch.no_grad(): input_tensor = [data[k2] for k2 in input_nodes] input_tensor = torch.cat(input_tensor, dim=-1) if input_tensor.dim() == 3: input_tensor = input_tensor.squeeze(dim=0) disturb_net.to(input_tensor.device) rnd_disturbing = disturb_net(input_tensor) * disturb_weight # print(f'[ DEBUG ] curernt batch[k] shape: {data[node_name].shape}, rnd_disturbing_shape: {rnd_disturbing.shape}') rnd_disturbing = rnd_disturbing[rnd_idx, np.arange(rnd_disturbing.shape[1]), :] # print(f'[ DEBUG ] curernt batch[k] shape: {data[node_name].shape}, rnd_disturbing_shape: {rnd_disturbing.shape}') rnd_disturbing[rnd_idx < 0, ...] = 0 return rnd_disturbing
[docs] class VirtualEnv: def __init__(self, env_list : List[VirtualEnvDev]): self._env = env_list[0] self.env_list = env_list self.graph = self._env.graph self.revive_version = __version__ self.device = "cpu" self.train_algo = self._env.train_algo def __str__(self): return f"REVIVE VERSION: {self.revive_version}\n" \ f"_env: {self._env}"
[docs] def to(self, device): r""" Move model to the device specified by the parameter. Examples:: >>> venv_model.to("cpu") >>> venv_model.to("cuda") >>> venv_model.to("cuda:1") """ if device != self.device: self.device = device for env in self.env_list: env.to(device)
[docs] def check_version(self): r"""Check if the revive version of the saved model and the current revive version match.""" if not self.revive_version == __version__: warnings.warn(f'detect the venv is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self) -> None: r""" When using RNN for model training, this method needs to be called before model reuse to reset the hidden layer information. """ for env in self.env_list: env.reset()
[docs] def set_env(self, env_id) -> None: assert env_id < len(self.env_list) self._env = self.env_list[env_id]
@property def target_policy_name(self) -> str: r''' Get the target policy name. ''' return self._env.target_policy_name
[docs] def set_target_policy_name(self, target_policy_name) -> None: r''' Set the target policy name. ''' for env in self.env_list: env.set_target_policy_name(target_policy_name)
[docs] def replace_policy(self, policy : 'PolicyModel') -> None: ''' Replace the target policy with the given policy. ''' assert self.target_policy_name == policy.target_policy_name, \ f'policy name does not match, require {self.target_policy_name} but get {policy.target_policy_name}!' for env in self.env_list: env.graph.nodes[self.target_policy_name] = policy._policy_model.node
[docs] @torch.no_grad() def infer_k_steps(self, states : Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]], k : Optional[int] = None, deterministic : bool = True, clip : bool = True) -> List[Dict[str, np.ndarray]]: r""" Generate k steps interactive data. Args: :states: a dict of initial input nodes :k: how many steps to generate :deterministic: if True, the most likely actions are generated; if False, actions are generated by sample. Default: True Return: k steps interactive data dict Examples:: >>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> ten_step_output = venv_model.infer_k_steps(state, k=10) """ self.check_version() if isinstance(states, dict): states = [states] if k is not None: for i in range(k - 1): states.append({}) return self._env.infer_k_steps(deepcopy(states), deterministic, clip=clip)
[docs] @torch.no_grad() def infer_one_step(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True, **kwargs) -> Dict[str, np.ndarray]: r""" Generate one step interactive data given action. Args: :state: a dict of input nodes :deterministic: if True, the most likely actions are generated; if False, actions are generated by sample. Default: True Return: one step outputs Examples:: >>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> one_step_output = venv_model.infer_one_step(state) """ self.check_version() return self._env.infer_one_step(deepcopy(state), deterministic, clip=clip, **kwargs)
[docs] def reset_ensemble_matcher(self, ensemble_size: int=10, ensemble_choosing_interval: int=10): self._env.reset_ensemble_matcher(ensemble_size, ensemble_choosing_interval)
[docs] def node_pre_computation(self, node_name: str, node_data: np.ndarray): self.check_version() return self._env.node_pre_computation(node_name, node_data)
[docs] def node_post_computation(self, node_name: str, node_data: np.ndarray): self.check_version() return self._env.node_pre_computation(node_name, node_data)
[docs] @torch.no_grad() def node_infer(self, node_name: str, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True) -> Dict[str, np.ndarray]: r""" Generate one step interactive data given node_name. Args: :state: a dict of input nodes :deterministic: if True, the most likely actions are generated; if False, actions are generated by sample. Default: True Return: one step node output Examples:: >>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action_output = venv_model.node_infer("action", state) """ self.check_version() return self._env.node_infer(node_name, deepcopy(state), deterministic, clip=clip)
[docs] @torch.no_grad() def node_dist(self, node_name: str, state : Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: r""" Generate one step interactive dist given node_name. Args: :state: a dict of input nodes :deterministic: if True, the most likely actions are generated; if False, actions are generated by sample. Default: True Return: one step node output Examples:: >>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action_output = venv_model.node_infer("action", state) """ self.check_version() return self._env.node_dist(node_name, deepcopy(state))
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True): r""" Exporting the model to onnx mode. Reference: https://pytorch.org/docs/stable/onnx.html Args: :onnx_file: the onnx model file save path. :verbose: if True, prints a description of the model being exported to stdout. In addition, the final ONNX graph will include the field ``doc_string``` from the exported model which mentions the source code locations for ``model``. """ self._env.export2onnx(onnx_file, verbose)
[docs] class PolicyModelDev(torch.nn.Module): def __init__(self, nodes : List[DesicionNode,]): super().__init__() self.nodes = nodes self.node = self.nodes[0] self.models = [node.get_network() for node in self.nodes] self.target_policy_name = [node.name for node in self.nodes] self.target_policy_name = self.target_policy_name[0] self.revive_version = __version__ self.device = "cpu" self.reset()
[docs] def to(self, device): if device != self.device: self.device = device self.node.to(self.device)
[docs] def check_version(self): if not self.revive_version == __version__: warnings.warn(f'detect the policy is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self): for node in self.nodes: node.reset()
def _data_preprocess(self, data : np.ndarray, data_key : str = "obs") -> torch.Tensor: data = self.node.processor.process_single(data, data_key) data = to_torch(data, device=self.device) return data def _data_postprocess(self, data : torch.tensor, data_key : str = "action1") -> np.ndarray: data = to_numpy(data) data = self.node.processor.deprocess_single(data, data_key) return data
[docs] def infer(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = False) -> np.ndarray: self.check_version() state = deepcopy(state) sample_fn = get_sample_function(deterministic) for k, v in state.items(): state[k] = self._data_preprocess(v, data_key=k) output = self.node(state) if isinstance(output, torch.Tensor): action = output else: action = sample_fn(output) if clip: action = torch.clamp(action, -1, 1) action = self._data_postprocess(action, self.target_policy_name) return action
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True): self.node.export2onnx(onnx_file, verbose)
[docs] class PolicyModel: def __init__(self, policy_model_dev : PolicyModelDev, post_process : Optional[Callable[[Dict[str, np.ndarray], np.ndarray], np.ndarray]] = None): self._policy_model = policy_model_dev self.post_process = post_process self.revive_version = __version__ self.device = "cpu"
[docs] def to(self, device: str): r""" Move model to the device specified by the parameter. Examples:: >>> policy_model.to("cpu") >>> policy_model.to("cuda") >>> policy_model.to("cuda:1") """ if device != self.device: self.device = device self._policy_model.to(self.device)
[docs] def check_version(self): r"""Check if the revive version of the saved model and the current revive version match.""" if not self.revive_version == __version__: warnings.warn(f'detect the policy is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self): r""" When using RNN for model training, this method needs to be called before model reuse to reset the hidden layer information. """ self._policy_model.reset()
@property def target_policy_name(self) -> None: r''' Get the target policy name. ''' return self._policy_model.target_policy_name
[docs] @torch.no_grad() def infer(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = True, additional_info : Optional[Dict[str, np.ndarray]] = None) -> np.ndarray: r""" Generate action according policy. Args: :state: a dict contain *ALL* the input nodes of the policy node :deterministic: if True, the most likely actions are generated; if False, actions are generated by sample. Default: True :clip: if True, The output will be cropped to the range set in the yaml file; if False, actions are generated by sample. Default: True :additional_info: a dict of additional info for post process Return: action Examples:: >>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action = policy_model.infer(state) """ self.check_version() action = self._policy_model.infer(deepcopy(state), deterministic, clip=clip) if self.post_process is not None: state.update(additional_info) action = self.post_process(state, action) return action
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True): r""" Exporting the model to onnx mode. Reference: https://pytorch.org/docs/stable/onnx.html Args: :onnx_file: the onnx model file save path. :verbose: if True, prints a description of the model being exported to stdout. In addition, the final ONNX graph will include the field ``doc_string``` from the exported model which mentions the source code locations for ``model``. """ if self.post_process is not None: warnings.warn('Currently, post process will not be exported.') self._policy_model.export2onnx(onnx_file, verbose)