''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import torch
import warnings
import numpy as np
from loguru import logger
from copy import deepcopy
from collections import OrderedDict
from typing import Any, Callable, Dict, Iterable, List, Union
from revive.data.processor import DataProcessor
from revive.computation.dists import ReviveDistribution
from revive.computation.modules import *
[docs]
class DesicionNode:
''' An abstract node for making decisions '''
node_type : str = None # mark the type of the node
def __init__(self, name : str, input_names : List[str], input_descriptions : List[Dict[str, Dict[str, Any]]]):
assert len(input_descriptions) == len(input_names)
self.name = name
self.input_names = input_names
self.input_descriptions = input_descriptions
self.processor = None
self.reset()
[docs]
def __call__(self, data : Dict[str, torch.Tensor], *args, **kwargs) -> Union[torch.Tensor, ReviveDistribution]:
'''
Run a forward computation of this node.
NOTE: The input data was transferred by self.processor. You can use `self.processor.deprocess_torch(data)` to get the original data.
'''
self.before_compute_node()
kwargs["node"] = self
out = self.call(data, *args, **kwargs)
self.after_compute_node()
return out
[docs]
def register_processor(self, processor : DataProcessor):
''' register the global data processor to this node '''
self.processor = processor
[docs]
def remove_processor(self):
''' remove the registered data processor '''
self.processor = None
[docs]
def before_compute_node(self):
pass
[docs]
def after_compute_node(self):
if self.initial_state_flag:
self.initial_state_flag = False
self.trajectory_steps += 1
[docs]
def to(self, device : str) -> 'DesicionNode':
''' change the device of this node '''
return self
[docs]
def requires_grad_(self, mode : bool = False) -> 'DesicionNode':
''' change the requirement of gradient for this node '''
return self
[docs]
def train(self) -> 'DesicionNode':
''' set the state of this node to training '''
return self
[docs]
def eval(self) -> 'DesicionNode':
''' set the state of this node to evaluation '''
return self
[docs]
def reset(self) -> None:
''' reset the state of this node, useful when node is an RNN '''
self.initial_state_flag = True
self.trajectory_steps = 0
[docs]
def export2onnx(self, onnx_file : str, verbose : bool = True):
''' export the node to onnx file, with input from original data space '''
assert self.processor is not None, 'please register processor before export!'
node = deepcopy(self)
node = node.to('cpu')
node.requires_grad_(False)
node.eval()
class ExportHelper(torch.nn.Module):
def forward(self, state : Dict[str, torch.Tensor]) -> torch.Tensor:
state = node.processor.process_torch(state)
output = node(state)
if isinstance(output, torch.Tensor):
action = output
else:
action = output.mode
action = torch.clamp(action, -1, 1)
action = node.processor.deprocess_single_torch(action, node.name)
return action
demo_inputs = {}
dynamic_axes={}
for name, description in zip(self.input_names, self.input_descriptions):
demo_inputs[name] = torch.randn(len(description), dtype=torch.float32).unsqueeze(0)
dynamic_axes[name] = [0]
torch.onnx.export(ExportHelper(), (demo_inputs,{}), onnx_file, verbose=verbose, input_names=self.input_names, output_names=[self.name], dynamic_axes=dynamic_axes, opset_version=11)
def __str__(self) -> str:
info = []
info.append(f'node class : {type(self)}')
info.append(f'node name : {self.name}')
info.append(f'node inputs : {self.input_names}')
info.append(f'processor : {self.processor}')
info.append(f'node type : {self.node_type}')
return '\n'.join(info)
[docs]
class NetworkDecisionNode(DesicionNode):
node_type = 'network'
def __init__(self, name: str, input_names: List[str], input_descriptions: List[Dict[str, Dict[str, Any]]]):
super().__init__(name, input_names, input_descriptions)
self.network = None
self.input_is_dict = False
[docs]
def set_network(self, network : torch.nn.Module):
''' set the network from a different source '''
self.network = network
[docs]
def get_network(self) -> torch.nn.Module:
''' return all the network in this node '''
return self.network
[docs]
def initialize_network(self,
input_dim : Union[int, dict],
output_dim : int,
hidden_features : int,
hidden_layers : int,
backbone_type : str,
dist_config : list,
is_transition : bool = False,
hidden_activation : str = 'leakyrelu',
norm : str = None,
transition_mode : Optional[str] = None,
obs_dim : Optional[int] = None,
input_dim_dict : dict = {},
*args, **kwargs):
''' initialize the network of this node '''
self.initialize_network_parameters = {
"input_dim" : input_dim,
"output_dim" : output_dim,
"hidden_features" : hidden_features,
"hidden_layers" : hidden_layers,
"backbone_type" : backbone_type,
"dist_config" : dist_config,
"is_transition" : is_transition,
"hidden_activation" : hidden_activation,
"norm" : norm,
"transition_mode" : transition_mode,
"obs_dim" : obs_dim,
"input_dim_dict" : input_dim_dict,
"args" : args,
"kwargs" : kwargs,
}
if "ts_conv_config" in self.initialize_network_parameters["kwargs"] and self.initialize_network_parameters["kwargs"]["ts_conv_config"] != None:
self.ts_conv_node = True
if self.network:
_para = sum(param.sum().item() for param in self.network.parameters())
logger.warning(f"The node network {self.name} has been initialized. Skip initialization with sum {_para}")
time.sleep(2)
return
kwargs['node_name'] = self.name
ts_input_names = {k:v for k,v in self.ts_node_frames.items() if k in self.input_names}
if is_transition:
if backbone_type in ['mlp', 'res', 'ft_transformer']:
network = FeedForwardTransition(input_dim,
output_dim,
hidden_features,
hidden_layers,
norm=norm,
hidden_activation=hidden_activation,
dist_config=dist_config,
backbone_type=backbone_type,
mode=transition_mode,
obs_dim=obs_dim,
**kwargs)
elif backbone_type in ['gru', 'lstm', 'ts_transformer']:
assert isinstance(input_dim, int), "assert isinstance(input_dim, int)"
if ts_input_names:
self.input_is_dict = True
input_dim = int(sum([v/ts_input_names.get(k,1) for k,v in input_dim_dict.items()]))
if backbone_type in ['gru', 'lstm']:
network = TsRecurrentTransition(input_dim,
output_dim,
hidden_features,
hidden_layers,
norm=norm,
dist_config=dist_config,
backbone_type=backbone_type,
ts_input_names = ts_input_names,
input_dim_dict=input_dim_dict,
**kwargs)
else:
network = TsTransformerTransition(input_dim,
output_dim,
hidden_features,
hidden_layers,
norm=norm,
dist_config=dist_config,
ts_input_names = ts_input_names,
input_dim_dict=input_dim_dict,
**kwargs)
else:
network = RecurrentTransition(input_dim, output_dim,
hidden_features, hidden_layers,
norm=norm,
dist_config=dist_config,
backbone_type=backbone_type,
mode=transition_mode,
obs_dim=obs_dim,
**kwargs)
# rnn_hidden_features=kwargs['rnn_hidden_features'],
# window_size=kwargs['window_size'],
# **kwargs)
else:
raise ValueError(f'Initializing node `{self.name}`, backbone type {backbone_type} is not supported!')
else:
if backbone_type in ['mlp', 'res', 'ft_transformer']:
network = FeedForwardPolicy(input_dim, output_dim,
hidden_features, hidden_layers,
dist_config=dist_config,
norm=norm,
hidden_activation=hidden_activation,
backbone_type=backbone_type,
**kwargs)
elif backbone_type in ['gru', 'lstm', 'ts_transformer']:
assert isinstance(input_dim, int), "assert isinstance(input_dim, int)"
if ts_input_names:
self.input_is_dict = True
input_dim = int(sum([v/ts_input_names.get(k,1) for k,v in input_dim_dict.items()]))
if backbone_type in ['gru', 'lstm']:
network = TsRecurrentPolicy(input_dim,
output_dim,
hidden_features,
hidden_layers,
norm=norm,
dist_config=dist_config,
backbone_type=backbone_type,
ts_input_names = ts_input_names,
input_dim_dict=input_dim_dict,
**kwargs)
else:
network = TsTransformerPolicy(input_dim,
output_dim,
hidden_features,
hidden_layers,
norm=norm,
dist_config=dist_config,
ts_input_names = ts_input_names,
input_dim_dict=input_dim_dict,
**kwargs)
else:
network = RecurrentPolicy(input_dim,
output_dim,
hidden_features,
hidden_layers,
norm=norm,
dist_config=dist_config,
backbone_type=backbone_type,
**kwargs)
elif backbone_type in ['contextual_gru', 'contextual_lstm']:
assert isinstance(input_dim, int), "assert isinstance(input_dim, int)"
network = ContextualPolicy(input_dim,
output_dim,
hidden_features,
hidden_layers,
dist_config,
backbone_type,
**kwargs)
else:
raise ValueError(f'Initializing node `{self.name}`, backbone type {backbone_type} is not supported!')
self.network = network
logger.info(f"The node network {self.name} has been initialized -> {self.network.__class__.__name__}")
[docs]
def call(self, data : Dict[str, torch.Tensor], *args, **kwargs) -> ReviveDistribution:
'''
Run a forward computation of this node.
NOTE: The input data was transferred by self.processor. You can use `self.processor.deprocess_torch(data)` to get the original data.
'''
data = self.get_inputs(data)
inputs = data if hasattr(self, 'custom_node') or hasattr(self, 'ts_conv_node') or self.input_is_dict or hasattr(self.network, 'input_is_dict') else torch.cat([data[k] for k in self.input_names], dim=-1)
kwargs['input_names'] = self.input_names
output_dist = self.network(inputs, *args, **kwargs)
return output_dist
[docs]
def to(self, device : str) -> 'DesicionNode':
''' change the device of this node '''
self.network = self.network.to(device)
return self
[docs]
def requires_grad_(self, mode : bool = False) -> 'DesicionNode':
''' change the requirement of gradient for this node '''
self.network.requires_grad_(mode)
return self
[docs]
def train(self) -> 'DesicionNode':
''' set the state of this node to training '''
self.network.train()
return self
[docs]
def eval(self) -> 'DesicionNode':
''' set the state of this node to evaluation '''
self.network.eval()
return self
[docs]
def reset(self) -> None:
''' reset the state of this node, useful when node is an RNN '''
super().reset()
try:
self.network.reset()
except:
pass
return self
def __str__(self) -> str:
info = [super(NetworkDecisionNode, self).__str__()]
info.append(f'network : {self.network}')
return '\n'.join(info)
[docs]
class FunctionDecisionNode(DesicionNode):
node_type = 'function'
def __init__(self, name: str, input_names: List[str], input_descriptions: List[Dict[str, Dict[str, Any]]]):
super().__init__(name, input_names, input_descriptions)
self.node_function = None
self.node_function_type = None
[docs]
def register_node_function(self,
node_function : Union[Callable[[Dict[str, np.ndarray]], np.ndarray],
Callable[[Dict[str, torch.Tensor]], torch.Tensor]],
node_function_type : str):
self.node_function = node_function
self.node_function_type = node_function_type
[docs]
def remove_node_function(self):
self.node_function = None
self.node_function_type = None
[docs]
def delta_node_function(self, data):
assert self.name.startswith("next_")
node_name = self.name[5:]
delta_node_name = "delta_" + node_name
return data[node_name] + data[delta_node_name]
[docs]
def call(self, data : Dict[str, torch.Tensor], *args, **kwargs) -> Union[torch.Tensor, ReviveDistribution]:
''' NOTE: if there is any provided function defined in numpy, this process cannot maintain gradients '''
data = self.get_inputs(data)
torch_data = list(data.values())[0]
data_type = 'torch'
deprocessed_data = self.processor.deprocess_torch(data)
# Automatically detect node function type
if self.node_function_type == 'none':
try:
output = self.node_function(deprocessed_data)
if torch.isinf(torch.mean(output)).item():
logger.error(f"Find inf in {self.name} node function output {output}")
raise ValueError(f"Find inf in {self.name} node function output {output}")
if torch.isnan(torch.mean(output)).item():
logger.error(f"Find nan in {self.name} node function output {output}")
raise ValueError(f"Find nan in {self.name} node function output {output}")
output = self.processor.process_single_torch(output, self.name)
self.node_function_type = "torch"
logger.warning(f"Automatically parse '{self.name}' node function into a torch functions")
except:
self.node_function_type = "numpy"
logger.warning(f"Automatically parse '{self.name}' node function into a numpy functions")
if self.node_function_type == 'numpy':
for k in deprocessed_data.keys(): deprocessed_data[k] = deprocessed_data[k].detach().cpu().numpy() # torch -> numpy
data_type = 'numpy'
output = self.node_function(deprocessed_data)
if data_type == 'numpy':
if np.isinf(np.mean(output)).item():
logger.error(f"Find inf in {self.name} node function output {output}")
raise ValueError(f"Find inf in {self.name} node function output {output}")
if np.isnan(np.mean(output)).item():
logger.error(f"Find nan in {self.name} node function output {output}")
raise ValueError(f"Find nan in {self.name} node function output {output}")
output = self.processor.process_single(output, self.name)
output = torch.as_tensor(output).to(torch_data) # numpy -> torch
else:
if torch.isinf(torch.mean(output)).item():
logger.error(f"Find inf in {self.name} node function output {output}")
raise ValueError(f"Find inf in {self.name} node function output {output}")
if torch.isnan(torch.mean(output)).item():
logger.error(f"Find nan in {self.name} node function output {output}")
raise ValueError(f"Find nan in {self.name} node function output {output}")
output = self.processor.process_single_torch(output, self.name)
return output
[docs]
def export2onnx(self, onnx_file : str, verbose : bool = True):
''' export the node to onnx file, with input from original data space '''
if self.node_function_type == 'numpy':
warnings.warn(f'Detect function in node `{self.name}` with type numpy, export may be incorrect.')
super(FunctionDecisionNode, self).export2onnx(onnx_file, verbose)
def __str__(self) -> str:
info = [super(FunctionDecisionNode, self).__str__()]
info.append(f'node function : {self.node_function}')
info.append(f'node function type: {self.node_function_type}')
return '\n'.join(info)
[docs]
def reset(self) -> None:
''' reset the state of this node, useful when node is an RNN '''
super().reset()
return self
[docs]
class DesicionGraph:
r''' A collection of DecisionNodes '''
def __init__(self,
graph_dict : Dict[str, List[str]],
descriptions : Dict[str, List[Dict[str, Dict[str, Any]]]],
fit,
metric_nodes) -> None:
self.descriptions = descriptions
self.graph_list = []
self.graph_dict = self.sort_graph(graph_dict)
self.fit = fit
self.leaf = self.get_leaf(self.graph_dict)
self.transition_map = self._get_transition_map(self.graph_dict)
self.external_factors = list(filter(lambda x: not x in self.transition_map.keys(), self.leaf))
self.tunable = []
self.nodes = OrderedDict()
for node_name in self.graph_dict.keys():
self.nodes[node_name] = None
if metric_nodes is None:
self.metric_nodes = list(self.nodes.keys())
else:
self.metric_nodes = []
for node in self.nodes.keys():
if node in metric_nodes:
if node in self.leaf:
logger.info(f"Node '{node}' is a leaf node, it should't be a metric node.")
continue
assert node in self.nodes.keys(), f"Metric node '{node}' is not in Graph, Please check yaml."
self.metric_nodes.append(node)
assert len(self.metric_nodes) >= 1, f"At least one non-leaf node is required for metric."
self.is_target_network = False
[docs]
def register_node(self, node_name : str, node_class):
r''' Register a node with given node class '''
assert self.nodes[node_name] is None, f'Cannot register node `{node_name}`, the node is already registered as `{type(self.nodes[node_name])}`'
input_names = self.graph_dict[node_name]
self.nodes[node_name] = node_class(node_name, input_names, [self.descriptions[input_name] for input_name in input_names])
self.nodes[node_name].ts_node_frames = self.ts_node_frames
# TODO: UPDATE
if node_class.node_type == 'function':
if node_name in self.metric_nodes:
self.metric_nodes.remove(node_name)
assert len(self.metric_nodes) >= 1, f"At least one non-leaf node is required for metric."
@property
def learnable_node_names(self) -> List[str]:
r'''A list of names for learnable nodes the graph'''
node_names = []
for node_name, node in self.nodes.items():
if not node.node_type == 'function':
node_names.append(node_name)
return node_names
[docs]
def register_target_nodes(self):
self.target_nodes = deepcopy(self.nodes)
[docs]
def del_target_nodes(self):
assert not self.is_target_network
del self.target_nodes
[docs]
def use_target_network(self,):
if self.is_target_network is False:
self.target_nodes, self.nodes = self.nodes, self.target_nodes
self.is_target_network = True
[docs]
def not_use_target_network(self,):
if self.is_target_network is True:
self.target_nodes, self.nodes = self.nodes, self.target_nodes
self.is_target_network = False
[docs]
def update_target_network(self, polyak=0.99):
with torch.no_grad():
for node_name, node in self.nodes.items():
if not node.node_type == 'function':
target_node = self.target_nodes[node_name]
for p, p_targ in zip(node.network.parameters(), target_node.network.parameters()):
# NB: We use an in-place operations "mul_", "add_" to update target
# params, as opposed to "mul" and "add", which would make new tensors.
p_targ.data.mul_(polyak)
p_targ.data.add_((1 - polyak) * p.data)
[docs]
def mark_tunable(self, node_name : str) -> None:
r'''Mark a leaf variable as tunable'''
assert node_name in self.external_factors, 'Only external factors can be tunable!'
if node_name in self.tunable:
warnings.warn(f'{node_name} is already marked as a tunable node, skip.')
else:
self.tunable.append(node_name)
[docs]
def register_processor(self, processor : DataProcessor):
r'''Register data processor to the graph and nodes'''
self.processor = processor
for node in self.nodes.values():
node.register_processor(self.processor)
[docs]
def get_node(self, node_name : str, use_target: bool = False) -> DesicionNode:
'''get the node by name'''
if self.nodes[node_name].node_type == 'network':
if use_target:
assert hasattr(self, "target_nodes"), "Not have target nodes. You should register target nodes firstly."
return self.target_nodes[node_name]
return self.nodes[node_name]
[docs]
def compute_node(self, node_name : str, inputs : Dict[str, torch.Tensor], use_target: bool = False, *args, **kwargs):
'''compute the node by name'''
if node_name in self.freeze_nodes:
with torch.no_grad():
# Those nodes which are treated as freezed nodes,
# the network bounding to them should be at eval() mode more than no_grad mode
if self.nodes[node_name].network.training:
self.nodes[node_name].network.eval()
return self.get_node(node_name, use_target)(inputs, *args, **kwargs)
else:
return self.get_node(node_name, use_target)(inputs, *args, **kwargs)
[docs]
def get_relation_node_names(self) -> List[str]:
'''
get all the nodes that related to the learning (network) nodes.
NOTE: this is the default list if you have matcher and value functions.
'''
node_names = []
for node in self.nodes.values():
node_name = node.name
input_names = node.input_names
if not (node.node_type == 'function'): # skip function nodes
for name in input_names + [node_name]:
if not (name in node_names):
node_names.append(name)
return node_names
[docs]
def get_node_value_net_node_names(self) -> List[str]:
'''
get all the nodes that related to the learning (network) nodes.
NOTE: this is the default list if you have matcher and value functions.
'''
node_names = []
for node in self.nodes.values():
if not (node.node_type == 'function'): # skip function nodes
node_names.append(node.name)
return node_names
[docs]
def summary_nodes(self) -> Dict[str, int]:
network_nodes = 0
function_nodes = 0
unregistered_nodes = 0
unknown_nodes = 0
for node in self.nodes.values():
if node is None:
unregistered_nodes += 1
else:
if node.node_type == 'network':
network_nodes += 1
elif node.node_type == 'function':
function_nodes += 1
else:
unknown_nodes += 1
return {
'network_nodes' : network_nodes,
'function_nodes' : function_nodes,
'unregistered_nodes' : unregistered_nodes,
'unknown_nodes' : unknown_nodes,
}
[docs]
def collect_models(self) -> List[torch.nn.Module]:
'''return all the network that registered in this graph'''
return [node.get_network() for node in self.nodes.values() if node.node_type == 'network']
[docs]
def is_equal_venv(self, source_graph : 'DesicionGraph', policy_node) -> bool:
''' check if new graph shares the same virtual environments '''
for node_name in self.nodes.keys():
target_node = self.get_node(node_name)
if node_name in policy_node: # do not judge policy node
continue
if target_node.node_type == "function": # do not judge env node with defined expert function
logger.warning(f'Detected "{node_name}" is attached with a expert function. Please check if it is RIGHT?!')
continue
source_node = source_graph.get_node(node_name)
if target_node.input_names != source_node.input_names:
logger.warning(f'Detected "{node_name}" is attached with different input names. Please check if it is RIGHT?!')
return False
if target_node.input_descriptions != source_node.input_descriptions:
logger.warning(f'Detected "{node_name}" is attached with different input descriptions. Please check if it is RIGHT?!')
return False
if target_node.node_type != source_node.node_type:
logger.warning(f'Detected "{node_name}" is attached with different input node_type. Please check if it is RIGHT?!')
return False
return True
[docs]
def is_equal_structure(self, source_graph : 'DesicionGraph') -> bool:
''' check if new graph shares the same structure '''
if self.graph_dict != source_graph.graph_dict:
return False
for node_name in self.nodes.keys():
target_node = self.get_node(node_name)
source_node = source_graph.get_node(node_name)
if target_node.input_names != source_node.input_names:
return False
if target_node.input_descriptions != source_node.input_descriptions:
return False
if target_node.node_type != source_node.node_type:
return False
if target_node.node_type == "function":
return False
return True
[docs]
def copy_graph_node(self, source_graph : 'DesicionGraph') -> bool:
'''try copy all the node from source graph '''
for node_name in self.nodes.keys():
target_node = self.get_node(node_name)
if node_name not in source_graph.nodes.keys():
if target_node.node_type == 'network':
logger.warning(f'Find new network node "{node_name}" is not in source_graph. Initialize a new network node.')
elif target_node.node_type == 'function':
logger.warning(f'Find new function node "{node_name}" is not in source_graph.')
else:
raise NotImplementedError
else:
try:
source_node = source_graph.get_node(node_name)
assert target_node.input_names == source_node.input_names
assert target_node.input_descriptions == source_node.input_descriptions
if source_node.node_type == 'network' and target_node.node_type == 'network':
target_node.set_network(source_node.get_network())
para = sum(param.sum().item() for param in source_node.get_network().parameters())
logger.warning(f'Network copy "{node_name}" node with parameters sum {para}.')
if hasattr(source_node, 'ts_conv_node'):
target_node.ts_conv_node = True
logger.info(f'Successfully copy "{node_name}" node from source_graph.')
except Exception as e:
logger.warning(f'Failed copy "{node_name}" node from source_graph. {e}')
[docs]
def get_leaf(self, graph : Dict[str, List[str]] = None) -> List[str]:
''' return the leaf of the graph in *alphabet order* '''
if graph is None:
graph = self
outputs = [name for name in graph.keys()]
inputs = []
for names in graph.values(): inputs += names
inputs = set(inputs)
leaf = [name for name in inputs if name not in outputs]
leaf = sorted(leaf) # make sure the order of leaf is fixed given a graph
return leaf
def _get_transition_map(self, graph : Dict[str, List[str]]) -> Dict[str, str]:
outputs = [name for name in graph.keys()]
inputs = []
for names in graph.values(): inputs += names
inputs = set(inputs)
# transition = {name[5:] : name for name in outputs if name.startswith('next_') and name[5:] in inputs}
transition = {name[5:] : name for name in outputs if name.startswith('next_') and (name[5:] in inputs or "ts_"+name[5:] in inputs)}
return transition
[docs]
def sort_graph(self, graph_dict : dict) -> OrderedDict:
'''Sort arbitrary computation graph to the topological order'''
ordered_graph = OrderedDict()
computed = self.get_leaf(graph_dict)
self.graph_list = deepcopy(computed)
# sort output
while len(graph_dict) > 0:
find_new_node = False
for output_name in sorted(graph_dict.keys()):
input_names = graph_dict[output_name]
if all([name in computed for name in input_names]):
ordered_graph[output_name] = graph_dict.pop(output_name)
computed.append(output_name)
self.graph_list.append(output_name)
find_new_node = True
break
if not find_new_node:
raise ValueError('Cannot find any computable node, check if there are loops or isolations on the graph!\n' + \
f'current computed nodes: {computed}, node waiting to be computed: {graph_dict}')
# sort input
for output_name, input_names in ordered_graph.items():
sorted_input_names = []
if output_name.startswith('next_') and output_name[5:] in input_names:
sorted_input_names.append(output_name[5:])
input_names.pop(input_names.index(output_name[5:]))
for name in computed:
if name in input_names:
for i_name in input_names:
if i_name == name:
sorted_input_names.append(name)
sorted_input_names.sort()
ordered_graph[output_name] = sorted_input_names
assert self._is_sort_graph(ordered_graph), f"{ordered_graph}, graph is not correctly sorted!"
return ordered_graph
def _is_sort_graph(self, graph : OrderedDict) -> bool:
''' check if a graph is sorted '''
computed = self.get_leaf(graph)
for output_name, input_names in graph.items():
for name in input_names:
if not name in computed:
return False
computed.append(output_name)
return True
[docs]
def to(self, device : str) -> 'DesicionGraph':
for node in self.nodes.values(): node.to(device)
return self
[docs]
def requires_grad_(self, mode : bool = False) -> 'DesicionGraph':
for node in self.nodes.values(): node.requires_grad_(mode)
return self
[docs]
def eval(self) -> 'DesicionGraph':
for node in self.nodes.values(): node.eval()
return self
[docs]
def reset(self) -> 'DesicionGraph':
''' reset graph, useful for stateful graph '''
for node in self.nodes.values():
node.reset()
return self
[docs]
def __getitem__(self, name : str) -> List[str]:
return self.graph_dict[name]
[docs]
def keys(self) -> Iterable[str]:
return self.graph_dict.keys()
[docs]
def values(self) -> Iterable[List[str]]:
return self.graph_dict.values()
[docs]
def items(self) -> Iterable[Dict[str, List[str]]]:
return self.graph_dict.items()
[docs]
def __len__(self) -> int:
return len(self.graph_dict)
def __str__(self) -> str:
node_info = [node.__str__() for node in self.nodes.values()]
return '\n\n'.join(node_info)
[docs]
def __call__(self, state : Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
''' compute the whole graph, from leaf node to all output nodes '''
assert all([node_name in state.keys() for node_name in self.leaf])
for node_name, node in self.nodes.items():
output = node(state)
if isinstance(output, torch.Tensor):
action = output
else:
action = output.mode
action = torch.clamp(action, -1, 1)
state[node_name] = action
actions = {node_name : state[node_name] for node_name in self.nodes.keys()}
return actions
[docs]
def state_transition(self, state : Dict[str, torch.Tensor], copy : bool = False) -> Dict[str, torch.Tensor]:
new_state = {}
for new_name, old_name in self.transition_map.items():
new_state[new_name] = state[old_name]
if copy:
new_state = deepcopy(new_state)
return new_state
[docs]
def export2onnx(self, onnx_file : str, verbose : bool = True):
''' export the graph to onnx file, with input from original data space '''
assert self.processor is not None, 'please register processor before export!'
for node_name, node in self.nodes.items():
if node.node_type == 'function':
if node.node_function_type == 'numpy':
warnings.warn(f'Detect function in node {node_name} with type numpy, may be incorrect.')
graph = deepcopy(self)
graph = graph.reset()
graph = graph.to('cpu')
graph.requires_grad_(False)
graph.eval()
class ExportHelper(torch.nn.Module):
def forward(self, state : Dict[str, torch.Tensor]) -> torch.Tensor:
state = graph.processor.process_torch(state)
actions = graph(state)
actions = graph.processor.deprocess_torch(actions)
return tuple([actions[node_name] for node_name in graph.nodes.keys()])
demo_inputs = {}
dynamic_axes={}
for name in self.leaf:
description = self.descriptions[name]
demo_inputs[name] = torch.randn(len(description), dtype=torch.float32).unsqueeze(0)
dynamic_axes[name] = [0]
torch.onnx.export(ExportHelper(), (demo_inputs,{}), onnx_file, verbose=verbose, input_names=self.leaf, output_names=list(self.nodes.keys()), dynamic_axes=dynamic_axes, opset_version=11)