''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
from loguru import logger
import time
import torch
import math
import warnings
from copy import deepcopy
from torch import nn#, vmap
# from torch.func import stack_module_state, functional_call
from typing import Optional, Union, List, Dict, cast
from collections import OrderedDict, deque
from revive.computation.dists import *
from revive.computation.utils import *
[docs]
def reglu(x):
a, b = x.chunk(2, dim=-1)
return a * F.relu(b)
[docs]
def geglu(x):
a, b = x.chunk(2, dim=-1)
return a * F.gelu(b)
[docs]
class Swish(nn.Module):
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
ACTIVATION_CREATORS = {
'relu' : lambda dim: nn.ReLU(inplace=True),
'elu' : lambda dim: nn.ELU(),
'leakyrelu' : lambda dim: nn.LeakyReLU(negative_slope=0.1, inplace=True),
'tanh' : lambda dim: nn.Tanh(),
'sigmoid' : lambda dim: nn.Sigmoid(),
'identity' : lambda dim: nn.Identity(),
'prelu' : lambda dim: nn.PReLU(dim),
'gelu' : lambda dim: nn.GELU(),
'reglu' : reglu,
'geglu' : geglu,
'swish' : lambda dim: Swish(),
}
# -------------------------------- Layers -------------------------------- #
[docs]
class Value_Net(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self,
input_dim_dict,
value_hidden_features,
value_hidden_layers,
value_normalization,
value_activation,
*args, **kwargs) -> None:
super().__init__()
self.kwargs = kwargs
self.input_dim_dict = input_dim_dict
if isinstance(input_dim_dict, dict):
input_dim = sum(input_dim_dict.values())
else:
input_dim = input_dim_dict
self.value = MLP(input_dim, 1, value_hidden_features, value_hidden_layers,
norm=value_normalization, hidden_activation=value_activation)
if 'ts_conv_config' in self.kwargs.keys() and self.kwargs['ts_conv_config'] != None:
other_state_layer_output = input_dim//2
ts_state_conv_layer_output = input_dim//2 + input_dim%2
self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=value_activation)
all_node_ts = None
self.conv_ts_node = []
self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input'])
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
self.conv_ts_node.append('__history__'+tsnode)
self.conv_other_node.append('__now__'+tsnode)
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
if all_node_ts == None:
all_node_ts = node_ts
assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}"
#ts-1 remove now state
kernalsize = all_node_ts-1
self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=value_activation)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
if hasattr(self, 'conv_ts_node'):
#deal with all ts nodes with ts dim
assert hasattr(self, 'conv_ts_node')
assert hasattr(self, 'conv_other_node')
inputdata = deepcopy(state.detach()) #data is deteched from other node
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
# assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True
batch_size = np.prod(list(inputdata[tsnode].shape[:-1]))
original_size = list(inputdata[tsnode].shape[:-1])
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
# inputdata[tsnode]
temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1)
inputdata['__history__'+tsnode] = temp_data[..., :-1, :]
inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1])
state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1)
state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1)
state_other = self.other_state_layer(state_other)
original_size = list(state_other.shape[:-1])
state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1])
state = torch.cat([state_other, state_ts_conv], dim=-1)
else:
state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1)
output = self.value(state)
return output
[docs]
class ConvBlock(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self,
in_features: int,
out_features: int,
conv_time_step: int,
output_activation : str = 'identity',
):
super().__init__()
output_activation_creator = ACTIVATION_CREATORS[output_activation]
# self.in_features = in_features
# self.out_features = out_features
self.mlp_net = nn.Sequential(
nn.Linear(in_features, out_features),
output_activation_creator(out_features)
)
self.conv_net = nn.Conv1d(in_channels=out_features, out_channels=out_features, groups=out_features, kernel_size=conv_time_step)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp_net(x)
#exchange timestep dim with feature dime
x = x.transpose(-1,-2)
#combine the horizon and batchsize dims
out = self.conv_net(x).squeeze(-1)
return out
[docs]
class VectorizedLinear(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self, in_features: int, out_features: int, ensemble_size: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size
self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))
self.reset_parameters()
[docs]
def reset_parameters(self):
# default pytorch init for nn.Linear module
for layer in range(self.ensemble_size):
nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
input: [ensemble_size, batch_size, input_size]
weight: [ensemble_size, input_size, out_size]
out: [ensemble_size, batch_size, out_size]
"""
# out = torch.einsum("kij,kjl->kil", x, self.weight) + self.bias
out = x @ self.weight + self.bias
# out = torch.bmm(x, self.weight) + self.bias
return out
[docs]
class EnsembleLinear(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
num_ensemble: int,
weight_decay: float = 0.0
) -> None:
super().__init__()
self.num_ensemble = num_ensemble
self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim)))
self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim)))
nn.init.trunc_normal_(self.weight, std=1 / (2 * input_dim ** 0.5))
self.register_parameter("saved_weight", nn.Parameter(self.weight.detach().clone()))
self.register_parameter("saved_bias", nn.Parameter(self.bias.detach().clone()))
self.weight_decay = weight_decay
self.device = torch.device('cpu')
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.weight
bias = self.bias
if len(x.shape) == 2:
x = torch.einsum('ij,bjk->bik', x, weight)
elif len(x.shape) == 3:
if x.shape[0] == weight.data.shape[0]:
x = torch.einsum('bij,bjk->bik', x, weight)
else:
x = torch.einsum('cij,bjk->bcik', x, weight)
elif len(x.shape) == 4:
if x.shape[0] == weight.data.shape[0]:
x = torch.einsum('cbij,cjk->cbik', x, weight)
else:
x = torch.einsum('cdij,bjk->bcdik', x, weight)
elif len(x.shape) == 5:
x = torch.einsum('bcdij,bjk->bcdik', x, weight)
assert x.shape[0] == bias.shape[0] and x.shape[-1] == bias.shape[-1]
if len(x.shape) == 4:
bias = bias.unsqueeze(1)
elif len(x.shape) == 5:
bias = bias.unsqueeze(1)
bias = bias.unsqueeze(1)
x = x + bias
return x
[docs]
def load_save(self) -> None:
self.weight.data.copy_(self.saved_weight.data)
self.bias.data.copy_(self.saved_bias.data)
[docs]
def update_save(self, indexes: List[int]) -> None:
self.saved_weight.data[indexes] = self.weight.data[indexes]
self.saved_bias.data[indexes] = self.bias.data[indexes]
[docs]
def get_decay_loss(self) -> torch.Tensor:
decay_loss = self.weight_decay * (0.5 * ((self.weight ** 2).sum()))
return decay_loss
[docs]
def to(self, device):
if not device == self.device:
self.device = device
super().to(device)
self.weight = self.weight.to(self.device)
self.bias = self.bias.to(self.device)
self.saved_weight = self.saved_weight.to(self.device)
self.saved_bias = self.saved_bias.to(self.device)
# -------------------------------- Backbones -------------------------------- #
[docs]
class MLP(nn.Module):
r"""
Multi-layer Perceptron
Args:
in_features : int, features numbers of the input
out_features : int, features numbers of the output
hidden_features : int, features numbers of the hidden layers
hidden_layers : int, numbers of the hidden layers
norm : str, normalization method between hidden layers, default : None
hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu'
output_activation : str, activation function used in output layer, default : 'identity'
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str = None,
hidden_activation : str = 'leakyrelu',
output_activation : str = 'identity'):
super(MLP, self).__init__()
hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation]
output_activation_creator = ACTIVATION_CREATORS[output_activation]
if hidden_layers == 0:
self.net = nn.Sequential(
nn.Linear(in_features, out_features),
output_activation_creator(out_features)
)
else:
net = []
for i in range(hidden_layers):
net.append(nn.Linear(in_features if i == 0 else hidden_features, hidden_features))
if norm:
if norm == 'ln':
net.append(nn.LayerNorm(hidden_features))
elif norm == 'bn':
net.append(nn.BatchNorm1d(hidden_features))
else:
raise NotImplementedError(f'{norm} does not supported!')
net.append(hidden_activation_creator(hidden_features))
net.append(nn.Linear(hidden_features, out_features))
net.append(output_activation_creator(out_features))
self.net = nn.Sequential(*net)
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
r""" forward method of MLP only assume the last dim of x matches `in_features` """
return self.net(x)
[docs]
class VectorizedMLP(nn.Module):
r"""
Vectorized MLP
Args:
in_features : int, features numbers of the input
out_features : int, features numbers of the output
hidden_features : int, features numbers of the hidden layers
hidden_layers : int, numbers of the hidden layers
norm : str, normalization method between hidden layers, default : None
hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu'
output_activation : str, activation function used in output layer, default : 'identity'
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
norm : str = None,
hidden_activation : str = 'leakyrelu',
output_activation : str = 'identity'):
super(VectorizedMLP, self).__init__()
self.ensemble_size = ensemble_size
hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation]
output_activation_creator = ACTIVATION_CREATORS[output_activation]
if hidden_layers == 0:
self.net = nn.Sequential(
VectorizedLinear(in_features, out_features, ensemble_size),
output_activation_creator(out_features)
)
else:
net = []
for i in range(hidden_layers):
net.append(VectorizedLinear(in_features if i == 0 else hidden_features, hidden_features, ensemble_size))
if norm:
if norm == 'ln':
net.append(nn.LayerNorm(hidden_features))
elif norm == 'bn':
net.append(nn.BatchNorm1d(hidden_features))
else:
raise NotImplementedError(f'{norm} does not supported!')
net.append(hidden_activation_creator(hidden_features))
net.append(VectorizedLinear(hidden_features, out_features, ensemble_size))
net.append(output_activation_creator(out_features))
self.net = nn.Sequential(*net)
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
r""" forward method of MLP only assume the last dim of x matches `in_features` """
if x.dim() == 2: # [batch_size, dim]
x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [ensemble_size, batch_size, x_dim]
assert x.dim() == 3
assert x.shape[0] == self.ensemble_size
return self.net(x)
[docs]
class ResBlock(nn.Module):
"""
Initializes a residual block instance.
Args:
input_feature (int): The number of input features to the block.
output_feature (int): The number of output features from the block.
norm (str, optional): The type of normalization to apply to the block. Default is 'ln' for layer normalization.
"""
def __init__(self, input_feature : int, output_feature : int, norm : str = 'ln'):
super().__init__()
if norm == 'ln':
norm_class = torch.nn.LayerNorm
self.process_net = torch.nn.Sequential(
torch.nn.Linear(input_feature, output_feature),
norm_class(output_feature),
torch.nn.ReLU(True),
torch.nn.Linear(output_feature, output_feature),
norm_class(output_feature),
torch.nn.ReLU(True)
)
else:
self.process_net = torch.nn.Sequential(
torch.nn.Linear(input_feature, output_feature),
torch.nn.ReLU(True),
torch.nn.Linear(output_feature, output_feature),
torch.nn.ReLU(True)
)
if not input_feature == output_feature:
self.skip_net = torch.nn.Linear(input_feature, output_feature)
else:
self.skip_net = torch.nn.Identity()
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
'''x should be a 2D Tensor due to batchnorm'''
return self.process_net(x) + self.skip_net(x)
[docs]
class VectorizedResBlock(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self,
in_features: int,
out_features: int,
ensemble_size: int,
norm : str = 'ln'):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size
if norm == 'ln':
norm_class = torch.nn.LayerNorm
self.process_net = torch.nn.Sequential(
VectorizedLinear(in_features, out_features, ensemble_size),
norm_class(out_features),
torch.nn.ReLU(True),
VectorizedLinear(out_features, out_features, ensemble_size),
norm_class(out_features),
torch.nn.ReLU(True)
)
else:
self.process_net = torch.nn.Sequential(
VectorizedLinear(in_features, out_features, ensemble_size),
torch.nn.ReLU(True),
VectorizedLinear(out_features, out_features, ensemble_size),
torch.nn.ReLU(True)
)
if not in_features == out_features:
self.skip_net = VectorizedLinear(in_features, out_features, ensemble_size)
else:
self.skip_net = torch.nn.Identity()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# input: [ensemble_size, batch_size, input_size]
# out: [ensemble_size, batch_size, out_size]
out = self.process_net(x) + self.skip_net(x)
return out
[docs]
class ResNet(torch.nn.Module):
"""
Initializes a residual neural network (ResNet) instance.
Args:
in_features (int): The number of input features to the ResNet.
out_features (int): The number of output features from the ResNet.
hidden_features (int): The number of hidden features in each residual block.
hidden_layers (int): The number of residual blocks in the ResNet.
norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization.
output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str = 'bn',
output_activation : str = 'identity'):
super().__init__()
modules = []
for i in range(hidden_layers):
if i == 0:
modules.append(ResBlock(in_features, hidden_features, norm))
else:
modules.append(ResBlock(hidden_features, hidden_features, norm))
modules.append(torch.nn.Linear(hidden_features, out_features))
modules.append(ACTIVATION_CREATORS[output_activation](out_features))
self.resnet = torch.nn.Sequential(*modules)
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
'''NOTE: reshape is needed since resblock only support 2D Tensor'''
shape = x.shape
x = x.view(-1, shape[-1])
output = self.resnet(x)
output = output.view(*shape[:-1], -1)
return output
[docs]
class VectorizedResNet(torch.nn.Module):
"""
Initializes a residual neural network (ResNet) instance.
Args:
in_features (int): The number of input features to the ResNet.
out_features (int): The number of output features from the ResNet.
hidden_features (int): The number of hidden features in each residual block.
hidden_layers (int): The number of residual blocks in the ResNet.
norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization.
output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
norm : str = 'bn',
output_activation : str = 'identity'):
super().__init__()
self.ensemble_size = ensemble_size
modules = []
for i in range(hidden_layers):
if i == 0:
modules.append(VectorizedResBlock(in_features, hidden_features, ensemble_size, norm))
else:
modules.append(VectorizedResBlock(hidden_features, hidden_features, ensemble_size, norm))
modules.append(VectorizedLinear(hidden_features, out_features, ensemble_size))
modules.append(ACTIVATION_CREATORS[output_activation](out_features))
self.resnet = torch.nn.Sequential(*modules)
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
'''NOTE: reshape is needed since resblock only support 2D Tensor'''
shape = x.shape
if len(shape) == 2:
x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [batch_size, dim] --> [ensemble_size, batch_size, dim]
assert x.dim() == 3, f"shape == {shape}, with length == {len(shape)}, expect length 3"
assert x.shape[0] == self.ensemble_size # [ensemble_size, batch_size, dim]
output = self.resnet(x)
return output
[docs]
class TAPE(nn.Module):
def __init__(self, d_model, max_len=200, scale_factor=1.0):
super(TAPE, self).__init__()
pe = torch.zeros(max_len, d_model) # positional encoding
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin((position * div_term)*(d_model/max_len))
pe[:, 1::2] = torch.cos((position * div_term)*(d_model/max_len))
pe = scale_factor * pe
self.register_buffer('pe', pe) # this stores the variable in the state_dict (used for non-trainable variables)
[docs]
def forward(self, x):
return x + self.pe[:x.shape[1], :]
[docs]
class TAttention(nn.Module):
def __init__(self, d_model, nhead, dropout):
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.qtrans = nn.Linear(d_model, d_model, bias=False)
self.ktrans = nn.Linear(d_model, d_model, bias=False)
self.vtrans = nn.Linear(d_model, d_model, bias=False)
self.attn_dropout = []
if dropout > 0:
for i in range(nhead):
self.attn_dropout.append(nn.Dropout(p=dropout))
self.attn_dropout = nn.ModuleList(self.attn_dropout)
# input LayerNorm
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
# FFN layerNorm
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
# FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Dropout(p=dropout),
nn.Linear(d_model, d_model),
nn.Dropout(p=dropout)
)
[docs]
def forward(self, x):
x = self.norm1(x)
q = self.qtrans(x)
k = self.ktrans(x)
v = self.vtrans(x)
dim = int(self.d_model / self.nhead)
att_output = []
for i in range(self.nhead):
if i==self.nhead-1:
qh = q[:, :, i * dim:]
kh = k[:, :, i * dim:]
vh = v[:, :, i * dim:]
else:
qh = q[:, :, i * dim:(i + 1) * dim]
kh = k[:, :, i * dim:(i + 1) * dim]
vh = v[:, :, i * dim:(i + 1) * dim]
atten_ave_matrixh = torch.softmax(torch.matmul(qh, kh.transpose(1, 2)), dim=-1)
if self.attn_dropout:
atten_ave_matrixh = self.attn_dropout[i](atten_ave_matrixh)
att_output.append(torch.matmul(atten_ave_matrixh, vh))
att_output = torch.concat(att_output, dim=-1)
# FFN
xt = x + att_output
xt = self.norm2(xt)
att_output = xt + self.ffn(xt)
return att_output
[docs]
class Tokenizer(nn.Module):
category_offsets: Optional[torch.Tensor]
def __init__(
self,
d_numerical: int,
categories: Optional[List[int]],
d_token: int,
bias: bool,
) -> None:
super().__init__()
if categories is None:
d_bias = d_numerical
self.category_offsets = None
self.category_embeddings = None
else:
d_bias = d_numerical + len(categories)
category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
self.register_buffer('category_offsets', category_offsets)
self.category_embeddings = nn.Embedding(sum(categories), d_token)
nn.init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
print(f'{self.category_embeddings.weight.shape=}')
# take [CLS] token into account
self.weight = nn.Parameter(torch.Tensor(d_numerical + 1, d_token))
self.bias = nn.Parameter(torch.Tensor(d_bias, d_token)) if bias else None
# The initialization is inspired by nn.Linear
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5))
@property
def n_tokens(self) -> int:
return len(self.weight) + (
0 if self.category_offsets is None else len(self.category_offsets)
)
[docs]
def forward(self, x_num: torch.Tensor, x_cat: Optional[torch.Tensor]):
x_some = x_num if x_cat is None else x_cat
assert x_some is not None
x_num = torch.cat(
[torch.ones(len(x_some), 1, device=x_some.device)] # [CLS]
+ ([] if x_num is None else [x_num]),
dim=1,
)
x = self.weight[None] * x_num[:, :, None]
if x_cat is not None:
x = torch.cat(
[x, self.category_embeddings(x_cat + self.category_offsets[None])],
dim=1,
)
if self.bias is not None:
bias = torch.cat(
[
torch.zeros(1, self.bias.shape[1], device=x.device),
self.bias,
]
)
x = x + bias[None]
return x
[docs]
class MultiheadAttention(nn.Module):
def __init__(
self, d: int, n_heads: int, dropout: float, initialization: str
) -> None:
if n_heads > 1:
assert d % n_heads == 0
assert initialization in ['xavier', 'kaiming']
super().__init__()
self.W_q = nn.Linear(d, d)
self.W_k = nn.Linear(d, d)
self.W_v = nn.Linear(d, d)
self.W_out = nn.Linear(d, d) if n_heads > 1 else None
self.n_heads = n_heads
self.dropout = nn.Dropout(dropout) if dropout else None
for m in [self.W_q, self.W_k, self.W_v]:
if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
# gain is needed since W_qkv is represented with 3 separate layers
nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
nn.init.zeros_(m.bias)
if self.W_out is not None:
nn.init.zeros_(self.W_out.bias)
def _reshape(self, x):
batch_size, n_tokens, d = x.shape
d_head = d // self.n_heads
return (
x.reshape(batch_size, n_tokens, self.n_heads, d_head)
.transpose(1, 2)
.reshape(batch_size * self.n_heads, n_tokens, d_head)
)
[docs]
def forward(
self,
x_q,
x_kv,
key_compression: Optional[nn.Linear],
value_compression: Optional[nn.Linear],
):
q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
for tensor in [q, k, v]:
assert tensor.shape[-1] % self.n_heads == 0
if key_compression is not None:
assert value_compression is not None
k = key_compression(k.transpose(1, 2)).transpose(1, 2)
v = value_compression(v.transpose(1, 2)).transpose(1, 2)
else:
assert value_compression is None
batch_size = len(q)
d_head_key = k.shape[-1] // self.n_heads
d_head_value = v.shape[-1] // self.n_heads
n_q_tokens = q.shape[1]
q = self._reshape(q)
k = self._reshape(k)
attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1)
if self.dropout is not None:
attention = self.dropout(attention)
x = attention @ self._reshape(v)
x = (
x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
.transpose(1, 2)
.reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
)
if self.W_out is not None:
x = self.W_out(x)
return x
[docs]
class DistributionWrapper(nn.Module):
r"""wrap output of Module to distribution"""
BASE_TYPES = ['normal', 'gmm', 'onehot', 'discrete_logistic']
SUPPORTED_TYPES = BASE_TYPES + ['mix']
def __init__(self, distribution_type : str = 'normal', **params):
super().__init__()
self.distribution_type = distribution_type
self.params = params
assert self.distribution_type in self.SUPPORTED_TYPES, f"{self.distribution_type} is not supported!"
if self.distribution_type == 'normal':
self.max_logstd = nn.Parameter(torch.ones(self.params['dim']) * 0.0, requires_grad=True)
self.min_logstd = nn.Parameter(torch.ones(self.params['dim']) * -7, requires_grad=True)
if not self.params.get('conditioned_std', True):
self.logstd = nn.Parameter(torch.zeros(self.params['dim']), requires_grad=True)
elif self.distribution_type == 'gmm':
self.max_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * 0, requires_grad=True)
self.min_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * -10, requires_grad=True)
if not self.params.get('conditioned_std', True):
self.logstd = nn.Parameter(torch.zeros(self.params['mixture'], self.params['dim']), requires_grad=True)
elif self.distribution_type == 'discrete_logistic':
self.num = self.params['num']
elif self.distribution_type == 'mix':
assert 'dist_config' in self.params.keys(), "You need to provide `dist_config` for Mix distribution"
self.dist_config = self.params['dist_config']
self.wrapper_list = []
self.input_sizes = []
self.output_sizes = []
for config in self.dist_config:
dist_type = config['type']
assert dist_type in self.SUPPORTED_TYPES, f"{dist_type} is not supported!"
assert not dist_type == 'mix', "recursive MixDistribution is not supported!"
self.wrapper_list.append(DistributionWrapper(dist_type, **config))
self.input_sizes.append(config['dim'])
self.output_sizes.append(config['output_dim'])
self.wrapper_list = nn.ModuleList(self.wrapper_list)
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, payload : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
'''
Warp the given tensor to distribution
:param adapt_std : it will overwrite the std part of the distribution (optional)
:param payload : payload will be applied to the output distribution after built (optional)
'''
# [ OTHER ] replace the clip with tanh
# [ OTHER ] better std controlling strategy is needed todo
if self.distribution_type == 'normal':
if self.params.get('conditioned_std', True):
mu, logstd = torch.chunk(x, 2, dim=-1)
if 'soft_clamp' in kwargs and kwargs['soft_clamp'] == True:
# distribution wrapper only for revive_f venv training
logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd)
std = torch.exp(logstd)
else:
# distribution wrapper for others
logstd_logit = logstd
max_std = 0.5
min_std = 0.001
std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std
else:
mu, logstd = x, self.logstd
logstd_logit = self.logstd
max_std = 0.5
min_std = 0.001
std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std
# std = torch.exp(logstd)
if payload is not None:
mu = mu + safe_atanh(payload)
mu = torch.tanh(mu)
if adapt_std is not None:
std = torch.ones_like(mu).to(mu) * adapt_std
return DiagnalNormal(mu, std)
elif self.distribution_type == 'gmm':
if self.params.get('conditioned_std', True):
logits, mus, logstds = torch.split(x, [self.params['mixture'],
self.params['mixture'] * self.params['dim'],
self.params['mixture'] * self.params['dim']], dim=-1)
mus = mus.view(*mus.shape[:-1], self.params['mixture'], self.params['dim'])
logstds = logstds.view(*logstds.shape[:-1], self.params['mixture'], self.params['dim'])
else:
logits, mus = torch.split(x, [self.params['mixture'], self.params['mixture'] * self.params['dim']], dim=-1)
logstds = self.logstd
if payload is not None:
mus = mus + safe_atanh(payload.unsqueeze(dim=-2))
mus = torch.tanh(mus)
stds = adapt_std if adapt_std is not None else torch.exp(soft_clamp(logstds, self.min_logstd, self.max_logstd))
return GaussianMixture(mus, stds, logits)
elif self.distribution_type == 'onehot':
return Onehot(x)
elif self.distribution_type == 'discrete_logistic':
mu, logstd = torch.chunk(x, 2, dim=-1)
logstd = torch.clamp(logstd, -7, 1)
return DiscreteLogistic(mu, torch.exp(logstd), num=self.num)
elif self.distribution_type == 'mix':
xs = torch.split(x, self.output_sizes, dim=-1)
if isinstance(adapt_std, float):
# print("look !!!")
# print(self.input_sizes)
# print(len(self.input_sizes))
adapt_stds = [adapt_std] * len(self.input_sizes)
elif adapt_std is not None:
adapt_stds = torch.split(adapt_std, self.input_sizes, dim=-1)
else:
adapt_stds = [None] * len(self.input_sizes)
if payload is not None:
payloads = torch.split(payload, self.input_sizes + [payload.shape[-1] - sum(self.input_sizes)], dim=-1)[:-1]
else:
payloads = [None] * len(self.input_sizes)
dists = [wrapper(x, _adapt_std, _payload, **kwargs) for x, _adapt_std, _payload, wrapper in zip(xs, adapt_stds, payloads, self.wrapper_list)]
return MixDistribution(dists)
# --------------------------------- Policies -------------------------------- #
[docs]
class FeedForwardPolicy(torch.nn.Module):
"""Policy for using mlp, resnet and transformer backbone.
Args:
in_features : The number of input features, or a dictionary describing the input distribution
out_features : The number of output features
hidden_features : The number of hidden features in each layer
hidden_layers : The number of hidden layers
dist_config : A list of configurations for the distributions to use in the model
norm : The type of normalization to apply to the input features
hidden_activation : The activation function to use in the hidden layers
backbone_type : The type of backbone to use in the model
use_multihead : Whether to use a multihead model
use_feature_embed : Whether to use feature embedding
"""
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
norm : str = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'mlp',
use_multihead : bool = False,
**kwargs):
super().__init__()
self.multihead = []
self.dist_config = dist_config
self.kwargs = kwargs
if isinstance(in_features, dict):
in_features = sum(in_features.values())
if not use_multihead:
if backbone_type == 'mlp':
self.backbone = MLP(in_features, out_features, hidden_features, hidden_layers, norm, hidden_activation)
elif backbone_type == 'res':
self.backbone = ResNet(in_features, out_features, hidden_features, hidden_layers, norm)
elif backbone_type == 'ft_transformer':
self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers)
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
else:
if backbone_type == 'mlp':
self.backbone = MLP(in_features, hidden_features, hidden_features, hidden_layers, norm, hidden_activation)
elif backbone_type == 'res':
self.backbone = ResNet(in_features, hidden_features, hidden_features, hidden_layers, norm)
elif backbone_type == 'ft_transformer':
self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers)
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
for dist in self.dist_config:
dist_type = dist["type"]
output_dim = dist["output_dim"]
if dist_type == "onehot" or dist_type == "discrete_logistic":
self.multihead.append(MLP(hidden_features, output_dim, 64, 1, norm=None, hidden_activation='leakyrelu'))
elif dist_type == "normal":
normal_dim = dist["dim"]
for dim in range(normal_dim):
self.multihead.append(MLP(hidden_features, int(output_dim // normal_dim), 64, 1, norm=None, hidden_activation='leakyrelu'))
else:
raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.')
self.multihead = nn.ModuleList(self.multihead)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
if "ts_conv_config" in self.kwargs and self.kwargs['ts_conv_config'] != None:
other_state_layer_output = in_features//2
ts_state_conv_layer_output = in_features//2 + in_features%2
self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=hidden_activation)
all_node_ts = None
self.conv_ts_node = []
self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input'])
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
self.conv_ts_node.append('__history__'+tsnode)
self.conv_other_node.append('__now__'+tsnode)
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
if all_node_ts == None:
all_node_ts = node_ts
assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}"
#ts-1 remove now state
kernalsize = all_node_ts-1
self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=hidden_activation)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
if "ts_conv_config" in self.kwargs and self.kwargs['ts_conv_config'] != None:
#deal with all ts nodes with ts dim
# assert 'input_names' in kwargs
assert hasattr(self, 'conv_ts_node')
assert hasattr(self, 'conv_other_node')
inputdata = deepcopy(state) #data is deteched from other node
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
# assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True
batch_size = int(np.prod(list(inputdata[tsnode].shape[:-1])))
original_size = list(inputdata[tsnode].shape[:-1])
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
# inputdata[tsnode]
temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1)
inputdata['__history__'+tsnode] = temp_data[..., :-1, :]
inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1])
state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1)
state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1)
state_other = self.other_state_layer(state_other)
original_size = list(state_other.shape[:-1])
state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1])
state = torch.cat([state_other, state_ts_conv], dim=-1)
elif isinstance(state, dict):
assert 'input_names' in kwargs
state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1)
if not self.multihead:
output = self.backbone(state)
else:
backbone_output = self.backbone(state)
multihead_output = []
multihead_index = 0
for dist in self.dist_config:
dist_type = dist["type"]
output_dim = dist["output_dim"]
if dist_type == "onehot" or dist_type == "discrete_logistic":
multihead_output.append(self.multihead[multihead_index](backbone_output))
multihead_index += 1
elif dist_type == "normal":
normal_mode_output = []
normal_std_output = []
for head in self.multihead[multihead_index:]:
head_output = head(backbone_output)
if head_output.shape[-1] == 1:
mode = head_output
normal_mode_output.append(mode)
else:
mode, std = torch.chunk(head_output, 2, axis=-1)
normal_mode_output.append(mode)
normal_std_output.append(std)
normal_output = torch.cat(normal_mode_output, axis=-1)
if normal_std_output:
normal_std_output = torch.cat(normal_std_output, axis=-1)
normal_output = torch.cat([normal_output, normal_std_output], axis=-1)
multihead_output.append(normal_output)
break
else:
raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.')
output = torch.cat(multihead_output, axis= -1)
soft_clamp_flag = True if "soft_clamp" in self.kwargs and self.kwargs['soft_clamp'] == True else False
dist = self.dist_wrapper(output, adapt_std, soft_clamp=soft_clamp_flag, **kwargs)
if hasattr(self, "dist_mu_shift"):
dist = dist.shift(self.dist_mu_shift)
return dist
[docs]
@torch.no_grad()
def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True):
dist = self(state)
return dist.mode if deterministic else dist.sample()
[docs]
class RecurrentPolicy(torch.nn.Module):
"""Initializes a recurrent policy network instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features in the RNN.
hidden_layers (int): The number of layers in the RNN.
dist_config (list): The configuration for the distributions used in the model.
backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
**kwargs):
super().__init__()
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
rnn_hidden_features = 256
rnn_hidden_layers = min(hidden_layers, 3)
rnn_output_features = min(out_features,8)
self.rnn = RNN(in_features, rnn_hidden_features, rnn_hidden_layers)
self.rnn_mlp = MLP(rnn_hidden_features, rnn_output_features, 0, 0)
self.backbone = ResNet(in_features+rnn_output_features, out_features, hidden_features, hidden_layers, norm)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
def reset(self):
self.h = None
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
x_shape = x.shape
if len(x_shape) == 1:
rnn_output, self.h = self.rnn(x.unsqueeze(0).unsqueeze(0), self.h)
rnn_output = rnn_output.squeeze(0).squeeze(0)
elif len(x_shape) == 2:
rnn_output, self.h = self.rnn(x.unsqueeze(0), self.h)
rnn_output = rnn_output.squeeze(0)
else:
rnn_output, self.h = self.rnn(x, self.h)
rnn_output = self.rnn_mlp(rnn_output)
logits = self.backbone(torch.cat([x, rnn_output],axis=-1))
return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs]
class TsRecurrentPolicy(torch.nn.Module):
input_is_dict = True
"""Initializes a recurrent policy network instance for ts_node.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features in the RNN.
hidden_layers (int): The number of layers in the RNN.
dist_config (list): The configuration for the distributions used in the model.
backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
ts_input_names : dict = {},
bidirectional: bool = True,
dropout: float = 0,
**kwargs):
super().__init__()
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
self.ts_input_names = ts_input_names
self.input_dim_dict = kwargs.get("input_dim_dict",{})
self.max_ts_steps = max(self.ts_input_names.values())
self.rnn = RNN(input_size=in_features,
hidden_size=hidden_features,
num_layers=hidden_layers,
bidirectional=bidirectional,
batch_first=True,
dropout=dropout if hidden_layers > 1 else 0)
if bidirectional:
self.linear = nn.Linear(hidden_features * 2, out_features)
else:
self.linear = nn.Linear(hidden_features, out_features)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
def reset(self):
self.h = None
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
x_shape = x[list(x.keys())[0]].shape
if len(x_shape) == 3:
x = {k:v.reshape(-1,v.shape[-1]) for k,v in x.items()}
input = {}
for k,v in x.items():
ts_steps = self.ts_input_names.get(k,1)
input[k] = v.reshape(v.shape[0],ts_steps,-1)
if ts_steps < self.max_ts_steps:
input[k] = torch.cat([input[k][:,:1].repeat(1, self.max_ts_steps-ts_steps, 1), input[k]],axis=1)
if len(input) == 1:
rnn_input = list(input.values())[0]
else:
rnn_input = torch.cat(list(input.values()),axis=-1)
out, _ = self.rnn(rnn_input,None)
out = self.linear(out[:, -1, :])
if len(x_shape) == 3:
logits = out.reshape(x_shape[0],x_shape[1],-1)
else:
logits = out.reshape(x_shape[0],-1)
return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs]
class RecurrentRESPolicy(torch.nn.Module):
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
rnn_hidden_features : int = 64,
window_size : int = 0,
**kwargs):
super().__init__()
self.kwargs = kwargs
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
self.in_feature_embed = nn.Sequential(
nn.Linear(in_features, hidden_features),
ACTIVATION_CREATORS['relu'](hidden_features),
nn.Dropout(0.1),
nn.LayerNorm(hidden_features)
)
self.side_net = nn.Sequential(
ResBlock(hidden_features, hidden_features),
nn.Dropout(0.1),
nn.LayerNorm(hidden_features),
ResBlock(hidden_features, hidden_features),
nn.Dropout(0.1),
nn.LayerNorm(hidden_features)
)
self.rnn = RNN(hidden_features, rnn_hidden_features, hidden_layers)
self.backbone = MLP(hidden_features + rnn_hidden_features, out_features, hidden_features, 1)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
self.hidden_layers = hidden_layers
self.hidden_features = hidden_features
self.rnn_hidden_features = rnn_hidden_features
self.window_size = window_size
self.hidden_que = deque(maxlen=self.window_size) # [(h_0, c_0), (h_1, c_1), ...]
self.h = None
[docs]
def reset(self):
self.hidden_que = deque(maxlen=self.window_size)
self.h = None
[docs]
def preprocess_window(self, a_embed):
shape = a_embed.shape
if len(self.hidden_que) < self.window_size: # do Not pop
if self.hidden_que:
h = self.hidden_que[0]
else:
h = (torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed), torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed))
else:
h = self.hidden_que.popleft()
hidden_cat = torch.concat([h[0]] + [hidden[0] for hidden in self.hidden_que] + [torch.zeros_like(h[0]).to(h[0])], dim=1) # (layers, bs * 4, dim)
cell_cat = torch.concat([h[1]] + [hidden[1] for hidden in self.hidden_que] + [torch.zeros_like(h[1]).to(h[1])], dim=1) # (layers, bs * 4, dim)
a_embed = torch.repeat_interleave(a_embed, repeats=len(self.hidden_que)+2, dim=0).view(1, (len(self.hidden_que)+2)*shape[1], -1) # (1, bs * 4, dim)
return a_embed, hidden_cat, cell_cat
[docs]
def postprocess_window(self, rnn_output, hidden_cat, cell_cat):
hidden_cat = torch.chunk(hidden_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim)
cell_cat = torch.chunk(cell_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim)
rnn_output = torch.chunk(rnn_output, chunks=len(self.hidden_que)+2, dim=1) # tuples of (1, bs, dim)
self.hidden_que = deque(zip(hidden_cat[1:], cell_cat[1:]), maxlen=self.window_size) # important to discrad the first element !!!
rnn_output = rnn_output[0] # (1, bs, dim)
return rnn_output
[docs]
def rnn_forward(self, a_embed, joint_train : bool):
if self.window_size > 0:
a_embed, hidden_cat, cell_cat = self.preprocess_window(a_embed)
if joint_train:
rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat))
else:
with torch.no_grad():
rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat))
rnn_output = self.postprocess_window(rnn_output, hidden_cat, cell_cat) #(1, bs, dim)
else:
if joint_train:
rnn_output, self.h = self.rnn(a_embed, self.h) #(1, bs, dim)
else:
with torch.no_grad():
rnn_output, self.h = self.rnn(a_embed, self.h)
return rnn_output
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution:
if field == 'bc':
shape = x.shape
joint_train = True
if len(shape) == 2: # (bs, dim)
x_embed = self.in_feature_embed(x)
side_output = self.side_net(x_embed)
a_embed = x_embed.unsqueeze(0) # [1, bs, dim]
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
rnn_output = rnn_output.squeeze(0) # (bs, dim)
logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim)
else:
assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}"
self.reset()
output = []
for i in range(shape[0]):
a = x[i]
a = a.unsqueeze(0) #(1, bs, dim)
a_embed = self.in_feature_embed(a) # (1, bs, dim)
side_output = self.side_net(a_embed)
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim)
output.append(backbone_output)
logits = torch.concat(output, dim=0)
elif field == 'mail':
shape = x.shape
joint_train = False
if len(shape) == 2:
x_embed = self.in_feature_embed(x)
side_output = self.side_net(x_embed)
a_embed = x_embed.unsqueeze(0)
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
rnn_output = rnn_output.squeeze(0) # (bs, dim)
logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim)
else:
assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}"
self.reset()
output = []
for i in range(shape[0]):
a = x[i]
a = a.unsqueeze(0) #(1, bs, dim)
a_embed = self.in_feature_embed(a) # (1, bs, dim)
side_output = self.side_net(a_embed)
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim)
output.append(backbone_output)
logits = torch.concat(output, dim=0)
else:
raise NotImplementedError(f"unknow field: {field} in RNN training !")
return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs]
class ContextualPolicy(torch.nn.Module):
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] ='contextual_gru',
**kwargs):
super().__init__()
self.kwargs = kwargs
RNN = torch.nn.GRU if backbone_type == 'contextual_gru' else torch.nn.LSTM
self.preprocess_mlp = MLP(in_features, hidden_features, 0, 0, output_activation='leakyrelu')
self.rnn = RNN(hidden_features, hidden_features, 1)
self.mlp = MLP(hidden_features + in_features, out_features, hidden_features, hidden_layers)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
def reset(self):
self.h = None
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
in_feature = x
if len(x.shape) == 2:
x = x.unsqueeze(0)
x = self.preprocess_mlp(x)
rnn_output, self.h = self.rnn(x, self.h)
rnn_output = rnn_output.squeeze(0)
else:
x = self.preprocess_mlp(x)
rnn_output, self.h = self.rnn(x)
logits = self.mlp(torch.cat((in_feature, rnn_output), dim=-1))
return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs]
class EnsembleFeedForwardPolicy(torch.nn.Module):
"""Policy for using mlp, resnet and transformer backbone.
Args:
in_features : The number of input features, or a dictionary describing the input distribution
out_features : The number of output features
hidden_features : The number of hidden features in each layer
hidden_layers : The number of hidden layers
dist_config : A list of configurations for the distributions to use in the model
norm : The type of normalization to apply to the input features
hidden_activation : The activation function to use in the hidden layers
backbone_type : The type of backbone to use in the model
use_multihead : Whether to use a multihead model
use_feature_embed : Whether to use feature embedding
"""
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
dist_config : list,
norm : str = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'EnsembleRES',
use_multihead : bool = False,
**kwargs):
super().__init__()
self.multihead = []
self.ensemble_size = ensemble_size
self.dist_config = dist_config
self.kwargs = kwargs
if isinstance(in_features, dict):
in_features = sum(in_features.values())
if backbone_type == 'EnsembleMLP':
self.backbone = VectorizedMLP(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation)
elif backbone_type == 'EnsembleRES':
self.backbone = VectorizedResNet(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm)
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
if isinstance(state, dict):
assert 'input_names' in kwargs
state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1)
output = self.backbone(state)
dist = self.dist_wrapper(output, adapt_std, **kwargs)
if hasattr(self, "dist_mu_shift"):
dist = dist.shift(self.dist_mu_shift)
return dist
[docs]
@torch.no_grad()
def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True):
dist = self(state)
return dist.mode if deterministic else dist.sample()
# ------------------------------- Transitions ------------------------------- #
[docs]
class FeedForwardTransition(FeedForwardPolicy):
r"""Initializes a feedforward transition instance.
Args:
in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The configuration of the distribution to use.
norm (Optional[str]): The normalization method to use. None if no normalization is used.
hidden_activation (str): The activation function to use for the hidden layers.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode to use for the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation for 'local' mode.
use_feature_embed (bool): Whether to use feature embedding.
"""
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
norm : Optional[str] = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'mlp',
mode : str = 'global',
obs_dim : Optional[int] = None,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local':
dist_types = [config['type'] for config in dist_config]
if 'onehot' in dist_types or 'discrete_logistic' in dist_types:
warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!')
self.mode = 'global'
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(FeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers,
dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
dist = super(FeedForwardTransition, self).forward(state, adapt_std, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]
class RecurrentTransition(RecurrentPolicy):
r"""
Initializes a recurrent transition instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The distribution configuration.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode of the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
mode : str = 'global',
obs_dim : Optional[int] = None,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(RecurrentTransition, self).__init__(in_features,
out_features,
hidden_features,
hidden_layers,
norm,
dist_config,
backbone_type,
**kwargs)
[docs]
def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
dist = super(RecurrentTransition, self).forward(state, adapt_std, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]
class TsRecurrentTransition(TsRecurrentPolicy):
input_is_dict = True
r"""
Initializes a recurrent transition instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The distribution configuration.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode of the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
mode : str = 'global',
obs_dim : Optional[int] = None,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(TsRecurrentTransition, self).__init__(in_features,
out_features,
hidden_features,
hidden_layers,
norm,
dist_config,
backbone_type,
**kwargs)
[docs]
def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
dist = super(TsRecurrentTransition, self).forward(state, adapt_std, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]
class RecurrentRESTransition(RecurrentRESPolicy):
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] = 'mlp',
mode : str = 'global',
obs_dim : Optional[int] = None,
rnn_hidden_features : int = 64,
window_size : int = 0,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(RecurrentRESTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers,
dist_config, backbone_type,
rnn_hidden_features=rnn_hidden_features, window_size=window_size,
**kwargs)
[docs]
def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution:
dist = super(RecurrentRESTransition, self).forward(state, adapt_std, field, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]
class EnsembleFeedForwardTransition(EnsembleFeedForwardPolicy):
r"""Initializes a feedforward transition instance.
Args:
in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The configuration of the distribution to use.
norm (Optional[str]): The normalization method to use. None if no normalization is used.
hidden_activation (str): The activation function to use for the hidden layers.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode to use for the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation for 'local' mode.
use_feature_embed (bool): Whether to use feature embedding.
"""
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
dist_config : list,
norm : Optional[str] = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'EnsembleRES',
mode : str = 'global',
obs_dim : Optional[int] = None,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local':
dist_types = [config['type'] for config in dist_config]
if 'onehot' in dist_types or 'discrete_logistic' in dist_types:
warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!')
self.mode = 'global'
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(EnsembleFeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, ensemble_size,
dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
dist = super(EnsembleFeedForwardTransition, self).forward(state, adapt_std, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
# ------------------------------- Matchers ------------------------------- #
[docs]
class FeedForwardMatcher(torch.nn.Module):
r"""
Initializes a feedforward matcher instance.
Args:
in_features (int): The number of input features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
hidden_activation (str): The activation function to use for the hidden layers.
norm (str): The normalization method to use. None if no normalization is used.
backbone_type (Union[str, np.str_]): The type of backbone to use.
"""
def __init__(self,
in_features : int,
hidden_features : int,
hidden_layers : int,
hidden_activation : str = 'leakyrelu',
norm : str = None,
backbone_type : Union[str, np.str_] = 'mlp',
**kwargs):
super().__init__()
if backbone_type == 'mlp':
self.backbone = MLP(in_features, 1, hidden_features, hidden_layers, norm, hidden_activation, output_activation='sigmoid')
elif backbone_type == 'res':
self.backbone = ResNet(in_features, 1, hidden_features, hidden_layers, norm, output_activation='sigmoid')
elif backbone_type == 'ft_transformer':
self.backbone = FT_Transformer(in_features, 1, hidden_features, hidden_layers=hidden_layers)
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
# x.requires_grad = True
shape = x.shape
if len(shape) == 3: # [horizon, batch_size, dim]
x = x.view(-1, shape[-1]) # [horizon * batch_size, dim]
assert len(x.shape) == 2, f"len(x.shape) == {len(x.shape)}, expect 2"
out = self.backbone(x) # [batch_size, dim]
if len(shape) == 3:
out = out.view(*shape[:-1], -1) # [horizon, batch_size, dim]
return out
[docs]
class RecurrentMatcher(torch.nn.Module):
def __init__(self,
in_features : int,
hidden_features : int,
hidden_layers : int,
backbone_type : Union[str, np.str_] = 'gru',
bidirect : bool = False):
super().__init__()
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
self.rnn = RNN(in_features, hidden_features, hidden_layers, bidirectional=bidirect)
self.output_layer = MLP(hidden_features * (2 if bidirect else 1), 1, 0, 0, output_activation='sigmoid')
[docs]
def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
x = torch.cat(inputs, dim=-1)
rnn_output = self.rnn(x)[0]
return self.output_layer(rnn_output)
[docs]
class HierarchicalMatcher(torch.nn.Module):
def __init__(self,
in_features : list,
hidden_features : int,
hidden_layers : int,
hidden_activation : int,
norm : str = None):
super().__init__()
self.in_features = in_features
process_layers = []
output_layers = []
feature = self.in_features[0]
for i in range(1, len(self.in_features)):
feature += self.in_features[i]
process_layers.append(MLP(feature, hidden_features, hidden_features, hidden_layers, norm, hidden_activation, hidden_activation))
output_layers.append(torch.nn.Linear(hidden_features, 1))
feature = hidden_features
self.process_layers = torch.nn.ModuleList(process_layers)
self.output_layers = torch.nn.ModuleList(output_layers)
[docs]
def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
assert len(inputs) == len(self.in_features)
last_feature = inputs[0]
result = 0
for i, process_layer, output_layer in zip(range(1, len(self.in_features)), self.process_layers, self.output_layers):
last_feature = torch.cat([last_feature, inputs[i]], dim=-1)
last_feature = process_layer(last_feature)
result += output_layer(last_feature)
return torch.sigmoid(result)
[docs]
class VectorizedMatcher(torch.nn.Module):
r"""
Initializes a feedforward matcher instance.
Args:
in_features (int): The number of input features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
hidden_activation (str): The activation function to use for the hidden layers.
norm (str): The normalization method to use. None if no normalization is used.
backbone_type (Union[str, np.str_]): The type of backbone to use.
"""
def __init__(self,
in_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
hidden_activation : str = 'leakyrelu',
norm : str = None,
backbone_type : Union[str, np.str_] = 'mlp',
**kwargs):
super().__init__()
self.backbone_type = backbone_type
self.ensemble_size = ensemble_size
if backbone_type == 'mlp':
self.backbone = VectorizedMLP(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation, output_activation='sigmoid')
elif backbone_type == 'res':
self.backbone = VectorizedResNet(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, output_activation='sigmoid')
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs]
def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
x = torch.cat(inputs, dim=-1).detach()
# x.requires_grad = True
shape = x.shape
if len(shape) == 2: # [batch_size, dim]
x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [batch_size, dim] --> [ensemble_size, batch_size, dim]
assert x.shape[0] == self.ensemble_size
out = self.backbone(x) # [ensemble_size, batch_size, dim]
return out
[docs]
class EnsembleMatcher:
def __init__(self,
matcher_record: Dict[str, deque],
ensemble_size : int = 50,
ensemble_choosing_interval: int = 1,
config : dict = None) -> None:
self.matcher_weights_list = []
self.single_structure_dict = config
self.matching_nodes = config['matching_nodes']
self.matching_fit_nodes = config['matching_fit_nodes']
# self.matching_nodes_fit_dict = config['matching_nodes_fit_index']
self.device = torch.device("cpu")
self.min_generate_score = 0.0
self.max_expert_score = 1.0
self.selected_ind = []
counter = 0
for i in reversed(range(len(matcher_record['state_dicts']))):
if counter % ensemble_choosing_interval == 0:
matcher_weight = matcher_record['state_dicts'][i]
self.selected_ind.append(i)
self.matcher_weights_list.append(matcher_weight)
if len(self.matcher_weights_list) >= ensemble_size:
break
counter += 1
self.ensemble_size = len(self.matcher_weights_list)
self.matcher_ensemble = VectorizedMatcher(ensemble_size=self.ensemble_size, **config)
self.init_ensemble_network()
self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores'])
[docs]
def load_min_max(self, expert_scores: deque, generate_scores: deque):
self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ]
self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ]
self.max_expert_score = max(self.expert_scores)
self.min_generate_score = min(self.generate_scores)
self.range = max(abs(self.max_expert_score), abs(self.min_generate_score))
[docs]
def init_ensemble_network(self,):
if self.matcher_ensemble.backbone_type == 'mlp':
with torch.no_grad():
for ind, state_dict in enumerate(self.matcher_weights_list):
for key, val in state_dict.items():
layer_ind = int(key.split('.')[2])
if key.endswith("weight"):
self.matcher_ensemble.backbone.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
self.matcher_ensemble.backbone.net[layer_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
elif self.matcher_ensemble.backbone_type == 'res':
with torch.no_grad():
for ind, state_dict in enumerate(self.matcher_weights_list):
for key, val in state_dict.items():
key_ls = key.split('.')
if "skip_net" in key_ls:
block_ind = int(key_ls[2])
if key.endswith("weight"):
self.matcher_ensemble.backbone.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
self.matcher_ensemble.backbone.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
elif "process_net" in key_ls:
block_ind, layer_ind = int(key_ls[2]), int(key_ls[4])
if key.endswith("weight"):
self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
else:
# linear layer at last
block_ind = int(key_ls[2])
if key.endswith("weight"):
self.matcher_ensemble.backbone.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
self.matcher_ensemble.backbone.resnet[block_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
else:
raise NotImplementedError
[docs]
def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor:
with torch.no_grad():
scores = self.matcher_ensemble(inputs) # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1]
scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size]
if aggregation == "mean":
return scores.mean(-1)
return scores
[docs]
def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor:
with torch.no_grad():
scores = self.matcher_ensemble(inputs) # [ensemble_size, batch_size, 1]
scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size]
if clip:
assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand"
scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores)
scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range
if aggregation == "mean":
return scores.mean(-1)
return scores
[docs]
def to(self, device):
self.device = device
self.matcher_ensemble = self.matcher_ensemble.to(device)
self.expert_scores = self.expert_scores.to(device)
self.generate_scores = self.generate_scores.to(device)
# class SequentialMatcher:
# def __init__(self,
# matcher_record: Dict[str, deque],
# ensemble_size : int = 50,
# ensemble_choosing_interval: int = 1,
# config : dict = None) -> None:
# self.matcher_list = []
# self.single_structure_dict = config
# self.matching_nodes = config['matching_nodes']
# self.matching_fit_nodes = config['matching_fit_nodes']
# # self.matching_nodes_fit_dict = config['matching_nodes_fit_index']
# self.device = torch.device("cpu")
# self.min_generate_score = 0.0
# self.max_expert_score = 1.0
# self.selected_ind = []
# counter = 0
# for i in reversed(range(len(matcher_record['state_dicts']))):
# if counter % ensemble_choosing_interval == 0:
# matcher_weight = matcher_record['state_dicts'][i]
# self.selected_ind.append(i)
# _matcher = FeedForwardMatcher(**config)
# _matcher.load_state_dict(matcher_weight)
# self.matcher_list.append(_matcher)
# if len(self.matcher_list) >= ensemble_size:
# break
# counter += 1
# self.ensemble_size = len(self.matcher_list)
# self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores'])
# # prepare for vmap func
# self.params, self.buffers = stack_module_state(self.matcher_list)
# """
# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
# """
# self.base_model = deepcopy(self.matcher_list[0])
# self.base_model = self.base_model.to('meta')
# def vmap_forward(self, x):
# def vmap_model(params, buffers, x):
# return functional_call(self.base_model, (params, buffers), (x,))
# return vmap(vmap_model, in_dims=(0, 0, None))(self.params, self.buffers, x)
# def load_min_max(self, expert_scores: deque, generate_scores: deque):
# self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ]
# self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ]
# self.max_expert_score = max(self.expert_scores)
# self.min_generate_score = min(self.generate_scores)
# self.range = max(abs(self.max_expert_score), abs(self.min_generate_score))
# def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor:
# with torch.no_grad():
# scores = self.vmap_forward(inputs) # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1]
# scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size]
# if aggregation == "mean":
# return scores.mean(-1)
# return scores
# def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor:
# with torch.no_grad():
# scores = self.vmap_forward(inputs) # [ensemble_size, batch_size, 1]
# scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size]
# if clip:
# assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand"
# scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores)
# scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range
# if aggregation == "mean":
# return scores.mean(-1)
# return scores
# def to(self, device):
# self.device = device
# self.matcher_list = [_matcher.to(device) for _matcher in self.matcher_list]
# self.expert_scores = self.expert_scores.to(device)
# self.generate_scores = self.generate_scores.to(device)
# # prepare for vmap func
# self.params, self.buffers = stack_module_state(self.matcher_list)
# ------------------------------- Others ------------------------------- #
[docs]
class VectorizedCritic(VectorizedMLP):
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
x = super(VectorizedCritic, self).forward(x)
return x.squeeze(-1)
# ------------------------------- Test ------------------------------- #
if __name__ == "__main__":
def test_ensemble_mlp():
rnd = torch.rand((6, 5))
ensemble_size = 3
ensemble_linear = VectorizedMLP(5, 4, hidden_features=16, hidden_layers=2, ensemble_size=ensemble_size)
linear1 = MLP(5, 4, hidden_features=16, hidden_layers=2)
linear2 = MLP(5, 4, hidden_features=16, hidden_layers=2)
linear3 = MLP(5, 4, hidden_features=16, hidden_layers=2)
state_dict = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()])
for ind, state_dict in enumerate(state_dict):
for key, val in state_dict.items():
layer_ind = int(key.split('.')[1])
if key.endswith("weight"):
ensemble_linear.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
ensemble_linear.net[layer_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
ensemble_out = ensemble_linear(rnd)
linear1_out = linear1(rnd)
linear2_out = linear2(rnd)
linear3_out = linear3(rnd)
# breakpoint()
print((ensemble_out[0, ...] - linear1_out).abs().max())
print((ensemble_out[1, ...] - linear2_out).abs().max())
print((ensemble_out[2, ...] - linear3_out).abs().max())
def test_ensemble_res_distinct():
in_fea = 256
out_fea = 15
hidden_features=256
hidden_layers=4
batch_size = 512
ensemble_size = 3
rnd1 = torch.rand((batch_size, in_fea))
rnd2 = torch.rand((batch_size, in_fea))
rnd3 = torch.rand((batch_size, in_fea))
rnd = torch.stack([rnd1, rnd2, rnd3])
ensemble_linear = VectorizedResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, ensemble_size=ensemble_size, norm=None)
linear1 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None)
linear2 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None)
linear3 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None)
state_dicts = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()])
for ind, state_dict in enumerate(state_dicts):
for key, val in state_dict.items():
key_ls = key.split('.')
if "skip_net" in key_ls:
block_ind = int(key_ls[1])
if key.endswith("weight"):
ensemble_linear.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
ensemble_linear.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
elif "process_net" in key_ls:
block_ind, layer_ind = int(key_ls[1]), int(key_ls[3])
if key.endswith("weight"):
ensemble_linear.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
ensemble_linear.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
else:
# linear layer at last
block_ind = int(key_ls[1])
if key.endswith("weight"):
ensemble_linear.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
ensemble_linear.resnet[block_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
ensemble_out = ensemble_linear(rnd)
linear1_out = linear1(rnd1)
linear2_out = linear2(rnd2)
linear3_out = linear3(rnd3)
# breakpoint()
print((ensemble_out[0, ...] - linear1_out).abs().max())
print((ensemble_out[1, ...] - linear2_out).abs().max())
print((ensemble_out[2, ...] - linear3_out).abs().max())
[docs]
class Value_Net_VectorizedCritic(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self,
input_dim_dict,
q_hidden_features,
q_hidden_layers,
num_q_net,
*args, **kwargs) -> None:
super().__init__()
self.kwargs = kwargs
self.input_dim_dict = input_dim_dict
if isinstance(input_dim_dict, dict):
input_dim = sum(input_dim_dict.values())
self.value =VectorizedMLP (input_dim,
1,
q_hidden_features,
q_hidden_layers,
num_q_net)
if 'ts_conv_config' in self.kwargs.keys() and self.kwargs['ts_conv_config'] != None:
other_state_layer_output = q_hidden_features//2
ts_state_conv_layer_output = q_hidden_features//2 + q_hidden_features%2
self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0)
all_node_ts = None
self.conv_ts_node = []
self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input'])
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
self.conv_ts_node.append('__history__'+tsnode)
self.conv_other_node.append('__now__'+tsnode)
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
if all_node_ts == None:
all_node_ts = node_ts
assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}"
#ts-1 remove now state
kernalsize = all_node_ts-1
self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
if hasattr(self, 'conv_ts_node'):
#deal with all ts nodes with ts dim
assert hasattr(self, 'conv_ts_node')
assert hasattr(self, 'conv_other_node')
inputdata = deepcopy(state.detach()) #data is deteched from other node
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True
batch_size = np.prod(list(inputdata[tsnode].shape[:-1]))
original_size = list(inputdata[tsnode].shape[:-1])
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
# inputdata[tsnode]
temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1)
inputdata['__history__'+tsnode] = temp_data[..., :-1, :]
inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1])
state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1)
state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1)
state_other = self.other_state_layer(state_other)
original_size = list(state_other.shape[:-1])
state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1])
state = torch.cat([state_other, state_ts_conv], dim=-1)
else:
pass
# state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1)
output = self.value(state)
output = output.squeeze(-1)
return output