''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
from collections import deque
import time
import torch
import warnings
from loguru import logger
import numpy as np
from copy import deepcopy
from typing import Callable, Dict, List, Union, Optional
from revive import __version__
from revive.computation.graph import DesicionGraph, DesicionNode
from revive.computation.utils import *
from revive.computation.modules import EnsembleMatcher#, SequentialMatcher
from revive.computation.dists import MixDistribution
import time
[docs]
def find_key_precedings(graph : DesicionGraph,
node_name : str,
preceding_type : str="network") -> List:
precedings = graph.get_node(node_name).input_names
key_precedings = [preceding for preceding in precedings if preceding in graph.keys()]
return [key_preceding for key_preceding in key_precedings if graph.get_node(key_preceding).node_type == preceding_type]
[docs]
class VirtualEnvDev(torch.nn.Module):
def __init__(self,
graph : DesicionGraph,
train_algo : str = None,
info: dict = {},
**kwargs) -> None:
super(VirtualEnvDev, self).__init__()
self.models = torch.nn.ModuleList()
self.graph = graph
for node in self.graph.nodes.values():
if node.node_type == 'network':
self.models.append(node.get_network())
self.set_target_policy_name(list(self.graph.keys())) # default
self.revive_version = __version__
self.device = "cpu"
self.train_algo = train_algo
self.info = info
if self.train_algo == "REVIVE_FILTER":
self.load_filter_components(**kwargs)
self.reset()
def __str__(self):
return f"Model Info: {self.info}\n" + super().__str__()
[docs]
def to(self, device):
if device != self.device:
self.device = device
for node in self.graph.nodes.values():
if node.node_type == 'network':
try:
node.to(self.device)
except:
logger.warning(f'Failed to move device the network of "{node.name}"')
pass
if hasattr(self, "ensemble_matcher_list"):
[ensemble_matcher.to(self.device) for ensemble_matcher in self.ensemble_matcher_list]
[docs]
def check_version(self):
if not self.revive_version == __version__:
warnings.warn(f'detect the venv is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs]
def reset(self) -> None:
self.graph.reset()
[docs]
def set_target_policy_name(self, target_policy_name : list) -> None:
self.target_policy_name = target_policy_name
# find target index
self.index = []
for i, (output_name, input_names) in enumerate(self.graph.items()):
if output_name in self.target_policy_name:
self.index.append(i)
self.index.sort()
[docs]
def load_filter_components(self, **kwargs):
if "matcher_record_list" in kwargs:
self.matcher_record_list = kwargs['matcher_record_list']
assert "matcher_structure_list" in kwargs
self.matcher_structure_list = kwargs['matcher_structure_list']
self.init_ensemble_matcher()
if "matching_nodes_fit_index_list" in kwargs:
self.matching_nodes_fit_index_list = kwargs['matching_nodes_fit_index_list']
def _data_preprocess(self, data : np.ndarray, data_key : str = "obs") -> torch.Tensor:
data = self.graph.processor.process_single(data, data_key)
data = to_torch(data, device=self.device)
return data
def _data_postprocess(self, data : torch.Tensor, data_key : str) -> np.ndarray:
data = to_numpy(data)
data = self.graph.processor.deprocess_single(data, data_key)
return data
[docs]
def init_ensemble_matcher(self,
ensemble_size: int=1,
ensemble_choosing_interval: int=1):
if hasattr(self, "ensemble_matcher_list"):
return
self.reset_ensemble_matcher(ensemble_size, ensemble_choosing_interval)
[docs]
def reset_ensemble_matcher(self,
ensemble_size: int,
ensemble_choosing_interval: int):
self.ensemble_matcher_list = [EnsembleMatcher(matcher_record, ensemble_size=ensemble_size, ensemble_choosing_interval=ensemble_choosing_interval, config=matcher_config)
for matcher_record, matcher_config in zip(self.matcher_record_list, self.matcher_structure_list)]
[ensemble_matcher.to(self.device) for ensemble_matcher in self.ensemble_matcher_list]
# self.vmap_matcher_list = [EnsembleMatcher(matcher_record, ensemble_size=ensemble_size, ensemble_choosing_interval=ensemble_choosing_interval, config=matcher_config)
# for matcher_record, matcher_config in zip(self.matcher_record_list, self.matcher_structure_list)]
# [vmap_matcher.to(self.device) for vmap_matcher in self.vmap_matcher_list]
def _get_filter_scores_ensemble(self,
output: MixDistribution,
state: Batch,
ensemble_matcher: EnsembleMatcher,
node_name : str,
candidate_num: int = 50) -> torch.Tensor:
state = deepcopy(state)
candidates = output.sample((candidate_num-1,)) # [candidate_num, batch_size, dim]
candidates = torch.cat([candidates, output.mode.unsqueeze(0)], dim=0) # [candidate_num, batch_size, dim]
state[node_name] = candidates
batch_size = candidates.shape[1]
matching_nodes = ensemble_matcher.matching_nodes
matcher_input = get_matcher_input(state, node_name, candidate_num, matching_nodes, nodes_fit_dict=None) #[candidate_num * batch_size, dim]
scores = ensemble_matcher.run_scores(matcher_input, aggregation="mean", clip=True) # [batch_size]
scores = scores.view(candidate_num, batch_size) # [candidate_num, batch_size]
return scores, candidates
def _filter_one_step_ensemble(self,
output: MixDistribution,
state: Batch,
ensemble_matcher: EnsembleMatcher,
node_name : str,
candidate_num: int = 50,
**kwargs) -> torch.Tensor:
scores, candidates = self._get_filter_scores_ensemble(output, state, ensemble_matcher, node_name, candidate_num)
indices = scores.argmax(0)
# print(indices)
batch_size = scores.shape[1]
selected = candidates[indices, torch.arange(batch_size)] # [batch_size, dim]
assert len(selected.shape) == 2
return selected
def _filter_one_step_vmap(self,
output: MixDistribution,
state: Batch,
ensemble_matcher: None ,#SequentialMatcher,
node_name : str,
candidate_num: int = 50) -> torch.Tensor:
state = deepcopy(state)
candidates = output.sample((candidate_num-1,)) # [candidate_num, batch_size, dim]
candidates = torch.cat([candidates, output.mode.unsqueeze(0)], dim=0) # [candidate_num, batch_size, dim]
state[node_name] = candidates
batch_size = candidates.shape[1]
matching_nodes = ensemble_matcher.matching_nodes
matcher_input = get_matcher_input(state, node_name, candidate_num, matching_nodes, nodes_fit_dict=None) #[candidate_num * batch_size, dim]
scores = ensemble_matcher.run_scores(matcher_input, aggregation="mean", clip=True) # [batch_size]
scores = scores.view(candidate_num, batch_size)
indices = scores.argmax(0)
# print(indices)
output = candidates[indices, torch.arange(batch_size)]
assert len(output.shape) == 2
return output
def _infer_one_step(self,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True,
node_name : str = None,
**kwargs) -> Dict[str, np.ndarray]:
self.check_version()
state = deepcopy(state)
sample_fn = get_sample_function(deterministic)
for k in list(state.keys()):
state[k] = self._data_preprocess(state[k], k)
# Support inferring nodes based on node name calls
if node_name is None:
compute_nodes = self.graph.keys()
else:
assert node_name in self.graph.keys()
assert not node_name in state.keys()
compute_nodes = [node_name, ]
for node_name in compute_nodes:
if not node_name in state.keys(): # skip provided values
output = self.graph.compute_node(node_name, state)
if isinstance(output, torch.Tensor):
state[node_name] = output
else:
if "filter" in kwargs and kwargs['filter'] == True:
assert self.train_algo == "REVIVE_FILTER"
assert "candidate_num" in kwargs, f"filter requires param \'candidate_num\'"
for _ensemble_matcher in self.ensemble_matcher_list:
if node_name in _ensemble_matcher.single_structure_dict['matching_fit_nodes']:
ensemble_matcher = _ensemble_matcher
break
state[node_name] = self._filter_one_step_ensemble(output, state, ensemble_matcher, node_name, candidate_num=kwargs['candidate_num'])
else:
state[node_name] = sample_fn(output)
if clip:
state[node_name] = torch.clamp(state[node_name], -1, 1)
for k in list(state.keys()):
state[k] = self._data_postprocess(state[k], k)
return state
[docs]
def infer_k_steps(self,
states : List[Dict[str, np.ndarray]],
deterministic : bool = True,
clip : bool = True) -> List[Dict[str, np.ndarray]]:
outputs = []
backup = {}
tunable = {name : states[0][name] for name in self.graph.tunable}
for state in states:
state.update(backup)
output = self._infer_one_step(state, deterministic=deterministic, clip=clip)
outputs.append(output)
backup = self.graph.state_transition(output)
backup.update(tunable)
return outputs
[docs]
def infer_one_step(self,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True,
**kwargs) -> Dict[str, np.ndarray]:
return self._infer_one_step(state, deterministic=deterministic, clip=clip, **kwargs)
[docs]
def node_infer(self,
node_name : str,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True) -> Dict[str, np.ndarray]:
return self._infer_one_step(state, deterministic=deterministic, clip=clip, node_name=node_name)[node_name]
[docs]
def node_dist(self,
node_name : str,
state : Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
self.check_version()
state = deepcopy(state)
for k in list(state.keys()):
state[k] = self._data_preprocess(state[k], k)
assert node_name in self.graph.keys()
assert not node_name in state.keys()
compute_nodes = [node_name, ]
for node_name in compute_nodes:
if not node_name in state.keys(): # skip provided values
output = self.graph.compute_node(node_name, state)
return output
[docs]
def node_pre_computation(self, node_name: str, node_data: np.ndarray):
self.check_version()
return self._data_preprocess(node_data, node_name)
[docs]
def node_post_computation(self, node_name: str, node_data: np.ndarray):
self.check_version()
return self._data_postprocess(node_data, node_name)
[docs]
def forward(self,
data : Dict[str, torch.Tensor],
deterministic : bool = True,
clip : bool = False) -> Dict[str, torch.Tensor]:
''' run the target node '''
self.check_version()
sample_fn = get_sample_function(deterministic)
node_name = self.target_policy_name
output = self.graph.compute_node(node_name, data)
if isinstance(output, torch.Tensor):
data[node_name] = output
else:
data[node_name] = sample_fn(output)
if clip: data[node_name] = torch.clamp(data[node_name], -1, 1)
return data
[docs]
def pre_computation(self,
data : Dict[str, torch.Tensor],
deterministic : bool = True,
clip : bool = False,
policy_index : int = 0,
disturb_node_list : Dict = {},
**kwargs) -> Dict[str, torch.Tensor]:
'''run all the node before target node. skip if the node value is already available.'''
self.check_version()
sample_fn = get_sample_function(deterministic)
for node_name in list(self.graph.keys())[:self.index[policy_index]]:
if not node_name in data.keys():
output = self.graph.compute_node(node_name, data)
if isinstance(output, torch.Tensor):
data[node_name] = output
else:
if "filter" in kwargs and kwargs['filter'] == True:
assert "candidate_num" in kwargs, f"filter requires param \'candidate_num\'"
for _ensemble_matcher in self.ensemble_matcher_list:
if node_name in _ensemble_matcher.single_structure_dict['matching_fit_nodes']:
ensemble_matcher = _ensemble_matcher
break
data[node_name] = self._filter_one_step_ensemble(output, data, ensemble_matcher, node_name, candidate_num=kwargs['candidate_num'])
else:
data[node_name] = sample_fn(output)
if node_name in disturb_node_list.keys():
_add_value = self._disturb(data,
disturb_node_list[node_name]['network'],
disturb_node_list[node_name]['input_nodes'],
disturb_node_list[node_name]['rnd_idx'],
disturb_node_list[node_name]['disturb_weight'])
data[node_name] += _add_value
if clip:
data[node_name] = torch.clamp(data[node_name], -1, 1)
# else:
# print(f'Skip {node_name}, since it is provided in the inputs.')
return data
[docs]
def filter_penalty(self,
penalty_type : str,
data : Dict[str, torch.Tensor],
sample_num: int,
clip : bool = False,
policy_index : int = 0,
**kwargs) -> Dict[str, torch.Tensor]:
'''run all the node after target node'''
assert self.train_algo == "REVIVE_FILTER", "filter_penalty_score requires REVIVE_FILTER"
penaltys = []
penalty_nodes = []
after_action_nodes = list(self.graph.keys())[self.index[policy_index]+1:]
# find which nodes need to compute penaltys
for node_name in after_action_nodes:
if node_name.startswith("next_"):
if self.graph.get_node(node_name).node_type == "network":
penalty_nodes.append(node_name)
else:
penalty_nodes += find_key_precedings(self.graph, node_name, preceding_type="network")
# rollout and compute penaltys
for node_name in after_action_nodes:
output = self.graph.compute_node(node_name, data)
if isinstance(output, torch.Tensor):
data[node_name] = output
else: # network node
if node_name in penalty_nodes:
for _ensemble_matcher in self.ensemble_matcher_list:
if node_name in _ensemble_matcher.single_structure_dict['matching_fit_nodes']:
ensemble_matcher = _ensemble_matcher
break
if penalty_type == "filter":
scores, _ = self._get_filter_scores_ensemble(output, data, ensemble_matcher, node_name, candidate_num=1) # [candidate_num, batch_size]
penalty = 1 - scores.max(0)[0].unsqueeze(-1)
elif penalty_type == "filter_score_std":
scores, _ = self._get_filter_scores_ensemble(output, data, ensemble_matcher, node_name, candidate_num=sample_num) # [candidate_num, batch_size]
penalty = scores.std(0).unsqueeze(-1)
else:
raise NotImplementedError(f"penalty type {penalty_type} not supported!")
penaltys.append(penalty)
data[node_name] = output.mode
if clip:
data[node_name] = torch.clamp(data[node_name], -1, 1)
return torch.stack(penaltys).mean(0)
[docs]
def post_computation(self,
data : Dict[str, torch.Tensor],
deterministic : bool = True,
clip : bool = False,
policy_index : int = 0,
disturb_node_list : Dict = {},
**kwargs) -> Dict[str, torch.Tensor]:
'''run all the node after target node'''
# start = time.time()
self.check_version()
sample_fn = get_sample_function(deterministic)
for node_name in list(self.graph.keys())[self.index[policy_index]+1:]:
output = self.graph.compute_node(node_name, data)
if isinstance(output, torch.Tensor):
data[node_name] = output
else:
if "filter" in kwargs and kwargs['filter'] == True:
assert "candidate_num" in kwargs, f"filter requires param \'candidate_num\'"
for _ensemble_matcher in self.ensemble_matcher_list:
if node_name in _ensemble_matcher.single_structure_dict['matching_fit_nodes']:
ensemble_matcher = _ensemble_matcher
break
data[node_name] = self._filter_one_step_ensemble(output, data, ensemble_matcher, node_name, candidate_num=kwargs['candidate_num'])
else:
data[node_name] = sample_fn(output)
if node_name in disturb_node_list.keys():
_add_value = self._disturb(data,
disturb_node_list[node_name]['network'],
disturb_node_list[node_name]['input_nodes'],
disturb_node_list[node_name]['rnd_idx'],
disturb_node_list[node_name]['disturb_weight'])
data[node_name] += _add_value
if clip:
data[node_name] = torch.clamp(data[node_name], -1, 1)
# end = time.time()
# print(f"post: {end - start}")
return data
[docs]
def export2onnx(self, onnx_file : str, verbose : bool = True):
self.graph.export2onnx(onnx_file, verbose)
def _disturb(self, data, disturb_net, input_nodes, rnd_idx, disturb_weight):
with torch.no_grad():
input_tensor = [data[k2] for k2 in input_nodes]
input_tensor = torch.cat(input_tensor, dim=-1)
if input_tensor.dim() == 3:
input_tensor = input_tensor.squeeze(dim=0)
disturb_net.to(input_tensor.device)
rnd_disturbing = disturb_net(input_tensor) * disturb_weight
# print(f'[ DEBUG ] curernt batch[k] shape: {data[node_name].shape}, rnd_disturbing_shape: {rnd_disturbing.shape}')
rnd_disturbing = rnd_disturbing[rnd_idx, np.arange(rnd_disturbing.shape[1]), :]
# print(f'[ DEBUG ] curernt batch[k] shape: {data[node_name].shape}, rnd_disturbing_shape: {rnd_disturbing.shape}')
rnd_disturbing[rnd_idx < 0, ...] = 0
return rnd_disturbing
[docs]
class VirtualEnv:
def __init__(self, env_list : List[VirtualEnvDev]):
self._env = env_list[0]
self.env_list = env_list
self.graph = self._env.graph
self.revive_version = __version__
self.device = "cpu"
self.train_algo = self._env.train_algo
def __str__(self):
return f"REVIVE VERSION: {self.revive_version}\n" \
f"_env: {self._env}"
[docs]
def to(self, device):
r"""
Move model to the device specified by the parameter.
Examples::
>>> venv_model.to("cpu")
>>> venv_model.to("cuda")
>>> venv_model.to("cuda:1")
"""
if device != self.device:
self.device = device
for env in self.env_list:
env.to(device)
[docs]
def check_version(self):
r"""Check if the revive version of the saved model and the current revive version match."""
if not self.revive_version == __version__:
warnings.warn(f'detect the venv is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs]
def reset(self) -> None:
r"""
When using RNN for model training, this method needs to be called before model reuse
to reset the hidden layer information.
"""
for env in self.env_list:
env.reset()
[docs]
def set_env(self, env_id) -> None:
assert env_id < len(self.env_list)
self._env = self.env_list[env_id]
@property
def target_policy_name(self) -> str:
r''' Get the target policy name. '''
return self._env.target_policy_name
[docs]
def set_target_policy_name(self, target_policy_name) -> None:
r''' Set the target policy name. '''
for env in self.env_list:
env.set_target_policy_name(target_policy_name)
[docs]
def replace_policy(self, policy : 'PolicyModel') -> None:
''' Replace the target policy with the given policy. '''
assert self.target_policy_name == policy.target_policy_name, \
f'policy name does not match, require {self.target_policy_name} but get {policy.target_policy_name}!'
for env in self.env_list:
env.graph.nodes[self.target_policy_name] = policy._policy_model.node
[docs]
@torch.no_grad()
def infer_k_steps(self,
states : Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]],
k : Optional[int] = None,
deterministic : bool = True,
clip : bool = True) -> List[Dict[str, np.ndarray]]:
r"""
Generate k steps interactive data.
Args:
:states: a dict of initial input nodes
:k: how many steps to generate
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
Return:
k steps interactive data dict
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> ten_step_output = venv_model.infer_k_steps(state, k=10)
"""
self.check_version()
if isinstance(states, dict):
states = [states]
if k is not None:
for i in range(k - 1):
states.append({})
return self._env.infer_k_steps(deepcopy(states), deterministic, clip=clip)
[docs]
@torch.no_grad()
def infer_one_step(self,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True,
**kwargs) -> Dict[str, np.ndarray]:
r"""
Generate one step interactive data given action.
Args:
:state: a dict of input nodes
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
Return:
one step outputs
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> one_step_output = venv_model.infer_one_step(state)
"""
self.check_version()
return self._env.infer_one_step(deepcopy(state), deterministic, clip=clip, **kwargs)
[docs]
def reset_ensemble_matcher(self,
ensemble_size: int=10,
ensemble_choosing_interval: int=10):
self._env.reset_ensemble_matcher(ensemble_size, ensemble_choosing_interval)
[docs]
def node_pre_computation(self,
node_name: str,
node_data: np.ndarray):
self.check_version()
return self._env.node_pre_computation(node_name, node_data)
[docs]
def node_post_computation(self,
node_name: str,
node_data: np.ndarray):
self.check_version()
return self._env.node_pre_computation(node_name, node_data)
[docs]
@torch.no_grad()
def node_infer(self,
node_name: str,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True) -> Dict[str, np.ndarray]:
r"""
Generate one step interactive data given node_name.
Args:
:state: a dict of input nodes
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
Return:
one step node output
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> action_output = venv_model.node_infer("action", state)
"""
self.check_version()
return self._env.node_infer(node_name, deepcopy(state), deterministic, clip=clip)
[docs]
@torch.no_grad()
def node_dist(self,
node_name: str,
state : Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
r"""
Generate one step interactive dist given node_name.
Args:
:state: a dict of input nodes
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
Return:
one step node output
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> action_output = venv_model.node_infer("action", state)
"""
self.check_version()
return self._env.node_dist(node_name, deepcopy(state))
[docs]
def export2onnx(self, onnx_file : str, verbose : bool = True):
r"""
Exporting the model to onnx mode.
Reference: https://pytorch.org/docs/stable/onnx.html
Args:
:onnx_file: the onnx model file save path.
:verbose: if True, prints a description of the model being exported to stdout.
In addition, the final ONNX graph will include the field ``doc_string```
from the exported model which mentions the source code locations for ``model``.
"""
self._env.export2onnx(onnx_file, verbose)
[docs]
class PolicyModelDev(torch.nn.Module):
def __init__(self, nodes : List[DesicionNode,]):
super().__init__()
self.nodes = nodes
self.node = self.nodes[0]
self.models = [node.get_network() for node in self.nodes]
self.target_policy_name = [node.name for node in self.nodes]
self.target_policy_name = self.target_policy_name[0]
self.revive_version = __version__
self.device = "cpu"
self.reset()
[docs]
def to(self, device):
if device != self.device:
self.device = device
self.node.to(self.device)
[docs]
def check_version(self):
if not self.revive_version == __version__:
warnings.warn(f'detect the policy is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs]
def reset(self):
for node in self.nodes:
node.reset()
def _data_preprocess(self, data : np.ndarray, data_key : str = "obs") -> torch.Tensor:
data = self.node.processor.process_single(data, data_key)
data = to_torch(data, device=self.device)
return data
def _data_postprocess(self, data : torch.tensor, data_key : str = "action1") -> np.ndarray:
data = to_numpy(data)
data = self.node.processor.deprocess_single(data, data_key)
return data
[docs]
def infer(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = False) -> np.ndarray:
self.check_version()
state = deepcopy(state)
sample_fn = get_sample_function(deterministic)
for k, v in state.items():
state[k] = self._data_preprocess(v, data_key=k)
output = self.node(state)
if isinstance(output, torch.Tensor):
action = output
else:
action = sample_fn(output)
if clip: action = torch.clamp(action, -1, 1)
action = self._data_postprocess(action, self.target_policy_name)
return action
[docs]
def export2onnx(self, onnx_file : str, verbose : bool = True):
self.node.export2onnx(onnx_file, verbose)
[docs]
class PolicyModel:
def __init__(self, policy_model_dev : PolicyModelDev, post_process : Optional[Callable[[Dict[str, np.ndarray], np.ndarray], np.ndarray]] = None):
self._policy_model = policy_model_dev
self.post_process = post_process
self.revive_version = __version__
self.device = "cpu"
[docs]
def to(self, device: str):
r"""
Move model to the device specified by the parameter.
Examples::
>>> policy_model.to("cpu")
>>> policy_model.to("cuda")
>>> policy_model.to("cuda:1")
"""
if device != self.device:
self.device = device
self._policy_model.to(self.device)
[docs]
def check_version(self):
r"""Check if the revive version of the saved model and the current revive version match."""
if not self.revive_version == __version__:
warnings.warn(f'detect the policy is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs]
def reset(self):
r"""
When using RNN for model training, this method needs to be called before model reuse
to reset the hidden layer information.
"""
self._policy_model.reset()
@property
def target_policy_name(self) -> None:
r''' Get the target policy name. '''
return self._policy_model.target_policy_name
[docs]
@torch.no_grad()
def infer(self,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True,
additional_info : Optional[Dict[str, np.ndarray]] = None) -> np.ndarray:
r"""
Generate action according policy.
Args:
:state: a dict contain *ALL* the input nodes of the policy node
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
:clip:
if True, The output will be cropped to the range set in the yaml file;
if False, actions are generated by sample.
Default: True
:additional_info: a dict of additional info for post process
Return:
action
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> action = policy_model.infer(state)
"""
self.check_version()
action = self._policy_model.infer(deepcopy(state), deterministic, clip=clip)
if self.post_process is not None:
state.update(additional_info)
action = self.post_process(state, action)
return action
[docs]
def export2onnx(self, onnx_file : str, verbose : bool = True):
r"""
Exporting the model to onnx mode.
Reference: https://pytorch.org/docs/stable/onnx.html
Args:
:onnx_file: the onnx model file save path.
:verbose: if True, prints a description of the model being exported to stdout.
In addition, the final ONNX graph will include the field ``doc_string```
from the exported model which mentions the source code locations for ``model``.
"""
if self.post_process is not None:
warnings.warn('Currently, post process will not be exported.')
self._policy_model.export2onnx(onnx_file, verbose)