''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2025 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
from loguru import logger
import time
import torch
import math
import warnings
import torch.nn.functional as F
from copy import deepcopy
from torch import nn#, vmap
# from torch.func import stack_module_state, functional_call
from typing import Optional, Union, List, Dict, cast
from collections import OrderedDict, deque
from revive.computation.dists import *
from revive.computation.utils import *
[docs]
def reglu(x):
a, b = x.chunk(2, dim=-1)
return a * F.relu(b)
[docs]
def geglu(x):
a, b = x.chunk(2, dim=-1)
return a * F.gelu(b)
[docs]
class Swish(nn.Module):
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
ACTIVATION_CREATORS = {
'relu' : lambda dim: nn.ReLU(inplace=True),
'elu' : lambda dim: nn.ELU(),
'leakyrelu' : lambda dim: nn.LeakyReLU(negative_slope=0.1, inplace=True),
'tanh' : lambda dim: nn.Tanh(),
'sigmoid' : lambda dim: nn.Sigmoid(),
'identity' : lambda dim: nn.Identity(),
'prelu' : lambda dim: nn.PReLU(dim),
'gelu' : lambda dim: nn.GELU(),
'reglu' : reglu,
'geglu' : geglu,
'swish' : lambda dim: Swish(),
}
# -------------------------------- Layers -------------------------------- #
[docs]
class Value_Net(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self,
input_dim_dict,
value_hidden_features,
value_hidden_layers,
value_normalization,
value_activation,
*args, **kwargs) -> None:
super().__init__()
self.kwargs = kwargs
self.input_dim_dict = input_dim_dict
if isinstance(input_dim_dict, dict):
input_dim = sum(input_dim_dict.values())
else:
input_dim = input_dim_dict
self.value = MLP(input_dim, 1, value_hidden_features, value_hidden_layers,
norm=value_normalization, hidden_activation=value_activation)
if 'ts_conv_config' in self.kwargs.keys() and self.kwargs['ts_conv_config'] != None:
other_state_layer_output = input_dim//2
ts_state_conv_layer_output = input_dim//2 + input_dim%2
self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=value_activation)
all_node_ts = None
self.conv_ts_node = []
self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input'])
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
self.conv_ts_node.append('__history__'+tsnode)
self.conv_other_node.append('__now__'+tsnode)
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
if all_node_ts == None:
all_node_ts = node_ts
assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}"
#ts-1 remove now state
kernalsize = all_node_ts-1
self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=value_activation)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
if hasattr(self, 'conv_ts_node'):
#deal with all ts nodes with ts dim
assert hasattr(self, 'conv_ts_node')
assert hasattr(self, 'conv_other_node')
inputdata = deepcopy(state.detach()) #data is deteched from other node
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
# assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True
batch_size = np.prod(list(inputdata[tsnode].shape[:-1]))
original_size = list(inputdata[tsnode].shape[:-1])
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
# inputdata[tsnode]
temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1)
inputdata['__history__'+tsnode] = temp_data[..., :-1, :]
inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1])
state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1)
state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1)
state_other = self.other_state_layer(state_other)
original_size = list(state_other.shape[:-1])
state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1])
state = torch.cat([state_other, state_ts_conv], dim=-1)
else:
state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1)
output = self.value(state)
return output
[docs]
class ConvBlock(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self,
in_features: int,
out_features: int,
conv_time_step: int,
output_activation : str = 'identity',
):
super().__init__()
output_activation_creator = ACTIVATION_CREATORS[output_activation]
# self.in_features = in_features
# self.out_features = out_features
self.mlp_net = nn.Sequential(
nn.Linear(in_features, out_features),
output_activation_creator(out_features)
)
self.conv_net = nn.Conv1d(in_channels=out_features, out_channels=out_features, groups=out_features, kernel_size=conv_time_step)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp_net(x)
#exchange timestep dim with feature dime
x = x.transpose(-1,-2)
#combine the horizon and batchsize dims
out = self.conv_net(x).squeeze(-1)
return out
[docs]
class VectorizedLinear(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self, in_features: int, out_features: int, ensemble_size: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size
self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))
self.reset_parameters()
[docs]
def reset_parameters(self):
# default pytorch init for nn.Linear module
for layer in range(self.ensemble_size):
nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
input: [ensemble_size, batch_size, input_size]
weight: [ensemble_size, input_size, out_size]
out: [ensemble_size, batch_size, out_size]
"""
# out = torch.einsum("kij,kjl->kil", x, self.weight) + self.bias
out = x @ self.weight + self.bias
# out = torch.bmm(x, self.weight) + self.bias
return out
[docs]
class EnsembleLinear(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
num_ensemble: int,
weight_decay: float = 0.0
) -> None:
super().__init__()
self.num_ensemble = num_ensemble
self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim)))
self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim)))
nn.init.trunc_normal_(self.weight, std=1 / (2 * input_dim ** 0.5))
self.register_parameter("saved_weight", nn.Parameter(self.weight.detach().clone()))
self.register_parameter("saved_bias", nn.Parameter(self.bias.detach().clone()))
self.weight_decay = weight_decay
self.device = torch.device('cpu')
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.weight
bias = self.bias
if len(x.shape) == 2:
x = torch.einsum('ij,bjk->bik', x, weight)
elif len(x.shape) == 3:
if x.shape[0] == weight.data.shape[0]:
x = torch.einsum('bij,bjk->bik', x, weight)
else:
x = torch.einsum('cij,bjk->bcik', x, weight)
elif len(x.shape) == 4:
if x.shape[0] == weight.data.shape[0]:
x = torch.einsum('cbij,cjk->cbik', x, weight)
else:
x = torch.einsum('cdij,bjk->bcdik', x, weight)
elif len(x.shape) == 5:
x = torch.einsum('bcdij,bjk->bcdik', x, weight)
assert x.shape[0] == bias.shape[0] and x.shape[-1] == bias.shape[-1]
if len(x.shape) == 4:
bias = bias.unsqueeze(1)
elif len(x.shape) == 5:
bias = bias.unsqueeze(1)
bias = bias.unsqueeze(1)
x = x + bias
return x
[docs]
def load_save(self) -> None:
self.weight.data.copy_(self.saved_weight.data)
self.bias.data.copy_(self.saved_bias.data)
[docs]
def update_save(self, indexes: List[int]) -> None:
self.saved_weight.data[indexes] = self.weight.data[indexes]
self.saved_bias.data[indexes] = self.bias.data[indexes]
[docs]
def get_decay_loss(self) -> torch.Tensor:
decay_loss = self.weight_decay * (0.5 * ((self.weight ** 2).sum()))
return decay_loss
[docs]
def to(self, device):
if not device == self.device:
self.device = device
super().to(device)
self.weight = self.weight.to(self.device)
self.bias = self.bias.to(self.device)
self.saved_weight = self.saved_weight.to(self.device)
self.saved_bias = self.saved_bias.to(self.device)
# -------------------------------- Backbones -------------------------------- #
[docs]
class MLP(nn.Module):
r"""
Multi-layer Perceptron
Args:
in_features : int, features numbers of the input
out_features : int, features numbers of the output
hidden_features : int, features numbers of the hidden layers
hidden_layers : int, numbers of the hidden layers
norm : str, normalization method between hidden layers, default : None
hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu'
output_activation : str, activation function used in output layer, default : 'identity'
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str = None,
hidden_activation : str = 'leakyrelu',
output_activation : str = 'identity'):
super(MLP, self).__init__()
hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation]
output_activation_creator = ACTIVATION_CREATORS[output_activation]
if hidden_layers == 0:
self.net = nn.Sequential(
nn.Linear(in_features, out_features),
output_activation_creator(out_features)
)
else:
net = []
for i in range(hidden_layers):
net.append(nn.Linear(in_features if i == 0 else hidden_features, hidden_features))
if norm:
if norm == 'ln':
net.append(nn.LayerNorm(hidden_features))
elif norm == 'bn':
net.append(nn.BatchNorm1d(hidden_features))
else:
raise NotImplementedError(f'{norm} does not supported!')
net.append(hidden_activation_creator(hidden_features))
net.append(nn.Linear(hidden_features, out_features))
net.append(output_activation_creator(out_features))
self.net = nn.Sequential(*net)
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
r""" forward method of MLP only assume the last dim of x matches `in_features` """
return self.net(x)
[docs]
class VectorizedMLP(nn.Module):
r"""
Vectorized MLP
Args:
in_features : int, features numbers of the input
out_features : int, features numbers of the output
hidden_features : int, features numbers of the hidden layers
hidden_layers : int, numbers of the hidden layers
norm : str, normalization method between hidden layers, default : None
hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu'
output_activation : str, activation function used in output layer, default : 'identity'
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
norm : str = None,
hidden_activation : str = 'leakyrelu',
output_activation : str = 'identity'):
super(VectorizedMLP, self).__init__()
self.ensemble_size = ensemble_size
hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation]
output_activation_creator = ACTIVATION_CREATORS[output_activation]
if hidden_layers == 0:
self.net = nn.Sequential(
VectorizedLinear(in_features, out_features, ensemble_size),
output_activation_creator(out_features)
)
else:
net = []
for i in range(hidden_layers):
net.append(VectorizedLinear(in_features if i == 0 else hidden_features, hidden_features, ensemble_size))
if norm:
if norm == 'ln':
net.append(nn.LayerNorm(hidden_features))
elif norm == 'bn':
net.append(nn.BatchNorm1d(hidden_features))
else:
raise NotImplementedError(f'{norm} does not supported!')
net.append(hidden_activation_creator(hidden_features))
net.append(VectorizedLinear(hidden_features, out_features, ensemble_size))
net.append(output_activation_creator(out_features))
self.net = nn.Sequential(*net)
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
r""" forward method of MLP only assume the last dim of x matches `in_features` """
if x.dim() == 2: # [batch_size, dim]
x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [ensemble_size, batch_size, x_dim]
assert x.dim() == 3
assert x.shape[0] == self.ensemble_size
return self.net(x)
[docs]
class ResBlock(nn.Module):
"""
Initializes a residual block instance.
Args:
input_feature (int): The number of input features to the block.
output_feature (int): The number of output features from the block.
norm (str, optional): The type of normalization to apply to the block. Default is 'ln' for layer normalization.
"""
def __init__(self, input_feature : int, output_feature : int, norm : str = 'ln', dropout: float = 0):
super().__init__()
self.dropout = nn.Dropout(dropout) if dropout else None
if norm == 'ln':
norm_class = torch.nn.LayerNorm
self.process_net = torch.nn.Sequential(
torch.nn.Linear(input_feature, output_feature),
norm_class(output_feature),
torch.nn.ReLU(True),
torch.nn.Linear(output_feature, output_feature),
norm_class(output_feature),
torch.nn.ReLU(True)
)
else:
self.process_net = torch.nn.Sequential(
torch.nn.Linear(input_feature, output_feature),
torch.nn.ReLU(True),
torch.nn.Linear(output_feature, output_feature),
torch.nn.ReLU(True)
)
if not input_feature == output_feature:
self.skip_net = torch.nn.Linear(input_feature, output_feature)
else:
self.skip_net = torch.nn.Identity()
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
'''x should be a 2D Tensor due to batchnorm'''
if self.dropout is not None:
return self.dropout(self.process_net(x)) + self.skip_net(x)
else:
return self.process_net(x) + self.skip_net(x)
[docs]
class VectorizedResBlock(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self,
in_features: int,
out_features: int,
ensemble_size: int,
norm : str = 'ln'):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size
if norm == 'ln':
norm_class = torch.nn.LayerNorm
self.process_net = torch.nn.Sequential(
VectorizedLinear(in_features, out_features, ensemble_size),
norm_class(out_features),
torch.nn.ReLU(True),
VectorizedLinear(out_features, out_features, ensemble_size),
norm_class(out_features),
torch.nn.ReLU(True)
)
else:
self.process_net = torch.nn.Sequential(
VectorizedLinear(in_features, out_features, ensemble_size),
torch.nn.ReLU(True),
VectorizedLinear(out_features, out_features, ensemble_size),
torch.nn.ReLU(True)
)
if not in_features == out_features:
self.skip_net = VectorizedLinear(in_features, out_features, ensemble_size)
else:
self.skip_net = torch.nn.Identity()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# input: [ensemble_size, batch_size, input_size]
# out: [ensemble_size, batch_size, out_size]
out = self.process_net(x) + self.skip_net(x)
return out
[docs]
class ResNet(torch.nn.Module):
"""
Initializes a residual neural network (ResNet) instance.
Args:
in_features (int): The number of input features to the ResNet.
out_features (int): The number of output features from the ResNet.
hidden_features (int): The number of hidden features in each residual block.
hidden_layers (int): The number of residual blocks in the ResNet.
norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization.
output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str = 'ln',
output_activation : str = 'identity',
dropout: float = 0):
super().__init__()
modules = []
for i in range(hidden_layers):
if i == 0:
modules.append(ResBlock(in_features, hidden_features, norm, dropout))
else:
modules.append(ResBlock(hidden_features, hidden_features, norm, dropout))
modules.append(torch.nn.Linear(hidden_features, out_features))
modules.append(ACTIVATION_CREATORS[output_activation](out_features))
self.resnet = torch.nn.Sequential(*modules)
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
'''NOTE: reshape is needed since resblock only support 2D Tensor'''
shape = x.shape
x = x.view(-1, shape[-1])
output = self.resnet(x)
output = output.view(*shape[:-1], -1)
return output
[docs]
class VectorizedResNet(torch.nn.Module):
"""
Initializes a residual neural network (ResNet) instance.
Args:
in_features (int): The number of input features to the ResNet.
out_features (int): The number of output features from the ResNet.
hidden_features (int): The number of hidden features in each residual block.
hidden_layers (int): The number of residual blocks in the ResNet.
norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization.
output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
norm : str = 'ln',
output_activation : str = 'identity'):
super().__init__()
self.ensemble_size = ensemble_size
modules = []
for i in range(hidden_layers):
if i == 0:
modules.append(VectorizedResBlock(in_features, hidden_features, ensemble_size, norm))
else:
modules.append(VectorizedResBlock(hidden_features, hidden_features, ensemble_size, norm))
modules.append(VectorizedLinear(hidden_features, out_features, ensemble_size))
modules.append(ACTIVATION_CREATORS[output_activation](out_features))
self.resnet = torch.nn.Sequential(*modules)
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
'''NOTE: reshape is needed since resblock only support 2D Tensor'''
shape = x.shape
if len(shape) == 2:
x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [batch_size, dim] --> [ensemble_size, batch_size, dim]
assert x.dim() == 3, f"shape == {shape}, with length == {len(shape)}, expect length 3"
assert x.shape[0] == self.ensemble_size # [ensemble_size, batch_size, dim]
output = self.resnet(x)
return output
[docs]
class TAPE(nn.Module):
def __init__(self, d_model, max_len=200, scale_factor=1.0):
super(TAPE, self).__init__()
pe = torch.zeros(max_len, d_model) # positional encoding
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin((position * div_term)*(d_model/max_len))
pe[:, 1::2] = torch.cos((position * div_term)*(d_model/max_len))
pe = scale_factor * pe
self.register_buffer('pe', pe) # this stores the variable in the state_dict (used for non-trainable variables)
[docs]
def forward(self, x):
return x + self.pe[:x.shape[1], :]
[docs]
class TAttention(nn.Module):
def __init__(self, d_model, nhead, dropout):
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.qtrans = nn.Linear(d_model, d_model, bias=False)
self.ktrans = nn.Linear(d_model, d_model, bias=False)
self.vtrans = nn.Linear(d_model, d_model, bias=False)
self.attn_dropout = []
if dropout > 0:
for i in range(nhead):
self.attn_dropout.append(nn.Dropout(p=dropout))
self.attn_dropout = nn.ModuleList(self.attn_dropout)
# input LayerNorm
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
# FFN layerNorm
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
# FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Dropout(p=dropout),
nn.Linear(d_model, d_model),
nn.Dropout(p=dropout)
)
[docs]
def forward(self, x):
x = self.norm1(x)
q = self.qtrans(x)
k = self.ktrans(x)
v = self.vtrans(x)
dim = int(self.d_model / self.nhead)
att_output = []
for i in range(self.nhead):
if i==self.nhead-1:
qh = q[:, :, i * dim:]
kh = k[:, :, i * dim:]
vh = v[:, :, i * dim:]
else:
qh = q[:, :, i * dim:(i + 1) * dim]
kh = k[:, :, i * dim:(i + 1) * dim]
vh = v[:, :, i * dim:(i + 1) * dim]
atten_ave_matrixh = torch.softmax(torch.matmul(qh, kh.transpose(1, 2)), dim=-1)
if self.attn_dropout:
atten_ave_matrixh = self.attn_dropout[i](atten_ave_matrixh)
att_output.append(torch.matmul(atten_ave_matrixh, vh))
att_output = torch.concat(att_output, dim=-1)
# FFN
xt = x + att_output
xt = self.norm2(xt)
att_output = xt + self.ffn(xt)
return att_output
[docs]
class Tokenizer(nn.Module):
category_offsets: Optional[torch.Tensor]
def __init__(
self,
d_numerical: int,
categories: Optional[List[int]],
d_token: int,
bias: bool,
) -> None:
super().__init__()
if categories is None:
d_bias = d_numerical
self.category_offsets = None
self.category_embeddings = None
else:
d_bias = d_numerical + len(categories)
category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
self.register_buffer('category_offsets', category_offsets)
self.category_embeddings = nn.Embedding(sum(categories), d_token)
nn.init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
print(f'{self.category_embeddings.weight.shape=}')
# take [CLS] token into account
self.weight = nn.Parameter(torch.Tensor(d_numerical + 1, d_token))
self.bias = nn.Parameter(torch.Tensor(d_bias, d_token)) if bias else None
# The initialization is inspired by nn.Linear
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5))
@property
def n_tokens(self) -> int:
return len(self.weight) + (
0 if self.category_offsets is None else len(self.category_offsets)
)
[docs]
def forward(self, x_num: torch.Tensor, x_cat: Optional[torch.Tensor]):
x_some = x_num if x_cat is None else x_cat
assert x_some is not None
x_num = torch.cat(
[torch.ones(len(x_some), 1, device=x_some.device)] # [CLS]
+ ([] if x_num is None else [x_num]),
dim=1,
)
x = self.weight[None] * x_num[:, :, None]
if x_cat is not None:
x = torch.cat(
[x, self.category_embeddings(x_cat + self.category_offsets[None])],
dim=1,
)
if self.bias is not None:
bias = torch.cat(
[
torch.zeros(1, self.bias.shape[1], device=x.device),
self.bias,
]
)
x = x + bias[None]
return x
[docs]
class MultiheadAttention(nn.Module):
def __init__(
self, d: int, n_heads: int, dropout: float, initialization: str
) -> None:
if n_heads > 1:
assert d % n_heads == 0
assert initialization in ['xavier', 'kaiming']
super().__init__()
self.W_q = nn.Linear(d, d)
self.W_k = nn.Linear(d, d)
self.W_v = nn.Linear(d, d)
self.W_out = nn.Linear(d, d) if n_heads > 1 else None
self.n_heads = n_heads
self.dropout = nn.Dropout(dropout) if dropout else None
for m in [self.W_q, self.W_k, self.W_v]:
if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
# gain is needed since W_qkv is represented with 3 separate layers
nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
nn.init.zeros_(m.bias)
if self.W_out is not None:
nn.init.zeros_(self.W_out.bias)
def _reshape(self, x):
batch_size, n_tokens, d = x.shape
d_head = d // self.n_heads
return (
x.reshape(batch_size, n_tokens, self.n_heads, d_head)
.transpose(1, 2)
.reshape(batch_size * self.n_heads, n_tokens, d_head)
)
[docs]
def forward(
self,
x_q,
x_kv,
key_compression: Optional[nn.Linear],
value_compression: Optional[nn.Linear],
):
q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
for tensor in [q, k, v]:
assert tensor.shape[-1] % self.n_heads == 0
if key_compression is not None:
assert value_compression is not None
k = key_compression(k.transpose(1, 2)).transpose(1, 2)
v = value_compression(v.transpose(1, 2)).transpose(1, 2)
else:
assert value_compression is None
batch_size = len(q)
d_head_key = k.shape[-1] // self.n_heads
d_head_value = v.shape[-1] // self.n_heads
n_q_tokens = q.shape[1]
q = self._reshape(q)
k = self._reshape(k)
attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1)
if self.dropout is not None:
attention = self.dropout(attention)
x = attention @ self._reshape(v)
x = (
x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
.transpose(1, 2)
.reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
)
if self.W_out is not None:
x = self.W_out(x)
return x
[docs]
class DistributionWrapper(nn.Module):
r"""wrap output of Module to distribution"""
BASE_TYPES = ['normal', 'gmm', 'onehot', 'discrete_logistic']
SUPPORTED_TYPES = BASE_TYPES + ['mix']
def __init__(self, distribution_type : str = 'normal', **params):
super().__init__()
self.distribution_type = distribution_type
self.params = params
assert self.distribution_type in self.SUPPORTED_TYPES, f"{self.distribution_type} is not supported!"
if self.distribution_type == 'normal':
self.max_logstd = nn.Parameter(torch.ones(self.params['dim']) * 0.0, requires_grad=True)
self.min_logstd = nn.Parameter(torch.ones(self.params['dim']) * -7, requires_grad=True)
if not self.params.get('conditioned_std', True):
self.logstd = nn.Parameter(torch.zeros(self.params['dim']), requires_grad=True)
elif self.distribution_type == 'gmm':
self.max_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * 0, requires_grad=True)
self.min_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * -10, requires_grad=True)
if not self.params.get('conditioned_std', True):
self.logstd = nn.Parameter(torch.zeros(self.params['mixture'], self.params['dim']), requires_grad=True)
elif self.distribution_type == 'discrete_logistic':
self.num = self.params['num']
elif self.distribution_type == 'mix':
assert 'dist_config' in self.params.keys(), "You need to provide `dist_config` for Mix distribution"
self.dist_config = self.params['dist_config']
self.wrapper_list = []
self.input_sizes = []
self.output_sizes = []
for config in self.dist_config:
dist_type = config['type']
assert dist_type in self.SUPPORTED_TYPES, f"{dist_type} is not supported!"
assert not dist_type == 'mix', "recursive MixDistribution is not supported!"
self.wrapper_list.append(DistributionWrapper(dist_type, **config))
self.input_sizes.append(config['dim'])
self.output_sizes.append(config['output_dim'])
self.wrapper_list = nn.ModuleList(self.wrapper_list)
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, payload : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
'''
Warp the given tensor to distribution
:param adapt_std : it will overwrite the std part of the distribution (optional)
:param payload : payload will be applied to the output distribution after built (optional)
'''
# [ OTHER ] replace the clip with tanh
# [ OTHER ] better std controlling strategy is needed todo
if self.distribution_type == 'normal':
if self.params.get('conditioned_std', True):
mu, logstd = torch.chunk(x, 2, dim=-1)
if 'soft_clamp' in kwargs and kwargs['soft_clamp'] == True:
# distribution wrapper only for revive_f venv training
logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd)
std = torch.exp(logstd)
else:
# distribution wrapper for others
logstd_logit = logstd
max_std = 0.5
min_std = 0.001
std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std
else:
mu, logstd = x, self.logstd
logstd_logit = self.logstd
max_std = 0.5
min_std = 0.001
std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std
# std = torch.exp(logstd)
if payload is not None:
mu = mu + safe_atanh(payload)
mu = torch.tanh(mu)
"""replace tanh
node = kwargs["node"]
mu = node.processor.process_torch({node.name:mu})[node.name]
mu = torch.clamp(mu, min=-1.1, max=1.1)
"""
if adapt_std is not None:
std = torch.ones_like(mu).to(mu) * adapt_std
return DiagnalNormal(mu, std)
elif self.distribution_type == 'gmm':
if self.params.get('conditioned_std', True):
logits, mus, logstds = torch.split(x, [self.params['mixture'],
self.params['mixture'] * self.params['dim'],
self.params['mixture'] * self.params['dim']], dim=-1)
mus = mus.view(*mus.shape[:-1], self.params['mixture'], self.params['dim'])
logstds = logstds.view(*logstds.shape[:-1], self.params['mixture'], self.params['dim'])
else:
logits, mus = torch.split(x, [self.params['mixture'], self.params['mixture'] * self.params['dim']], dim=-1)
logstds = self.logstd
if payload is not None:
mus = mus + safe_atanh(payload.unsqueeze(dim=-2))
mus = torch.tanh(mus)
stds = adapt_std if adapt_std is not None else torch.exp(soft_clamp(logstds, self.min_logstd, self.max_logstd))
return GaussianMixture(mus, stds, logits)
elif self.distribution_type == 'onehot':
return Onehot(x)
elif self.distribution_type == 'discrete_logistic':
mu, logstd = torch.chunk(x, 2, dim=-1)
logstd = torch.clamp(logstd, -7, 1)
return DiscreteLogistic(mu, torch.exp(logstd), num=self.num)
elif self.distribution_type == 'mix':
xs = torch.split(x, self.output_sizes, dim=-1)
if isinstance(adapt_std, float):
adapt_stds = [adapt_std] * len(self.input_sizes)
elif adapt_std is not None:
adapt_stds = torch.split(adapt_std, self.input_sizes, dim=-1)
else:
adapt_stds = [None] * len(self.input_sizes)
if payload is not None:
payloads = torch.split(payload, self.input_sizes + [payload.shape[-1] - sum(self.input_sizes)], dim=-1)[:-1]
else:
payloads = [None] * len(self.input_sizes)
dists = [wrapper(x, _adapt_std, _payload, **kwargs) for x, _adapt_std, _payload, wrapper in zip(xs, adapt_stds, payloads, self.wrapper_list)]
return MixDistribution(dists)
# --------------------------------- Policies -------------------------------- #
[docs]
class FeedForwardPolicy(torch.nn.Module):
"""Policy for using mlp, resnet and transformer backbone.
Args:
in_features : The number of input features, or a dictionary describing the input distribution
out_features : The number of output features
hidden_features : The number of hidden features in each layer
hidden_layers : The number of hidden layers
dist_config : A list of configurations for the distributions to use in the model
norm : The type of normalization to apply to the input features
hidden_activation : The activation function to use in the hidden layers
backbone_type : The type of backbone to use in the model
use_multihead : Whether to use a multihead model
use_feature_embed : Whether to use feature embedding
"""
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
norm : str = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'mlp',
use_multihead : bool = False,
**kwargs):
super().__init__()
self.multihead = []
self.dist_config = dist_config
self.kwargs = kwargs
if isinstance(in_features, dict):
in_features = sum(in_features.values())
if not use_multihead:
if backbone_type == 'mlp':
self.backbone = MLP(in_features, out_features, hidden_features, hidden_layers, norm, hidden_activation)
elif backbone_type == 'res':
self.backbone = ResNet(in_features, out_features, hidden_features, hidden_layers, norm)
elif backbone_type == 'ft_transformer':
self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers)
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
else:
if backbone_type == 'mlp':
self.backbone = MLP(in_features, hidden_features, hidden_features, hidden_layers, norm, hidden_activation)
elif backbone_type == 'res':
self.backbone = ResNet(in_features, hidden_features, hidden_features, hidden_layers, norm)
elif backbone_type == 'ft_transformer':
self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers)
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
for dist in self.dist_config:
dist_type = dist["type"]
output_dim = dist["output_dim"]
if dist_type == "onehot" or dist_type == "discrete_logistic":
self.multihead.append(MLP(hidden_features, output_dim, 64, 1, norm=None, hidden_activation='leakyrelu'))
elif dist_type == "normal":
normal_dim = dist["dim"]
for dim in range(normal_dim):
self.multihead.append(MLP(hidden_features, int(output_dim // normal_dim), 64, 1, norm=None, hidden_activation='leakyrelu'))
else:
raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.')
self.multihead = nn.ModuleList(self.multihead)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
if "ts_conv_config" in self.kwargs and self.kwargs['ts_conv_config'] != None:
other_state_layer_output = in_features//2
ts_state_conv_layer_output = in_features//2 + in_features%2
self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=hidden_activation)
all_node_ts = None
self.conv_ts_node = []
self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input'])
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
self.conv_ts_node.append('__history__'+tsnode)
self.conv_other_node.append('__now__'+tsnode)
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
if all_node_ts == None:
all_node_ts = node_ts
assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}"
#ts-1 remove now state
kernalsize = all_node_ts-1
self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=hidden_activation)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
if "ts_conv_config" in self.kwargs and self.kwargs['ts_conv_config'] != None:
#deal with all ts nodes with ts dim
# assert 'input_names' in kwargs
assert hasattr(self, 'conv_ts_node')
assert hasattr(self, 'conv_other_node')
inputdata = deepcopy(state) #data is deteched from other node
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
# assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True
batch_size = int(np.prod(list(inputdata[tsnode].shape[:-1])))
original_size = list(inputdata[tsnode].shape[:-1])
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
# inputdata[tsnode]
temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1)
inputdata['__history__'+tsnode] = temp_data[..., :-1, :]
inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1])
state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1)
state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1)
state_other = self.other_state_layer(state_other)
original_size = list(state_other.shape[:-1])
state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1])
state = torch.cat([state_other, state_ts_conv], dim=-1)
elif isinstance(state, dict):
assert 'input_names' in kwargs
state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1)
if not self.multihead:
output = self.backbone(state)
else:
backbone_output = self.backbone(state)
multihead_output = []
multihead_index = 0
for dist in self.dist_config:
dist_type = dist["type"]
output_dim = dist["output_dim"]
if dist_type == "onehot" or dist_type == "discrete_logistic":
multihead_output.append(self.multihead[multihead_index](backbone_output))
multihead_index += 1
elif dist_type == "normal":
normal_mode_output = []
normal_std_output = []
for head in self.multihead[multihead_index:]:
head_output = head(backbone_output)
if head_output.shape[-1] == 1:
mode = head_output
normal_mode_output.append(mode)
else:
mode, std = torch.chunk(head_output, 2, axis=-1)
normal_mode_output.append(mode)
normal_std_output.append(std)
normal_output = torch.cat(normal_mode_output, axis=-1)
if normal_std_output:
normal_std_output = torch.cat(normal_std_output, axis=-1)
normal_output = torch.cat([normal_output, normal_std_output], axis=-1)
multihead_output.append(normal_output)
break
else:
raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.')
output = torch.cat(multihead_output, axis= -1)
soft_clamp_flag = True if "soft_clamp" in self.kwargs and self.kwargs['soft_clamp'] == True else False
dist = self.dist_wrapper(output, adapt_std, soft_clamp=soft_clamp_flag, **kwargs)
if hasattr(self, "dist_mu_shift"):
dist = dist.shift(self.dist_mu_shift)
return dist
[docs]
@torch.no_grad()
def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True):
dist = self(state)
return dist.mode if deterministic else dist.sample()
[docs]
class RecurrentPolicy(torch.nn.Module):
"""Initializes a recurrent policy network instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features in the RNN.
hidden_layers (int): The number of layers in the RNN.
dist_config (list): The configuration for the distributions used in the model.
backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
**kwargs):
super().__init__()
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
rnn_hidden_features = 256
rnn_hidden_layers = min(hidden_layers, 3)
rnn_output_features = min(out_features,8)
self.rnn = RNN(in_features, rnn_hidden_features, rnn_hidden_layers)
self.rnn_mlp = MLP(rnn_hidden_features, rnn_output_features, 0, 0)
self.backbone = ResNet(in_features+rnn_output_features, out_features, hidden_features, hidden_layers, norm)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
def reset(self):
self.h = None
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
x_shape = x.shape
if len(x_shape) == 1:
rnn_output, self.h = self.rnn(x.unsqueeze(0).unsqueeze(0), self.h)
rnn_output = rnn_output.squeeze(0).squeeze(0)
elif len(x_shape) == 2:
rnn_output, self.h = self.rnn(x.unsqueeze(0), self.h)
rnn_output = rnn_output.squeeze(0)
else:
rnn_output, self.h = self.rnn(x, self.h)
rnn_output = self.rnn_mlp(rnn_output)
logits = self.backbone(torch.cat([x, rnn_output],axis=-1))
return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs]
class TsRecurrentPolicy(torch.nn.Module):
input_is_dict = True
"""Initializes a recurrent policy network instance for ts_node.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features in the RNN.
hidden_layers (int): The number of layers in the RNN.
dist_config (list): The configuration for the distributions used in the model.
backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
ts_input_names : dict = {},
bidirectional: bool = True,
dropout: float = 0.1,
rnn_input_withtime=True,
smooth_input=False,
**kwargs):
super().__init__()
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
self.ts_input_names = ts_input_names
self.input_dim_dict = kwargs.get("input_dim_dict",{})
self.max_ts_steps = max([v[1] for v in self.ts_input_names.values()])
self.rnn_input_withtime = rnn_input_withtime
self.smooth_input = smooth_input
if self.smooth_input:
self.smooth_linear_1 = nn.Linear(in_features, in_features)
self.smooth_linear_2 = nn.Linear(in_features, in_features)
if self.rnn_input_withtime:
rnn_in_features = in_features + 1
else:
rnn_in_features = in_features
self.rnn = RNN(input_size=rnn_in_features,
hidden_size=hidden_features,
num_layers=hidden_layers,
bidirectional=bidirectional,
batch_first=True,
dropout=dropout if hidden_layers > 1 else 0)
if bidirectional:
self.linear = nn.Linear(hidden_features * 2, out_features)
# self.linear = ResNet(hidden_features * 2, out_features, hidden_features, hidden_layers, norm, output_activation='identity',dropout=dropout)
else:
self.linear = nn.Linear(hidden_features, out_features)
# self.linear = ResNet(hidden_features, out_features, hidden_features, hidden_layers, norm, output_activation='identity',dropout=dropout)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
def reset(self):
self.h = None
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
x_shape = x[list(x.keys())[0]].shape
if len(x_shape) == 3:
x = {k:v.reshape(-1,v.shape[-1]) for k,v in x.items()}
input = {}
for k,v in x.items():
if k in self.ts_input_names:
ts_steps = self.ts_input_names.get(k)[1]
else:
ts_steps = 1
input[k] = v.reshape(v.shape[0],ts_steps,-1)
if ts_steps < self.max_ts_steps:
# input[k] = torch.cat([input[k][:,:1].repeat(1, self.max_ts_steps+1-ts_steps, 1), input[k]],axis=1)
input[k] = torch.cat([input[k], input[k][:,-1:].repeat(1, self.max_ts_steps-ts_steps, 1)],axis=1)
if len(input) == 1:
rnn_input = list(input.values())[0]
else:
rnn_input = torch.cat(list(input.values()),axis=-1)
if self.smooth_input:
rnn_input_pool = F.avg_pool1d(rnn_input.permute(0,2,1), kernel_size=3, stride=1, padding=1).permute(0,2,1)
rnn_input = self.smooth_linear_1(rnn_input_pool) + self.smooth_linear_2(rnn_input-rnn_input_pool)
if self.rnn_input_withtime:
time_steps = torch.arange(rnn_input.shape[1]).unsqueeze(0).unsqueeze(2).expand(rnn_input.shape[0], rnn_input.shape[1], 1).float().to(rnn_input.device)
time_steps_normalized = time_steps / (rnn_input.shape[1] - 1)
rnn_input = torch.cat([rnn_input, time_steps_normalized], dim=2)
out, _ = self.rnn(rnn_input,None)
out = self.linear(out[:, -1, :])
if len(x_shape) == 3:
logits = out.reshape(x_shape[0],x_shape[1],-1)
else:
logits = out.reshape(x_shape[0],-1)
return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs]
class RecurrentRESPolicy(torch.nn.Module):
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
rnn_hidden_features : int = 64,
window_size : int = 0,
**kwargs):
super().__init__()
self.kwargs = kwargs
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
self.in_feature_embed = nn.Sequential(
nn.Linear(in_features, hidden_features),
ACTIVATION_CREATORS['relu'](hidden_features),
nn.Dropout(0.1),
nn.LayerNorm(hidden_features)
)
self.side_net = nn.Sequential(
ResBlock(hidden_features, hidden_features),
nn.Dropout(0.1),
nn.LayerNorm(hidden_features),
ResBlock(hidden_features, hidden_features),
nn.Dropout(0.1),
nn.LayerNorm(hidden_features)
)
self.rnn = RNN(hidden_features, rnn_hidden_features, hidden_layers)
self.backbone = MLP(hidden_features + rnn_hidden_features, out_features, hidden_features, 1)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
self.hidden_layers = hidden_layers
self.hidden_features = hidden_features
self.rnn_hidden_features = rnn_hidden_features
self.window_size = window_size
self.hidden_que = deque(maxlen=self.window_size) # [(h_0, c_0), (h_1, c_1), ...]
self.h = None
[docs]
def reset(self):
self.hidden_que = deque(maxlen=self.window_size)
self.h = None
[docs]
def preprocess_window(self, a_embed):
shape = a_embed.shape
if len(self.hidden_que) < self.window_size: # do Not pop
if self.hidden_que:
h = self.hidden_que[0]
else:
h = (torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed), torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed))
else:
h = self.hidden_que.popleft()
hidden_cat = torch.concat([h[0]] + [hidden[0] for hidden in self.hidden_que] + [torch.zeros_like(h[0]).to(h[0])], dim=1) # (layers, bs * 4, dim)
cell_cat = torch.concat([h[1]] + [hidden[1] for hidden in self.hidden_que] + [torch.zeros_like(h[1]).to(h[1])], dim=1) # (layers, bs * 4, dim)
a_embed = torch.repeat_interleave(a_embed, repeats=len(self.hidden_que)+2, dim=0).view(1, (len(self.hidden_que)+2)*shape[1], -1) # (1, bs * 4, dim)
return a_embed, hidden_cat, cell_cat
[docs]
def postprocess_window(self, rnn_output, hidden_cat, cell_cat):
hidden_cat = torch.chunk(hidden_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim)
cell_cat = torch.chunk(cell_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim)
rnn_output = torch.chunk(rnn_output, chunks=len(self.hidden_que)+2, dim=1) # tuples of (1, bs, dim)
self.hidden_que = deque(zip(hidden_cat[1:], cell_cat[1:]), maxlen=self.window_size) # important to discrad the first element !!!
rnn_output = rnn_output[0] # (1, bs, dim)
return rnn_output
[docs]
def rnn_forward(self, a_embed, joint_train : bool):
if self.window_size > 0:
a_embed, hidden_cat, cell_cat = self.preprocess_window(a_embed)
if joint_train:
rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat))
else:
with torch.no_grad():
rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat))
rnn_output = self.postprocess_window(rnn_output, hidden_cat, cell_cat) #(1, bs, dim)
else:
if joint_train:
rnn_output, self.h = self.rnn(a_embed, self.h) #(1, bs, dim)
else:
with torch.no_grad():
rnn_output, self.h = self.rnn(a_embed, self.h)
return rnn_output
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution:
if field == 'bc':
shape = x.shape
joint_train = True
if len(shape) == 2: # (bs, dim)
x_embed = self.in_feature_embed(x)
side_output = self.side_net(x_embed)
a_embed = x_embed.unsqueeze(0) # [1, bs, dim]
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
rnn_output = rnn_output.squeeze(0) # (bs, dim)
logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim)
else:
assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}"
self.reset()
output = []
for i in range(shape[0]):
a = x[i]
a = a.unsqueeze(0) #(1, bs, dim)
a_embed = self.in_feature_embed(a) # (1, bs, dim)
side_output = self.side_net(a_embed)
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim)
output.append(backbone_output)
logits = torch.concat(output, dim=0)
elif field == 'mail':
shape = x.shape
joint_train = False
if len(shape) == 2:
x_embed = self.in_feature_embed(x)
side_output = self.side_net(x_embed)
a_embed = x_embed.unsqueeze(0)
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
rnn_output = rnn_output.squeeze(0) # (bs, dim)
logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim)
else:
assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}"
self.reset()
output = []
for i in range(shape[0]):
a = x[i]
a = a.unsqueeze(0) #(1, bs, dim)
a_embed = self.in_feature_embed(a) # (1, bs, dim)
side_output = self.side_net(a_embed)
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim)
output.append(backbone_output)
logits = torch.concat(output, dim=0)
else:
raise NotImplementedError(f"unknow field: {field} in RNN training !")
return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs]
class ContextualPolicy(torch.nn.Module):
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] ='contextual_gru',
**kwargs):
super().__init__()
self.kwargs = kwargs
RNN = torch.nn.GRU if backbone_type == 'contextual_gru' else torch.nn.LSTM
self.preprocess_mlp = MLP(in_features, hidden_features, 0, 0, output_activation='leakyrelu')
self.rnn = RNN(hidden_features, hidden_features, 1)
self.mlp = MLP(hidden_features + in_features, out_features, hidden_features, hidden_layers)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
def reset(self):
self.h = None
[docs]
def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
in_feature = x
if len(x.shape) == 2:
x = x.unsqueeze(0)
x = self.preprocess_mlp(x)
rnn_output, self.h = self.rnn(x, self.h)
rnn_output = rnn_output.squeeze(0)
else:
x = self.preprocess_mlp(x)
rnn_output, self.h = self.rnn(x)
logits = self.mlp(torch.cat((in_feature, rnn_output), dim=-1))
return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs]
class EnsembleFeedForwardPolicy(torch.nn.Module):
"""Policy for using mlp, resnet and transformer backbone.
Args:
in_features : The number of input features, or a dictionary describing the input distribution
out_features : The number of output features
hidden_features : The number of hidden features in each layer
hidden_layers : The number of hidden layers
dist_config : A list of configurations for the distributions to use in the model
norm : The type of normalization to apply to the input features
hidden_activation : The activation function to use in the hidden layers
backbone_type : The type of backbone to use in the model
use_multihead : Whether to use a multihead model
use_feature_embed : Whether to use feature embedding
"""
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
dist_config : list,
norm : str = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'EnsembleRES',
use_multihead : bool = False,
**kwargs):
super().__init__()
self.multihead = []
self.ensemble_size = ensemble_size
self.dist_config = dist_config
self.kwargs = kwargs
if isinstance(in_features, dict):
in_features = sum(in_features.values())
if backbone_type == 'EnsembleMLP':
self.backbone = VectorizedMLP(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation)
elif backbone_type == 'EnsembleRES':
self.backbone = VectorizedResNet(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm)
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
if isinstance(state, dict):
assert 'input_names' in kwargs
state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1)
output = self.backbone(state)
dist = self.dist_wrapper(output, adapt_std, **kwargs)
if hasattr(self, "dist_mu_shift"):
dist = dist.shift(self.dist_mu_shift)
return dist
[docs]
@torch.no_grad()
def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True):
dist = self(state)
return dist.mode if deterministic else dist.sample()
# ------------------------------- Transitions ------------------------------- #
[docs]
class FeedForwardTransition(FeedForwardPolicy):
r"""Initializes a feedforward transition instance.
Args:
in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The configuration of the distribution to use.
norm (Optional[str]): The normalization method to use. None if no normalization is used.
hidden_activation (str): The activation function to use for the hidden layers.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode to use for the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation for 'local' mode.
use_feature_embed (bool): Whether to use feature embedding.
"""
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
norm : Optional[str] = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'mlp',
mode : str = 'global',
obs_dim : Optional[int] = None,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local':
dist_types = [config['type'] for config in dist_config]
if 'onehot' in dist_types or 'discrete_logistic' in dist_types:
warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!')
self.mode = 'global'
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(FeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers,
dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
dist = super(FeedForwardTransition, self).forward(state, adapt_std, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]
class RecurrentTransition(RecurrentPolicy):
r"""
Initializes a recurrent transition instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The distribution configuration.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode of the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
mode : str = 'global',
obs_dim : Optional[int] = None,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(RecurrentTransition, self).__init__(in_features,
out_features,
hidden_features,
hidden_layers,
norm,
dist_config,
backbone_type,
**kwargs)
[docs]
def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
dist = super(RecurrentTransition, self).forward(state, adapt_std, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]
class TsRecurrentTransition(TsRecurrentPolicy):
input_is_dict = True
r"""
Initializes a recurrent transition instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The distribution configuration.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode of the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation.
"""
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
mode : str = 'global',
obs_dim : Optional[int] = None,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(TsRecurrentTransition, self).__init__(in_features,
out_features,
hidden_features,
hidden_layers,
norm,
dist_config,
backbone_type,
**kwargs)
[docs]
def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
dist = super(TsRecurrentTransition, self).forward(state, adapt_std, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]
class RecurrentRESTransition(RecurrentRESPolicy):
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] = 'mlp',
mode : str = 'global',
obs_dim : Optional[int] = None,
rnn_hidden_features : int = 64,
window_size : int = 0,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(RecurrentRESTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers,
dist_config, backbone_type,
rnn_hidden_features=rnn_hidden_features, window_size=window_size,
**kwargs)
[docs]
def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution:
dist = super(RecurrentRESTransition, self).forward(state, adapt_std, field, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]
class EnsembleFeedForwardTransition(EnsembleFeedForwardPolicy):
r"""Initializes a feedforward transition instance.
Args:
in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The configuration of the distribution to use.
norm (Optional[str]): The normalization method to use. None if no normalization is used.
hidden_activation (str): The activation function to use for the hidden layers.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode to use for the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation for 'local' mode.
use_feature_embed (bool): Whether to use feature embedding.
"""
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
dist_config : list,
norm : Optional[str] = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'EnsembleRES',
mode : str = 'global',
obs_dim : Optional[int] = None,
**kwargs):
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local':
dist_types = [config['type'] for config in dist_config]
if 'onehot' in dist_types or 'discrete_logistic' in dist_types:
warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!')
self.mode = 'global'
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(EnsembleFeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, ensemble_size,
dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
dist = super(EnsembleFeedForwardTransition, self).forward(state, adapt_std, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
# ------------------------------- Matchers ------------------------------- #
[docs]
class FeedForwardMatcher(torch.nn.Module):
r"""
Initializes a feedforward matcher instance.
Args:
in_features (int): The number of input features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
hidden_activation (str): The activation function to use for the hidden layers.
norm (str): The normalization method to use. None if no normalization is used.
backbone_type (Union[str, np.str_]): The type of backbone to use.
"""
def __init__(self,
in_features : int,
hidden_features : int,
hidden_layers : int,
hidden_activation : str = 'leakyrelu',
norm : str = None,
backbone_type : Union[str, np.str_] = 'mlp',
**kwargs):
super().__init__()
if backbone_type == 'mlp':
self.backbone = MLP(in_features, 1, hidden_features, hidden_layers, norm, hidden_activation, output_activation='sigmoid')
elif backbone_type == 'res':
self.backbone = ResNet(in_features, 1, hidden_features, hidden_layers, norm, output_activation='sigmoid')
elif backbone_type == 'ft_transformer':
self.backbone = FT_Transformer(in_features, 1, hidden_features, hidden_layers=hidden_layers)
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
# x.requires_grad = True
shape = x.shape
if len(shape) == 3: # [horizon, batch_size, dim]
x = x.view(-1, shape[-1]) # [horizon * batch_size, dim]
assert len(x.shape) == 2, f"len(x.shape) == {len(x.shape)}, expect 2"
out = self.backbone(x) # [batch_size, dim]
if len(shape) == 3:
out = out.view(*shape[:-1], -1) # [horizon, batch_size, dim]
return out
[docs]
class RecurrentMatcher(torch.nn.Module):
def __init__(self,
in_features : int,
hidden_features : int,
hidden_layers : int,
backbone_type : Union[str, np.str_] = 'gru',
bidirect : bool = False):
super().__init__()
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
self.rnn = RNN(in_features, hidden_features, hidden_layers, bidirectional=bidirect)
self.output_layer = MLP(hidden_features * (2 if bidirect else 1), 1, 0, 0, output_activation='sigmoid')
[docs]
def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
x = torch.cat(inputs, dim=-1)
rnn_output = self.rnn(x)[0]
return self.output_layer(rnn_output)
[docs]
class HierarchicalMatcher(torch.nn.Module):
def __init__(self,
in_features : list,
hidden_features : int,
hidden_layers : int,
hidden_activation : int,
norm : str = None):
super().__init__()
self.in_features = in_features
process_layers = []
output_layers = []
feature = self.in_features[0]
for i in range(1, len(self.in_features)):
feature += self.in_features[i]
process_layers.append(MLP(feature, hidden_features, hidden_features, hidden_layers, norm, hidden_activation, hidden_activation))
output_layers.append(torch.nn.Linear(hidden_features, 1))
feature = hidden_features
self.process_layers = torch.nn.ModuleList(process_layers)
self.output_layers = torch.nn.ModuleList(output_layers)
[docs]
def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
assert len(inputs) == len(self.in_features)
last_feature = inputs[0]
result = 0
for i, process_layer, output_layer in zip(range(1, len(self.in_features)), self.process_layers, self.output_layers):
last_feature = torch.cat([last_feature, inputs[i]], dim=-1)
last_feature = process_layer(last_feature)
result += output_layer(last_feature)
return torch.sigmoid(result)
[docs]
class VectorizedMatcher(torch.nn.Module):
r"""
Initializes a feedforward matcher instance.
Args:
in_features (int): The number of input features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
hidden_activation (str): The activation function to use for the hidden layers.
norm (str): The normalization method to use. None if no normalization is used.
backbone_type (Union[str, np.str_]): The type of backbone to use.
"""
def __init__(self,
in_features : int,
hidden_features : int,
hidden_layers : int,
ensemble_size : int,
hidden_activation : str = 'leakyrelu',
norm : str = None,
backbone_type : Union[str, np.str_] = 'mlp',
**kwargs):
super().__init__()
self.backbone_type = backbone_type
self.ensemble_size = ensemble_size
if backbone_type == 'mlp':
self.backbone = VectorizedMLP(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation, output_activation='sigmoid')
elif backbone_type == 'res':
self.backbone = VectorizedResNet(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, output_activation='sigmoid')
else:
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs]
def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
x = torch.cat(inputs, dim=-1).detach()
# x.requires_grad = True
shape = x.shape
if len(shape) == 2: # [batch_size, dim]
x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [batch_size, dim] --> [ensemble_size, batch_size, dim]
assert x.shape[0] == self.ensemble_size
out = self.backbone(x) # [ensemble_size, batch_size, dim]
return out
[docs]
class EnsembleMatcher:
def __init__(self,
matcher_record: Dict[str, deque],
ensemble_size : int = 50,
ensemble_choosing_interval: int = 1,
config : dict = None) -> None:
self.matcher_weights_list = []
self.single_structure_dict = config
self.matching_nodes = config['matching_nodes']
self.matching_fit_nodes = config['matching_fit_nodes']
# self.matching_nodes_fit_dict = config['matching_nodes_fit_index']
self.device = torch.device("cpu")
self.min_generate_score = 0.0
self.max_expert_score = 1.0
self.selected_ind = []
counter = 0
for i in reversed(range(len(matcher_record['state_dicts']))):
if counter % ensemble_choosing_interval == 0:
matcher_weight = matcher_record['state_dicts'][i]
self.selected_ind.append(i)
self.matcher_weights_list.append(matcher_weight)
if len(self.matcher_weights_list) >= ensemble_size:
break
counter += 1
self.ensemble_size = len(self.matcher_weights_list)
self.matcher_ensemble = VectorizedMatcher(ensemble_size=self.ensemble_size, **config)
self.init_ensemble_network()
self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores'])
[docs]
def load_min_max(self, expert_scores: deque, generate_scores: deque):
self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ]
self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ]
self.max_expert_score = max(self.expert_scores)
self.min_generate_score = min(self.generate_scores)
self.range = max(abs(self.max_expert_score), abs(self.min_generate_score))
[docs]
def init_ensemble_network(self,):
if self.matcher_ensemble.backbone_type == 'mlp':
with torch.no_grad():
for ind, state_dict in enumerate(self.matcher_weights_list):
for key, val in state_dict.items():
layer_ind = int(key.split('.')[2])
if key.endswith("weight"):
self.matcher_ensemble.backbone.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
self.matcher_ensemble.backbone.net[layer_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
elif self.matcher_ensemble.backbone_type == 'res':
with torch.no_grad():
for ind, state_dict in enumerate(self.matcher_weights_list):
for key, val in state_dict.items():
key_ls = key.split('.')
if "skip_net" in key_ls:
block_ind = int(key_ls[2])
if key.endswith("weight"):
self.matcher_ensemble.backbone.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
self.matcher_ensemble.backbone.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
elif "process_net" in key_ls:
block_ind, layer_ind = int(key_ls[2]), int(key_ls[4])
if key.endswith("weight"):
self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
else:
# linear layer at last
block_ind = int(key_ls[2])
if key.endswith("weight"):
self.matcher_ensemble.backbone.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
self.matcher_ensemble.backbone.resnet[block_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
else:
raise NotImplementedError
[docs]
def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor:
with torch.no_grad():
scores = self.matcher_ensemble(inputs) # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1]
scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size]
if aggregation == "mean":
return scores.mean(-1)
return scores
[docs]
def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor:
with torch.no_grad():
scores = self.matcher_ensemble(inputs) # [ensemble_size, batch_size, 1]
scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size]
if clip:
assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand"
scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores)
scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range
if aggregation == "mean":
return scores.mean(-1)
return scores
[docs]
def to(self, device):
self.device = device
self.matcher_ensemble = self.matcher_ensemble.to(device)
self.expert_scores = self.expert_scores.to(device)
self.generate_scores = self.generate_scores.to(device)
# class SequentialMatcher:
# def __init__(self,
# matcher_record: Dict[str, deque],
# ensemble_size : int = 50,
# ensemble_choosing_interval: int = 1,
# config : dict = None) -> None:
# self.matcher_list = []
# self.single_structure_dict = config
# self.matching_nodes = config['matching_nodes']
# self.matching_fit_nodes = config['matching_fit_nodes']
# # self.matching_nodes_fit_dict = config['matching_nodes_fit_index']
# self.device = torch.device("cpu")
# self.min_generate_score = 0.0
# self.max_expert_score = 1.0
# self.selected_ind = []
# counter = 0
# for i in reversed(range(len(matcher_record['state_dicts']))):
# if counter % ensemble_choosing_interval == 0:
# matcher_weight = matcher_record['state_dicts'][i]
# self.selected_ind.append(i)
# _matcher = FeedForwardMatcher(**config)
# _matcher.load_state_dict(matcher_weight)
# self.matcher_list.append(_matcher)
# if len(self.matcher_list) >= ensemble_size:
# break
# counter += 1
# self.ensemble_size = len(self.matcher_list)
# self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores'])
# # prepare for vmap func
# self.params, self.buffers = stack_module_state(self.matcher_list)
# """
# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
# """
# self.base_model = deepcopy(self.matcher_list[0])
# self.base_model = self.base_model.to('meta')
# def vmap_forward(self, x):
# def vmap_model(params, buffers, x):
# return functional_call(self.base_model, (params, buffers), (x,))
# return vmap(vmap_model, in_dims=(0, 0, None))(self.params, self.buffers, x)
# def load_min_max(self, expert_scores: deque, generate_scores: deque):
# self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ]
# self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ]
# self.max_expert_score = max(self.expert_scores)
# self.min_generate_score = min(self.generate_scores)
# self.range = max(abs(self.max_expert_score), abs(self.min_generate_score))
# def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor:
# with torch.no_grad():
# scores = self.vmap_forward(inputs) # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1]
# scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size]
# if aggregation == "mean":
# return scores.mean(-1)
# return scores
# def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor:
# with torch.no_grad():
# scores = self.vmap_forward(inputs) # [ensemble_size, batch_size, 1]
# scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size]
# if clip:
# assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand"
# scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores)
# scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range
# if aggregation == "mean":
# return scores.mean(-1)
# return scores
# def to(self, device):
# self.device = device
# self.matcher_list = [_matcher.to(device) for _matcher in self.matcher_list]
# self.expert_scores = self.expert_scores.to(device)
# self.generate_scores = self.generate_scores.to(device)
# # prepare for vmap func
# self.params, self.buffers = stack_module_state(self.matcher_list)
# ------------------------------- Others ------------------------------- #
[docs]
class VectorizedCritic(VectorizedMLP):
[docs]
def forward(self, x : torch.Tensor) -> torch.Tensor:
x = super(VectorizedCritic, self).forward(x)
return x.squeeze(-1)
# ------------------------------- Test ------------------------------- #
if __name__ == "__main__":
def test_ensemble_mlp():
rnd = torch.rand((6, 5))
ensemble_size = 3
ensemble_linear = VectorizedMLP(5, 4, hidden_features=16, hidden_layers=2, ensemble_size=ensemble_size)
linear1 = MLP(5, 4, hidden_features=16, hidden_layers=2)
linear2 = MLP(5, 4, hidden_features=16, hidden_layers=2)
linear3 = MLP(5, 4, hidden_features=16, hidden_layers=2)
state_dict = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()])
for ind, state_dict in enumerate(state_dict):
for key, val in state_dict.items():
layer_ind = int(key.split('.')[1])
if key.endswith("weight"):
ensemble_linear.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
ensemble_linear.net[layer_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
ensemble_out = ensemble_linear(rnd)
linear1_out = linear1(rnd)
linear2_out = linear2(rnd)
linear3_out = linear3(rnd)
# breakpoint()
print((ensemble_out[0, ...] - linear1_out).abs().max())
print((ensemble_out[1, ...] - linear2_out).abs().max())
print((ensemble_out[2, ...] - linear3_out).abs().max())
def test_ensemble_res_distinct():
in_fea = 256
out_fea = 15
hidden_features=256
hidden_layers=4
batch_size = 512
ensemble_size = 3
rnd1 = torch.rand((batch_size, in_fea))
rnd2 = torch.rand((batch_size, in_fea))
rnd3 = torch.rand((batch_size, in_fea))
rnd = torch.stack([rnd1, rnd2, rnd3])
ensemble_linear = VectorizedResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, ensemble_size=ensemble_size, norm=None)
linear1 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None)
linear2 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None)
linear3 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None)
state_dicts = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()])
for ind, state_dict in enumerate(state_dicts):
for key, val in state_dict.items():
key_ls = key.split('.')
if "skip_net" in key_ls:
block_ind = int(key_ls[1])
if key.endswith("weight"):
ensemble_linear.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
ensemble_linear.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
elif "process_net" in key_ls:
block_ind, layer_ind = int(key_ls[1]), int(key_ls[3])
if key.endswith("weight"):
ensemble_linear.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
ensemble_linear.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
else:
# linear layer at last
block_ind = int(key_ls[1])
if key.endswith("weight"):
ensemble_linear.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
elif key.endswith("bias"):
ensemble_linear.resnet[block_ind].bias.data[ind, 0, ...].copy_(val)
else:
raise NotImplementedError("Only except network params with weight and bias")
ensemble_out = ensemble_linear(rnd)
linear1_out = linear1(rnd1)
linear2_out = linear2(rnd2)
linear3_out = linear3(rnd3)
# breakpoint()
print((ensemble_out[0, ...] - linear1_out).abs().max())
print((ensemble_out[1, ...] - linear2_out).abs().max())
print((ensemble_out[2, ...] - linear3_out).abs().max())
[docs]
class Value_Net_VectorizedCritic(nn.Module):
r"""
Initializes a vectorized linear layer instance.
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
"""
def __init__(self,
input_dim_dict,
q_hidden_features,
q_hidden_layers,
num_q_net,
*args, **kwargs) -> None:
super().__init__()
self.kwargs = kwargs
self.input_dim_dict = input_dim_dict
if isinstance(input_dim_dict, dict):
input_dim = sum(input_dim_dict.values())
self.value =VectorizedMLP (input_dim,
1,
q_hidden_features,
q_hidden_layers,
num_q_net)
if 'ts_conv_config' in self.kwargs.keys() and self.kwargs['ts_conv_config'] != None:
other_state_layer_output = q_hidden_features//2
ts_state_conv_layer_output = q_hidden_features//2 + q_hidden_features%2
self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0)
all_node_ts = None
self.conv_ts_node = []
self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input'])
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
self.conv_ts_node.append('__history__'+tsnode)
self.conv_other_node.append('__now__'+tsnode)
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
if all_node_ts == None:
all_node_ts = node_ts
assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}"
#ts-1 remove now state
kernalsize = all_node_ts-1
self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize)
[docs]
def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
if hasattr(self, 'conv_ts_node'):
#deal with all ts nodes with ts dim
assert hasattr(self, 'conv_ts_node')
assert hasattr(self, 'conv_other_node')
inputdata = deepcopy(state.detach()) #data is deteched from other node
for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True
batch_size = np.prod(list(inputdata[tsnode].shape[:-1]))
original_size = list(inputdata[tsnode].shape[:-1])
node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
# inputdata[tsnode]
temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1)
inputdata['__history__'+tsnode] = temp_data[..., :-1, :]
inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1])
state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1)
state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1)
state_other = self.other_state_layer(state_other)
original_size = list(state_other.shape[:-1])
state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1])
state = torch.cat([state_other, state_ts_conv], dim=-1)
else:
pass
# state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1)
output = self.value(state)
output = output.squeeze(-1)
return output