''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import torch
import warnings
from loguru import logger
import numpy as np
from copy import deepcopy
from typing import Callable, Dict, List, Union, Optional
from revive import __version__
from revive.computation.graph import DesicionGraph, DesicionNode
from revive.computation.utils import *
[docs]class VirtualEnvDev(torch.nn.Module):
def __init__(self, graph : DesicionGraph) -> None:
super(VirtualEnvDev, self).__init__()
self.models = torch.nn.ModuleList()
self.graph = graph
for node in self.graph.nodes.values():
if node.node_type == 'network':
self.models.append(node.get_network())
self.set_target_policy_name(list(self.graph.keys())) # default
self.revive_version = __version__
self.device = "cpu"
[docs] def to(self, device):
if device != self.device:
self.device = device
for node in self.graph.nodes.values():
if node.node_type == 'network':
try:
node.to(self.device)
except:
logger.warning(f'Failed to move device the network of "{node.name}"')
pass
[docs] def check_version(self):
if not self.revive_version == __version__:
warnings.warn(f'detect the venv is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self) -> None:
self.graph.reset()
[docs] def set_target_policy_name(self, target_policy_name : list) -> None:
self.target_policy_name = target_policy_name
# find target index
self.index = []
for i, (output_name, input_names) in enumerate(self.graph.items()):
if output_name in self.target_policy_name:
self.index.append(i)
self.index.sort()
def _data_preprocess(self, data : np.ndarray, data_key : str = "obs") -> torch.Tensor:
data = self.graph.processor.process_single(data, data_key)
data = to_torch(data, device=self.device)
return data
def _data_postprocess(self, data : torch.Tensor, data_key : str) -> np.ndarray:
data = to_numpy(data)
data = self.graph.processor.deprocess_single(data, data_key)
return data
def _infer_one_step(self,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True,
node_name : str = None) -> Dict[str, np.ndarray]:
self.check_version()
state = deepcopy(state)
sample_fn = get_sample_function(deterministic)
for k in list(state.keys()):
state[k] = self._data_preprocess(state[k], k)
# Support inferring nodes based on node name calls
if node_name is None:
compute_nodes = self.graph.keys()
else:
assert node_name in self.graph.keys()
assert not node_name in state.keys()
compute_nodes = [node_name, ]
for node_name in compute_nodes:
if not node_name in state.keys(): # skip provided values
output = self.graph.compute_node(node_name, state)
if isinstance(output, torch.Tensor):
state[node_name] = output
else:
state[node_name] = sample_fn(output)
if clip: state[node_name] = torch.clamp(state[node_name], -1, 1)
for k in list(state.keys()):
state[k] = self._data_postprocess(state[k], k)
return state
[docs] def infer_k_steps(self,
states : List[Dict[str, np.ndarray]],
deterministic : bool = True,
clip : bool = True) -> List[Dict[str, np.ndarray]]:
outputs = []
backup = {}
tunable = {name : states[0][name] for name in self.graph.tunable}
for state in states:
state.update(backup)
output = self._infer_one_step(state, deterministic=deterministic, clip=clip)
outputs.append(output)
backup = self.graph.state_transition(output)
backup.update(tunable)
return outputs
[docs] def infer_one_step(self,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True) -> Dict[str, np.ndarray]:
return self._infer_one_step(state, deterministic=deterministic, clip=clip)
[docs] def node_infer(self,
node_name : str,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True) -> Dict[str, np.ndarray]:
return self._infer_one_step(state, deterministic=deterministic, clip=clip, node_name=node_name)[node_name]
[docs] def node_dist(self,
node_name : str,
state : Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
self.check_version()
state = deepcopy(state)
for k in list(state.keys()):
state[k] = self._data_preprocess(state[k], k)
assert node_name in self.graph.keys()
assert not node_name in state.keys()
compute_nodes = [node_name, ]
for node_name in compute_nodes:
if not node_name in state.keys(): # skip provided values
output = self.graph.compute_node(node_name, state)
return output
[docs] def node_pre_computation(self, node_name: str, node_data: np.ndarray):
self.check_version()
return self._data_preprocess(node_data, node_name)
[docs] def node_post_computation(self, node_name: str, node_data: np.ndarray):
self.check_version()
return self._data_postprocess(node_data, node_name)
[docs] def forward(self,
data : Dict[str, torch.Tensor],
deterministic : bool = True,
clip : bool = False) -> Dict[str, torch.Tensor]:
''' run the target node '''
self.check_version()
sample_fn = get_sample_function(deterministic)
node_name = self.target_policy_name
output = self.graph.compute_node(node_name, data)
if isinstance(output, torch.Tensor):
data[node_name] = output
else:
data[node_name] = sample_fn(output)
if clip: data[node_name] = torch.clamp(data[node_name], -1, 1)
return data
[docs] def pre_computation(self,
data : Dict[str, torch.Tensor],
deterministic : bool = True,
clip : bool = False,
policy_index : int = 0) -> Dict[str, torch.Tensor]:
'''run all the node before target node. skip if the node value is already available.'''
self.check_version()
sample_fn = get_sample_function(deterministic)
for node_name in list(self.graph.keys())[:self.index[policy_index]]:
if not node_name in data.keys():
output = self.graph.compute_node(node_name, data)
if isinstance(output, torch.Tensor):
data[node_name] = output
else:
data[node_name] = sample_fn(output)
if clip: data[node_name] = torch.clamp(data[node_name], -1, 1)
# else:
# print(f'Skip {node_name}, since it is provided in the inputs.')
return data
[docs] def post_computation(self,
data : Dict[str, torch.Tensor],
deterministic : bool = True,
clip : bool = False,
policy_index : int = 0) -> Dict[str, torch.Tensor]:
'''run all the node after target node'''
self.check_version()
sample_fn = get_sample_function(deterministic)
for node_name in list(self.graph.keys())[self.index[policy_index]+1:]:
output = self.graph.compute_node(node_name, data)
if isinstance(output, torch.Tensor):
data[node_name] = output
else:
data[node_name] = sample_fn(output)
if clip: data[node_name] = torch.clamp(data[node_name], -1, 1)
return data
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True):
self.graph.export2onnx(onnx_file, verbose)
[docs]class VirtualEnv:
def __init__(self, env_list : List[VirtualEnvDev]):
self._env = env_list[0]
self.env_list = env_list
self.graph = self._env.graph
self.revive_version = __version__
self.device = "cpu"
[docs] def to(self, device):
r"""
Move model to the device specified by the parameter.
Examples::
>>> venv_model.to("cpu")
>>> venv_model.to("cuda")
>>> venv_model.to("cuda:1")
"""
if device != self.device:
self.device = device
for env in self.env_list:
env.to(device)
[docs] def check_version(self):
r"""Check if the revive version of the saved model and the current revive version match."""
if not self.revive_version == __version__:
warnings.warn(f'detect the venv is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self) -> None:
r"""
When using RNN for model training, this method needs to be called before model reuse
to reset the hidden layer information.
"""
for env in self.env_list:
env.reset()
[docs] def set_env(self, env_id) -> None:
assert env_id < len(env_list)
self._env = env_list[env_id]
@property
def target_policy_name(self) -> str:
r''' Get the target policy name. '''
return self._env.target_policy_name
[docs] def set_target_policy_name(self, target_policy_name) -> None:
r''' Set the target policy name. '''
for env in self.env_list:
env.set_target_policy_name(target_policy_name)
[docs] def replace_policy(self, policy : 'PolicyModel') -> None:
''' Replace the target policy with the given policy. '''
assert self.target_policy_name == policy.target_policy_name, \
f'policy name does not match, require {self.target_policy_name} but get {policy.target_policy_name}!'
for env in self.env_list:
env.graph.nodes[self.target_policy_name] = policy._policy_model.node
[docs] @torch.no_grad()
def infer_k_steps(self,
states : Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]],
k : Optional[int] = None,
deterministic : bool = True,
clip : bool = True) -> List[Dict[str, np.ndarray]]:
r"""
Generate k steps interactive data.
Args:
:states: a dict of initial input nodes
:k: how many steps to generate
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
Return:
k steps interactive data dict
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> ten_step_output = venv_model.infer_k_steps(state, k=10)
"""
self.check_version()
if isinstance(states, dict):
states = [states]
if k is not None:
for i in range(k - 1):
states.append({})
return self._env.infer_k_steps(deepcopy(states), deterministic, clip=clip)
[docs] @torch.no_grad()
def infer_one_step(self,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True) -> Dict[str, np.ndarray]:
r"""
Generate one step interactive data given action.
Args:
:state: a dict of input nodes
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
Return:
one step outputs
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> one_step_output = venv_model.infer_one_step(state)
"""
self.check_version()
return self._env.infer_one_step(deepcopy(state), deterministic, clip=clip)
[docs] def node_pre_computation(self,
node_name: str,
node_data: np.ndarray):
self.check_version()
return self._env.node_pre_computation(node_name, node_data)
[docs] def node_post_computation(self,
node_name: str,
node_data: np.ndarray):
self.check_version()
return self._env.node_pre_computation(node_name, node_data)
[docs] @torch.no_grad()
def node_infer(self,
node_name: str,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True) -> Dict[str, np.ndarray]:
r"""
Generate one step interactive data given node_name.
Args:
:state: a dict of input nodes
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
Return:
one step node output
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> action_output = venv_model.node_infer("action", state)
"""
self.check_version()
return self._env.node_infer(node_name, deepcopy(state), deterministic, clip=clip)
[docs] @torch.no_grad()
def node_dist(self,
node_name: str,
state : Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
r"""
Generate one step interactive dist given node_name.
Args:
:state: a dict of input nodes
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
Return:
one step node output
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> action_output = venv_model.node_infer("action", state)
"""
self.check_version()
return self._env.node_dist(node_name, deepcopy(state))
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True):
r"""
Exporting the model to onnx mode.
Reference: https://pytorch.org/docs/stable/onnx.html
Args:
:onnx_file: the onnx model file save path.
:verbose: if True, prints a description of the model being exported to stdout.
In addition, the final ONNX graph will include the field ``doc_string```
from the exported model which mentions the source code locations for ``model``.
"""
self._env.export2onnx(onnx_file, verbose)
[docs]class PolicyModelDev(torch.nn.Module):
def __init__(self, nodes : List[DesicionNode,]):
super().__init__()
self.nodes = nodes
self.node = self.nodes[0]
self.models = [node.get_network() for node in self.nodes]
self.target_policy_name = [node.name for node in self.nodes]
self.target_policy_name = self.target_policy_name[0]
self.revive_version = __version__
self.device = "cpu"
[docs] def to(self, device):
if device != self.device:
self.device = device
self.node.to(self.device)
[docs] def check_version(self):
if not self.revive_version == __version__:
warnings.warn(f'detect the policy is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self):
for node in self.nodes:
node.reset()
def _data_preprocess(self, data : np.ndarray, data_key : str = "obs") -> torch.Tensor:
data = self.node.processor.process_single(data, data_key)
data = to_torch(data, device=self.device)
return data
def _data_postprocess(self, data : torch.tensor, data_key : str = "action1") -> np.ndarray:
data = to_numpy(data)
data = self.node.processor.deprocess_single(data, data_key)
return data
[docs] def infer(self, state : Dict[str, np.ndarray], deterministic : bool = True, clip : bool = False) -> np.ndarray:
self.check_version()
state = deepcopy(state)
sample_fn = get_sample_function(deterministic)
for k, v in state.items():
state[k] = self._data_preprocess(v, data_key=k)
output = self.node(state)
if isinstance(output, torch.Tensor):
action = output
else:
action = sample_fn(output)
if clip: action = torch.clamp(action, -1, 1)
action = self._data_postprocess(action, self.target_policy_name)
return action
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True):
self.node.export2onnx(onnx_file, verbose)
[docs]class PolicyModel:
def __init__(self, policy_model_dev : PolicyModelDev, post_process : Optional[Callable[[Dict[str, np.ndarray], np.ndarray], np.ndarray]] = None):
self._policy_model = policy_model_dev
self.post_process = post_process
self.revive_version = __version__
self.device = "cpu"
[docs] def to(self, device: str):
r"""
Move model to the device specified by the parameter.
Examples::
>>> policy_model.to("cpu")
>>> policy_model.to("cuda")
>>> policy_model.to("cuda:1")
"""
if device != self.device:
self.device = device
self._policy_model.to(self.device)
[docs] def check_version(self):
r"""Check if the revive version of the saved model and the current revive version match."""
if not self.revive_version == __version__:
warnings.warn(f'detect the policy is create by version {self.revive_version}, but current version is {__version__}, maybe not compactable.')
[docs] def reset(self):
r"""
When using RNN for model training, this method needs to be called before model reuse
to reset the hidden layer information.
"""
self._policy_model.reset()
@property
def target_policy_name(self) -> None:
r''' Get the target policy name. '''
return self._policy_model.target_policy_name
[docs] @torch.no_grad()
def infer(self,
state : Dict[str, np.ndarray],
deterministic : bool = True,
clip : bool = True,
additional_info : Optional[Dict[str, np.ndarray]] = None) -> np.ndarray:
r"""
Generate action according policy.
Args:
:state: a dict contain *ALL* the input nodes of the policy node
:deterministic:
if True, the most likely actions are generated;
if False, actions are generated by sample.
Default: True
:clip:
if True, The output will be cropped to the range set in the yaml file;
if False, actions are generated by sample.
Default: True
:additional_info: a dict of additional info for post process
Return:
action
Examples::
>>> state = {"obs": obs_array, "static_obs": static_obs_array}
>>> action = policy_model.infer(state)
"""
self.check_version()
action = self._policy_model.infer(deepcopy(state), deterministic, clip=clip)
if self.post_process is not None:
state.update(additional_info)
action = self.post_process(state, action)
return action
[docs] def export2onnx(self, onnx_file : str, verbose : bool = True):
r"""
Exporting the model to onnx mode.
Reference: https://pytorch.org/docs/stable/onnx.html
Args:
:onnx_file: the onnx model file save path.
:verbose: if True, prints a description of the model being exported to stdout.
In addition, the final ONNX graph will include the field ``doc_string```
from the exported model which mentions the source code locations for ``model``.
"""
if self.post_process is not None:
warnings.warn('Currently, post process will not be exported.')
self._policy_model.export2onnx(onnx_file, verbose)