Source code for revive.computation.modules

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
from loguru import logger
import time
import torch
import math
import warnings
from copy import deepcopy
from torch import nn#, vmap
# from torch.func import stack_module_state, functional_call
from typing import Optional, Union, List, Dict, cast
from collections import OrderedDict, deque

from revive.computation.dists import *
from revive.computation.utils import *


[docs] def reglu(x): a, b = x.chunk(2, dim=-1) return a * F.relu(b)
[docs] def geglu(x): a, b = x.chunk(2, dim=-1) return a * F.gelu(b)
[docs] class Swish(nn.Module):
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: return x * torch.sigmoid(x)
ACTIVATION_CREATORS = { 'relu' : lambda dim: nn.ReLU(inplace=True), 'elu' : lambda dim: nn.ELU(), 'leakyrelu' : lambda dim: nn.LeakyReLU(negative_slope=0.1, inplace=True), 'tanh' : lambda dim: nn.Tanh(), 'sigmoid' : lambda dim: nn.Sigmoid(), 'identity' : lambda dim: nn.Identity(), 'prelu' : lambda dim: nn.PReLU(dim), 'gelu' : lambda dim: nn.GELU(), 'reglu' : reglu, 'geglu' : geglu, 'swish' : lambda dim: Swish(), } # -------------------------------- Layers -------------------------------- #
[docs] class Value_Net(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, input_dim_dict, value_hidden_features, value_hidden_layers, value_normalization, value_activation, *args, **kwargs) -> None: super().__init__() self.kwargs = kwargs self.input_dim_dict = input_dim_dict if isinstance(input_dim_dict, dict): input_dim = sum(input_dim_dict.values()) else: input_dim = input_dim_dict self.value = MLP(input_dim, 1, value_hidden_features, value_hidden_layers, norm=value_normalization, hidden_activation=value_activation) if 'ts_conv_config' in self.kwargs.keys() and self.kwargs['ts_conv_config'] != None: other_state_layer_output = input_dim//2 ts_state_conv_layer_output = input_dim//2 + input_dim%2 self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=value_activation) all_node_ts = None self.conv_ts_node = [] self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input']) for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): self.conv_ts_node.append('__history__'+tsnode) self.conv_other_node.append('__now__'+tsnode) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] if all_node_ts == None: all_node_ts = node_ts assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}" #ts-1 remove now state kernalsize = all_node_ts-1 self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=value_activation)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: if hasattr(self, 'conv_ts_node'): #deal with all ts nodes with ts dim assert hasattr(self, 'conv_ts_node') assert hasattr(self, 'conv_other_node') inputdata = deepcopy(state.detach()) #data is deteched from other node for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): # assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True batch_size = np.prod(list(inputdata[tsnode].shape[:-1])) original_size = list(inputdata[tsnode].shape[:-1]) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] # inputdata[tsnode] temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1) inputdata['__history__'+tsnode] = temp_data[..., :-1, :] inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1]) state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1) state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1) state_other = self.other_state_layer(state_other) original_size = list(state_other.shape[:-1]) state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1]) state = torch.cat([state_other, state_ts_conv], dim=-1) else: state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1) output = self.value(state) return output
[docs] class ConvBlock(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, in_features: int, out_features: int, conv_time_step: int, output_activation : str = 'identity', ): super().__init__() output_activation_creator = ACTIVATION_CREATORS[output_activation] # self.in_features = in_features # self.out_features = out_features self.mlp_net = nn.Sequential( nn.Linear(in_features, out_features), output_activation_creator(out_features) ) self.conv_net = nn.Conv1d(in_channels=out_features, out_channels=out_features, groups=out_features, kernel_size=conv_time_step)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp_net(x) #exchange timestep dim with feature dime x = x.transpose(-1,-2) #combine the horizon and batchsize dims out = self.conv_net(x).squeeze(-1) return out
[docs] class VectorizedLinear(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, in_features: int, out_features: int, ensemble_size: int): super().__init__() self.in_features = in_features self.out_features = out_features self.ensemble_size = ensemble_size self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features)) self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features)) self.reset_parameters()
[docs] def reset_parameters(self): # default pytorch init for nn.Linear module for layer in range(self.ensemble_size): nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5)) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0]) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ input: [ensemble_size, batch_size, input_size] weight: [ensemble_size, input_size, out_size] out: [ensemble_size, batch_size, out_size] """ # out = torch.einsum("kij,kjl->kil", x, self.weight) + self.bias out = x @ self.weight + self.bias # out = torch.bmm(x, self.weight) + self.bias return out
[docs] class EnsembleLinear(nn.Module): def __init__( self, input_dim: int, output_dim: int, num_ensemble: int, weight_decay: float = 0.0 ) -> None: super().__init__() self.num_ensemble = num_ensemble self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim))) self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim))) nn.init.trunc_normal_(self.weight, std=1 / (2 * input_dim ** 0.5)) self.register_parameter("saved_weight", nn.Parameter(self.weight.detach().clone())) self.register_parameter("saved_bias", nn.Parameter(self.bias.detach().clone())) self.weight_decay = weight_decay self.device = torch.device('cpu')
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: weight = self.weight bias = self.bias if len(x.shape) == 2: x = torch.einsum('ij,bjk->bik', x, weight) elif len(x.shape) == 3: if x.shape[0] == weight.data.shape[0]: x = torch.einsum('bij,bjk->bik', x, weight) else: x = torch.einsum('cij,bjk->bcik', x, weight) elif len(x.shape) == 4: if x.shape[0] == weight.data.shape[0]: x = torch.einsum('cbij,cjk->cbik', x, weight) else: x = torch.einsum('cdij,bjk->bcdik', x, weight) elif len(x.shape) == 5: x = torch.einsum('bcdij,bjk->bcdik', x, weight) assert x.shape[0] == bias.shape[0] and x.shape[-1] == bias.shape[-1] if len(x.shape) == 4: bias = bias.unsqueeze(1) elif len(x.shape) == 5: bias = bias.unsqueeze(1) bias = bias.unsqueeze(1) x = x + bias return x
[docs] def load_save(self) -> None: self.weight.data.copy_(self.saved_weight.data) self.bias.data.copy_(self.saved_bias.data)
[docs] def update_save(self, indexes: List[int]) -> None: self.saved_weight.data[indexes] = self.weight.data[indexes] self.saved_bias.data[indexes] = self.bias.data[indexes]
[docs] def get_decay_loss(self) -> torch.Tensor: decay_loss = self.weight_decay * (0.5 * ((self.weight ** 2).sum())) return decay_loss
[docs] def to(self, device): if not device == self.device: self.device = device super().to(device) self.weight = self.weight.to(self.device) self.bias = self.bias.to(self.device) self.saved_weight = self.saved_weight.to(self.device) self.saved_bias = self.saved_bias.to(self.device)
# -------------------------------- Backbones -------------------------------- #
[docs] class MLP(nn.Module): r""" Multi-layer Perceptron Args: in_features : int, features numbers of the input out_features : int, features numbers of the output hidden_features : int, features numbers of the hidden layers hidden_layers : int, numbers of the hidden layers norm : str, normalization method between hidden layers, default : None hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu' output_activation : str, activation function used in output layer, default : 'identity' """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str = None, hidden_activation : str = 'leakyrelu', output_activation : str = 'identity'): super(MLP, self).__init__() hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation] output_activation_creator = ACTIVATION_CREATORS[output_activation] if hidden_layers == 0: self.net = nn.Sequential( nn.Linear(in_features, out_features), output_activation_creator(out_features) ) else: net = [] for i in range(hidden_layers): net.append(nn.Linear(in_features if i == 0 else hidden_features, hidden_features)) if norm: if norm == 'ln': net.append(nn.LayerNorm(hidden_features)) elif norm == 'bn': net.append(nn.BatchNorm1d(hidden_features)) else: raise NotImplementedError(f'{norm} does not supported!') net.append(hidden_activation_creator(hidden_features)) net.append(nn.Linear(hidden_features, out_features)) net.append(output_activation_creator(out_features)) self.net = nn.Sequential(*net)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: r""" forward method of MLP only assume the last dim of x matches `in_features` """ return self.net(x)
[docs] class VectorizedMLP(nn.Module): r""" Vectorized MLP Args: in_features : int, features numbers of the input out_features : int, features numbers of the output hidden_features : int, features numbers of the hidden layers hidden_layers : int, numbers of the hidden layers norm : str, normalization method between hidden layers, default : None hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu' output_activation : str, activation function used in output layer, default : 'identity' """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, norm : str = None, hidden_activation : str = 'leakyrelu', output_activation : str = 'identity'): super(VectorizedMLP, self).__init__() self.ensemble_size = ensemble_size hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation] output_activation_creator = ACTIVATION_CREATORS[output_activation] if hidden_layers == 0: self.net = nn.Sequential( VectorizedLinear(in_features, out_features, ensemble_size), output_activation_creator(out_features) ) else: net = [] for i in range(hidden_layers): net.append(VectorizedLinear(in_features if i == 0 else hidden_features, hidden_features, ensemble_size)) if norm: if norm == 'ln': net.append(nn.LayerNorm(hidden_features)) elif norm == 'bn': net.append(nn.BatchNorm1d(hidden_features)) else: raise NotImplementedError(f'{norm} does not supported!') net.append(hidden_activation_creator(hidden_features)) net.append(VectorizedLinear(hidden_features, out_features, ensemble_size)) net.append(output_activation_creator(out_features)) self.net = nn.Sequential(*net)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: r""" forward method of MLP only assume the last dim of x matches `in_features` """ if x.dim() == 2: # [batch_size, dim] x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [ensemble_size, batch_size, x_dim] assert x.dim() == 3 assert x.shape[0] == self.ensemble_size return self.net(x)
[docs] class ResBlock(nn.Module): """ Initializes a residual block instance. Args: input_feature (int): The number of input features to the block. output_feature (int): The number of output features from the block. norm (str, optional): The type of normalization to apply to the block. Default is 'ln' for layer normalization. """ def __init__(self, input_feature : int, output_feature : int, norm : str = 'ln'): super().__init__() if norm == 'ln': norm_class = torch.nn.LayerNorm self.process_net = torch.nn.Sequential( torch.nn.Linear(input_feature, output_feature), norm_class(output_feature), torch.nn.ReLU(True), torch.nn.Linear(output_feature, output_feature), norm_class(output_feature), torch.nn.ReLU(True) ) else: self.process_net = torch.nn.Sequential( torch.nn.Linear(input_feature, output_feature), torch.nn.ReLU(True), torch.nn.Linear(output_feature, output_feature), torch.nn.ReLU(True) ) if not input_feature == output_feature: self.skip_net = torch.nn.Linear(input_feature, output_feature) else: self.skip_net = torch.nn.Identity()
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: '''x should be a 2D Tensor due to batchnorm''' return self.process_net(x) + self.skip_net(x)
[docs] class VectorizedResBlock(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, in_features: int, out_features: int, ensemble_size: int, norm : str = 'ln'): super().__init__() self.in_features = in_features self.out_features = out_features self.ensemble_size = ensemble_size if norm == 'ln': norm_class = torch.nn.LayerNorm self.process_net = torch.nn.Sequential( VectorizedLinear(in_features, out_features, ensemble_size), norm_class(out_features), torch.nn.ReLU(True), VectorizedLinear(out_features, out_features, ensemble_size), norm_class(out_features), torch.nn.ReLU(True) ) else: self.process_net = torch.nn.Sequential( VectorizedLinear(in_features, out_features, ensemble_size), torch.nn.ReLU(True), VectorizedLinear(out_features, out_features, ensemble_size), torch.nn.ReLU(True) ) if not in_features == out_features: self.skip_net = VectorizedLinear(in_features, out_features, ensemble_size) else: self.skip_net = torch.nn.Identity()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # input: [ensemble_size, batch_size, input_size] # out: [ensemble_size, batch_size, out_size] out = self.process_net(x) + self.skip_net(x) return out
[docs] class ResNet(torch.nn.Module): """ Initializes a residual neural network (ResNet) instance. Args: in_features (int): The number of input features to the ResNet. out_features (int): The number of output features from the ResNet. hidden_features (int): The number of hidden features in each residual block. hidden_layers (int): The number of residual blocks in the ResNet. norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization. output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str = 'bn', output_activation : str = 'identity'): super().__init__() modules = [] for i in range(hidden_layers): if i == 0: modules.append(ResBlock(in_features, hidden_features, norm)) else: modules.append(ResBlock(hidden_features, hidden_features, norm)) modules.append(torch.nn.Linear(hidden_features, out_features)) modules.append(ACTIVATION_CREATORS[output_activation](out_features)) self.resnet = torch.nn.Sequential(*modules)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: '''NOTE: reshape is needed since resblock only support 2D Tensor''' shape = x.shape x = x.view(-1, shape[-1]) output = self.resnet(x) output = output.view(*shape[:-1], -1) return output
[docs] class VectorizedResNet(torch.nn.Module): """ Initializes a residual neural network (ResNet) instance. Args: in_features (int): The number of input features to the ResNet. out_features (int): The number of output features from the ResNet. hidden_features (int): The number of hidden features in each residual block. hidden_layers (int): The number of residual blocks in the ResNet. norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization. output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, norm : str = 'bn', output_activation : str = 'identity'): super().__init__() self.ensemble_size = ensemble_size modules = [] for i in range(hidden_layers): if i == 0: modules.append(VectorizedResBlock(in_features, hidden_features, ensemble_size, norm)) else: modules.append(VectorizedResBlock(hidden_features, hidden_features, ensemble_size, norm)) modules.append(VectorizedLinear(hidden_features, out_features, ensemble_size)) modules.append(ACTIVATION_CREATORS[output_activation](out_features)) self.resnet = torch.nn.Sequential(*modules)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: '''NOTE: reshape is needed since resblock only support 2D Tensor''' shape = x.shape if len(shape) == 2: x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [batch_size, dim] --> [ensemble_size, batch_size, dim] assert x.dim() == 3, f"shape == {shape}, with length == {len(shape)}, expect length 3" assert x.shape[0] == self.ensemble_size # [ensemble_size, batch_size, dim] output = self.resnet(x) return output
[docs] class TAPE(nn.Module): def __init__(self, d_model, max_len=200, scale_factor=1.0): super(TAPE, self).__init__() pe = torch.zeros(max_len, d_model) # positional encoding position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin((position * div_term)*(d_model/max_len)) pe[:, 1::2] = torch.cos((position * div_term)*(d_model/max_len)) pe = scale_factor * pe self.register_buffer('pe', pe) # this stores the variable in the state_dict (used for non-trainable variables)
[docs] def forward(self, x): return x + self.pe[:x.shape[1], :]
[docs] class TAttention(nn.Module): def __init__(self, d_model, nhead, dropout): super().__init__() self.d_model = d_model self.nhead = nhead self.qtrans = nn.Linear(d_model, d_model, bias=False) self.ktrans = nn.Linear(d_model, d_model, bias=False) self.vtrans = nn.Linear(d_model, d_model, bias=False) self.attn_dropout = [] if dropout > 0: for i in range(nhead): self.attn_dropout.append(nn.Dropout(p=dropout)) self.attn_dropout = nn.ModuleList(self.attn_dropout) # input LayerNorm self.norm1 = nn.LayerNorm(d_model, eps=1e-5) # FFN layerNorm self.norm2 = nn.LayerNorm(d_model, eps=1e-5) # FFN self.ffn = nn.Sequential( nn.Linear(d_model, d_model), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(d_model, d_model), nn.Dropout(p=dropout) )
[docs] def forward(self, x): x = self.norm1(x) q = self.qtrans(x) k = self.ktrans(x) v = self.vtrans(x) dim = int(self.d_model / self.nhead) att_output = [] for i in range(self.nhead): if i==self.nhead-1: qh = q[:, :, i * dim:] kh = k[:, :, i * dim:] vh = v[:, :, i * dim:] else: qh = q[:, :, i * dim:(i + 1) * dim] kh = k[:, :, i * dim:(i + 1) * dim] vh = v[:, :, i * dim:(i + 1) * dim] atten_ave_matrixh = torch.softmax(torch.matmul(qh, kh.transpose(1, 2)), dim=-1) if self.attn_dropout: atten_ave_matrixh = self.attn_dropout[i](atten_ave_matrixh) att_output.append(torch.matmul(atten_ave_matrixh, vh)) att_output = torch.concat(att_output, dim=-1) # FFN xt = x + att_output xt = self.norm2(xt) att_output = xt + self.ffn(xt) return att_output
[docs] class TsTransformer(nn.Module): def __init__(self, in_features, out_features, hidden_layers, hidden_features, dropout=0.): super(TsTransformer,self).__init__() layers = [] layers.append(nn.Linear(in_features, hidden_features)) layers.append(TAPE(hidden_features)) for i in range(hidden_layers): layers.append(TAttention(hidden_features, nhead=4, dropout=dropout)) layers.append(nn.Linear(hidden_features, out_features)) self.layers = nn.Sequential(*layers)
[docs] def forward(self, x): if len(x.shape) == 2: x = x.reshape(x.shape[0],-1,2) output = self.layers(x).squeeze(-1)[:,-1:] return output
[docs] class Transformer1D(nn.Module): """ This is an experimental backbone. """ def __init__(self, in_features : int, out_features : int, transformer_features : int = 16, transformer_heads : int = 8, transformer_layers : int = 4): super().__init__() self.linear = torch.nn.Linear(in_features, out_features) self.register_parameter('in_weight', torch.nn.Parameter(torch.randn(out_features, transformer_features))) self.register_parameter('in_bais', torch.nn.Parameter(torch.zeros(out_features, transformer_features))) self.register_parameter('out_weight', torch.nn.Parameter(torch.randn(out_features, transformer_features))) self.register_parameter('out_bais', torch.nn.Parameter(torch.zeros(out_features))) torch.nn.init.xavier_normal_(self.in_weight) torch.nn.init.zeros_(self.out_weight) # zero initialize encoder_layer = torch.nn.TransformerEncoderLayer(transformer_features, transformer_heads, 512) self.transformer = torch.nn.TransformerEncoder(encoder_layer, transformer_layers)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: shape = x.shape x = x.view(-1, x.shape[-1]) # [B, I] x = self.linear(x) # [B, O] x = x.unsqueeze(dim=-1) # [B, O, 1] x = x * self.in_weight + self.in_bais # [B, O, F] x = x.permute(1, 0, 2).contiguous() # [O, B, F] x = self.transformer(x) # [O, B, F] x = x.permute(1, 0, 2).contiguous() # [B, O, F] x = torch.sum(x * self.out_weight, dim=-1) + self.out_bais # [B, O] x = x.view(*shape[:-1], x.shape[-1]) return x
[docs] class Tokenizer(nn.Module): category_offsets: Optional[torch.Tensor] def __init__( self, d_numerical: int, categories: Optional[List[int]], d_token: int, bias: bool, ) -> None: super().__init__() if categories is None: d_bias = d_numerical self.category_offsets = None self.category_embeddings = None else: d_bias = d_numerical + len(categories) category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0) self.register_buffer('category_offsets', category_offsets) self.category_embeddings = nn.Embedding(sum(categories), d_token) nn.init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5)) print(f'{self.category_embeddings.weight.shape=}') # take [CLS] token into account self.weight = nn.Parameter(torch.Tensor(d_numerical + 1, d_token)) self.bias = nn.Parameter(torch.Tensor(d_bias, d_token)) if bias else None # The initialization is inspired by nn.Linear nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5)) @property def n_tokens(self) -> int: return len(self.weight) + ( 0 if self.category_offsets is None else len(self.category_offsets) )
[docs] def forward(self, x_num: torch.Tensor, x_cat: Optional[torch.Tensor]): x_some = x_num if x_cat is None else x_cat assert x_some is not None x_num = torch.cat( [torch.ones(len(x_some), 1, device=x_some.device)] # [CLS] + ([] if x_num is None else [x_num]), dim=1, ) x = self.weight[None] * x_num[:, :, None] if x_cat is not None: x = torch.cat( [x, self.category_embeddings(x_cat + self.category_offsets[None])], dim=1, ) if self.bias is not None: bias = torch.cat( [ torch.zeros(1, self.bias.shape[1], device=x.device), self.bias, ] ) x = x + bias[None] return x
[docs] class MultiheadAttention(nn.Module): def __init__( self, d: int, n_heads: int, dropout: float, initialization: str ) -> None: if n_heads > 1: assert d % n_heads == 0 assert initialization in ['xavier', 'kaiming'] super().__init__() self.W_q = nn.Linear(d, d) self.W_k = nn.Linear(d, d) self.W_v = nn.Linear(d, d) self.W_out = nn.Linear(d, d) if n_heads > 1 else None self.n_heads = n_heads self.dropout = nn.Dropout(dropout) if dropout else None for m in [self.W_q, self.W_k, self.W_v]: if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v): # gain is needed since W_qkv is represented with 3 separate layers nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2)) nn.init.zeros_(m.bias) if self.W_out is not None: nn.init.zeros_(self.W_out.bias) def _reshape(self, x): batch_size, n_tokens, d = x.shape d_head = d // self.n_heads return ( x.reshape(batch_size, n_tokens, self.n_heads, d_head) .transpose(1, 2) .reshape(batch_size * self.n_heads, n_tokens, d_head) )
[docs] def forward( self, x_q, x_kv, key_compression: Optional[nn.Linear], value_compression: Optional[nn.Linear], ): q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv) for tensor in [q, k, v]: assert tensor.shape[-1] % self.n_heads == 0 if key_compression is not None: assert value_compression is not None k = key_compression(k.transpose(1, 2)).transpose(1, 2) v = value_compression(v.transpose(1, 2)).transpose(1, 2) else: assert value_compression is None batch_size = len(q) d_head_key = k.shape[-1] // self.n_heads d_head_value = v.shape[-1] // self.n_heads n_q_tokens = q.shape[1] q = self._reshape(q) k = self._reshape(k) attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1) if self.dropout is not None: attention = self.dropout(attention) x = attention @ self._reshape(v) x = ( x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value) .transpose(1, 2) .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value) ) if self.W_out is not None: x = self.W_out(x) return x
[docs] class FT_Transformer(nn.Module): """FT_Transformer References: - https://github.com/Yura52/tabular-dl-revisiting-models/blob/main/bin/ft_transformer.py """ def __init__( self, in_features : int, out_features : int, hidden_features : int =256, hidden_layers : int = 5, categories: Optional[List[int]]=None, token_bias: bool = True, # transformer n_heads: int = 8, d_ffn_factor: float = 1.33, attention_dropout: float = 0.0, ffn_dropout: float = 0.0, residual_dropout: float = 0.0, activation: str = "reglu", prenormalization: bool = True, initialization: str = "kaiming", # linformer kv_compression: Optional[float] = None, kv_compression_sharing: Optional[str] = None, ) -> None: d_numerical = in_features d_out = out_features n_layers = hidden_layers d_token = hidden_features assert (kv_compression is None) ^ (kv_compression_sharing is not None) super().__init__() self.tokenizer = Tokenizer(d_numerical, categories, d_token, token_bias) n_tokens = self.tokenizer.n_tokens def make_kv_compression(): assert kv_compression compression = nn.Linear( n_tokens, int(n_tokens * kv_compression), bias=False ) if initialization == 'xavier': nn.init.xavier_uniform_(compression.weight) return compression self.shared_kv_compression = ( make_kv_compression() if kv_compression and kv_compression_sharing == 'layerwise' else None ) def make_normalization(): return nn.LayerNorm(d_token) d_hidden = int(d_token * d_ffn_factor) self.layers = nn.ModuleList([]) for layer_idx in range(n_layers): layer = nn.ModuleDict( { 'attention': MultiheadAttention( d_token, n_heads, attention_dropout, initialization ), 'linear0': nn.Linear( d_token, d_hidden * (2 if activation.endswith('glu') else 1) ), 'linear1': nn.Linear(d_hidden, d_token), 'norm1': make_normalization(), } ) if not prenormalization or layer_idx: layer['norm0'] = make_normalization() if kv_compression and self.shared_kv_compression is None: layer['key_compression'] = make_kv_compression() if kv_compression_sharing == 'headwise': layer['value_compression'] = make_kv_compression() else: assert kv_compression_sharing == 'key-value' self.layers.append(layer) self.activation = ACTIVATION_CREATORS[activation] self.last_activation = ACTIVATION_CREATORS[activation] self.prenormalization = prenormalization self.last_normalization = make_normalization() if prenormalization else None self.ffn_dropout = ffn_dropout self.residual_dropout = residual_dropout self.head = nn.Linear(int(d_token//2), d_out) def _get_kv_compressions(self, layer): return ( (self.shared_kv_compression, self.shared_kv_compression) if self.shared_kv_compression is not None else (layer['key_compression'], layer['value_compression']) if 'key_compression' in layer and 'value_compression' in layer else (layer['key_compression'], layer['key_compression']) if 'key_compression' in layer else (None, None) ) def _start_residual(self, x, layer, norm_idx): x_residual = x if self.prenormalization: norm_key = f'norm{norm_idx}' if norm_key in layer: x_residual = layer[norm_key](x_residual) return x_residual def _end_residual(self, x, x_residual, layer, norm_idx): if self.residual_dropout: x_residual = F.dropout(x_residual, self.residual_dropout, self.training) x = x + x_residual if not self.prenormalization: x = layer[f'norm{norm_idx}'](x) return x
[docs] def forward(self, x_num, x_cat=None): if len(x_num.shape) == 3: x = x_num.view(-1, x_num.shape[-1]) else: x = x_num x = self.tokenizer(x, x_cat) for layer_idx, layer in enumerate(self.layers): is_last_layer = layer_idx + 1 == len(self.layers) layer = cast(Dict[str, nn.Module], layer) x_residual = self._start_residual(x, layer, 0) x_residual = layer['attention']( # for the last attention, it is enough to process only [CLS] (x_residual[:, :1] if is_last_layer else x_residual), x_residual, *self._get_kv_compressions(layer), ) if is_last_layer: x = x[:, : x_residual.shape[1]] x = self._end_residual(x, x_residual, layer, 0) x_residual = self._start_residual(x, layer, 1) x_residual = layer['linear0'](x_residual) x_residual = self.activation(x_residual) if self.ffn_dropout: x_residual = F.dropout(x_residual, self.ffn_dropout, self.training) x_residual = layer['linear1'](x_residual) x = self._end_residual(x, x_residual, layer, 1) assert x.shape[1] == 1 x = x[:, 0] if self.last_normalization is not None: x = self.last_normalization(x) x = self.last_activation(x) x = self.head(x) x = x.squeeze(-1) if len(x_num.shape) == 3: x = x.view(-1, x_num.shape[1], x.shape[-1]) return x
[docs] class DistributionWrapper(nn.Module): r"""wrap output of Module to distribution""" BASE_TYPES = ['normal', 'gmm', 'onehot', 'discrete_logistic'] SUPPORTED_TYPES = BASE_TYPES + ['mix'] def __init__(self, distribution_type : str = 'normal', **params): super().__init__() self.distribution_type = distribution_type self.params = params assert self.distribution_type in self.SUPPORTED_TYPES, f"{self.distribution_type} is not supported!" if self.distribution_type == 'normal': self.max_logstd = nn.Parameter(torch.ones(self.params['dim']) * 0.0, requires_grad=True) self.min_logstd = nn.Parameter(torch.ones(self.params['dim']) * -7, requires_grad=True) if not self.params.get('conditioned_std', True): self.logstd = nn.Parameter(torch.zeros(self.params['dim']), requires_grad=True) elif self.distribution_type == 'gmm': self.max_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * 0, requires_grad=True) self.min_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * -10, requires_grad=True) if not self.params.get('conditioned_std', True): self.logstd = nn.Parameter(torch.zeros(self.params['mixture'], self.params['dim']), requires_grad=True) elif self.distribution_type == 'discrete_logistic': self.num = self.params['num'] elif self.distribution_type == 'mix': assert 'dist_config' in self.params.keys(), "You need to provide `dist_config` for Mix distribution" self.dist_config = self.params['dist_config'] self.wrapper_list = [] self.input_sizes = [] self.output_sizes = [] for config in self.dist_config: dist_type = config['type'] assert dist_type in self.SUPPORTED_TYPES, f"{dist_type} is not supported!" assert not dist_type == 'mix', "recursive MixDistribution is not supported!" self.wrapper_list.append(DistributionWrapper(dist_type, **config)) self.input_sizes.append(config['dim']) self.output_sizes.append(config['output_dim']) self.wrapper_list = nn.ModuleList(self.wrapper_list)
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, payload : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: ''' Warp the given tensor to distribution :param adapt_std : it will overwrite the std part of the distribution (optional) :param payload : payload will be applied to the output distribution after built (optional) ''' # [ OTHER ] replace the clip with tanh # [ OTHER ] better std controlling strategy is needed todo if self.distribution_type == 'normal': if self.params.get('conditioned_std', True): mu, logstd = torch.chunk(x, 2, dim=-1) if 'soft_clamp' in kwargs and kwargs['soft_clamp'] == True: # distribution wrapper only for revive_f venv training logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd) std = torch.exp(logstd) else: # distribution wrapper for others logstd_logit = logstd max_std = 0.5 min_std = 0.001 std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std else: mu, logstd = x, self.logstd logstd_logit = self.logstd max_std = 0.5 min_std = 0.001 std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std # std = torch.exp(logstd) if payload is not None: mu = mu + safe_atanh(payload) mu = torch.tanh(mu) if adapt_std is not None: std = torch.ones_like(mu).to(mu) * adapt_std return DiagnalNormal(mu, std) elif self.distribution_type == 'gmm': if self.params.get('conditioned_std', True): logits, mus, logstds = torch.split(x, [self.params['mixture'], self.params['mixture'] * self.params['dim'], self.params['mixture'] * self.params['dim']], dim=-1) mus = mus.view(*mus.shape[:-1], self.params['mixture'], self.params['dim']) logstds = logstds.view(*logstds.shape[:-1], self.params['mixture'], self.params['dim']) else: logits, mus = torch.split(x, [self.params['mixture'], self.params['mixture'] * self.params['dim']], dim=-1) logstds = self.logstd if payload is not None: mus = mus + safe_atanh(payload.unsqueeze(dim=-2)) mus = torch.tanh(mus) stds = adapt_std if adapt_std is not None else torch.exp(soft_clamp(logstds, self.min_logstd, self.max_logstd)) return GaussianMixture(mus, stds, logits) elif self.distribution_type == 'onehot': return Onehot(x) elif self.distribution_type == 'discrete_logistic': mu, logstd = torch.chunk(x, 2, dim=-1) logstd = torch.clamp(logstd, -7, 1) return DiscreteLogistic(mu, torch.exp(logstd), num=self.num) elif self.distribution_type == 'mix': xs = torch.split(x, self.output_sizes, dim=-1) if isinstance(adapt_std, float): # print("look !!!") # print(self.input_sizes) # print(len(self.input_sizes)) adapt_stds = [adapt_std] * len(self.input_sizes) elif adapt_std is not None: adapt_stds = torch.split(adapt_std, self.input_sizes, dim=-1) else: adapt_stds = [None] * len(self.input_sizes) if payload is not None: payloads = torch.split(payload, self.input_sizes + [payload.shape[-1] - sum(self.input_sizes)], dim=-1)[:-1] else: payloads = [None] * len(self.input_sizes) dists = [wrapper(x, _adapt_std, _payload, **kwargs) for x, _adapt_std, _payload, wrapper in zip(xs, adapt_stds, payloads, self.wrapper_list)] return MixDistribution(dists)
[docs] def extra_repr(self) -> str: return 'type={}, dim={}'.format( self.distribution_type, self.params['dim'] if not self.distribution_type == 'mix' else len(self.wrapper_list) )
# --------------------------------- Policies -------------------------------- #
[docs] class FeedForwardPolicy(torch.nn.Module): """Policy for using mlp, resnet and transformer backbone. Args: in_features : The number of input features, or a dictionary describing the input distribution out_features : The number of output features hidden_features : The number of hidden features in each layer hidden_layers : The number of hidden layers dist_config : A list of configurations for the distributions to use in the model norm : The type of normalization to apply to the input features hidden_activation : The activation function to use in the hidden layers backbone_type : The type of backbone to use in the model use_multihead : Whether to use a multihead model use_feature_embed : Whether to use feature embedding """ def __init__(self, in_features : Union[int, dict], out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, norm : str = None, hidden_activation : str = 'leakyrelu', backbone_type : Union[str, np.str_] = 'mlp', use_multihead : bool = False, **kwargs): super().__init__() self.multihead = [] self.dist_config = dist_config self.kwargs = kwargs if isinstance(in_features, dict): in_features = sum(in_features.values()) if not use_multihead: if backbone_type == 'mlp': self.backbone = MLP(in_features, out_features, hidden_features, hidden_layers, norm, hidden_activation) elif backbone_type == 'res': self.backbone = ResNet(in_features, out_features, hidden_features, hidden_layers, norm) elif backbone_type == 'ft_transformer': self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers) else: raise NotImplementedError(f'backbone type {backbone_type} is not supported') else: if backbone_type == 'mlp': self.backbone = MLP(in_features, hidden_features, hidden_features, hidden_layers, norm, hidden_activation) elif backbone_type == 'res': self.backbone = ResNet(in_features, hidden_features, hidden_features, hidden_layers, norm) elif backbone_type == 'ft_transformer': self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers) else: raise NotImplementedError(f'backbone type {backbone_type} is not supported') for dist in self.dist_config: dist_type = dist["type"] output_dim = dist["output_dim"] if dist_type == "onehot" or dist_type == "discrete_logistic": self.multihead.append(MLP(hidden_features, output_dim, 64, 1, norm=None, hidden_activation='leakyrelu')) elif dist_type == "normal": normal_dim = dist["dim"] for dim in range(normal_dim): self.multihead.append(MLP(hidden_features, int(output_dim // normal_dim), 64, 1, norm=None, hidden_activation='leakyrelu')) else: raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.') self.multihead = nn.ModuleList(self.multihead) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config) if "ts_conv_config" in self.kwargs and self.kwargs['ts_conv_config'] != None: other_state_layer_output = in_features//2 ts_state_conv_layer_output = in_features//2 + in_features%2 self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=hidden_activation) all_node_ts = None self.conv_ts_node = [] self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input']) for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): self.conv_ts_node.append('__history__'+tsnode) self.conv_other_node.append('__now__'+tsnode) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] if all_node_ts == None: all_node_ts = node_ts assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}" #ts-1 remove now state kernalsize = all_node_ts-1 self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=hidden_activation)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: if "ts_conv_config" in self.kwargs and self.kwargs['ts_conv_config'] != None: #deal with all ts nodes with ts dim # assert 'input_names' in kwargs assert hasattr(self, 'conv_ts_node') assert hasattr(self, 'conv_other_node') inputdata = deepcopy(state) #data is deteched from other node for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): # assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True batch_size = int(np.prod(list(inputdata[tsnode].shape[:-1]))) original_size = list(inputdata[tsnode].shape[:-1]) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] # inputdata[tsnode] temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1) inputdata['__history__'+tsnode] = temp_data[..., :-1, :] inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1]) state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1) state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1) state_other = self.other_state_layer(state_other) original_size = list(state_other.shape[:-1]) state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1]) state = torch.cat([state_other, state_ts_conv], dim=-1) elif isinstance(state, dict): assert 'input_names' in kwargs state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1) if not self.multihead: output = self.backbone(state) else: backbone_output = self.backbone(state) multihead_output = [] multihead_index = 0 for dist in self.dist_config: dist_type = dist["type"] output_dim = dist["output_dim"] if dist_type == "onehot" or dist_type == "discrete_logistic": multihead_output.append(self.multihead[multihead_index](backbone_output)) multihead_index += 1 elif dist_type == "normal": normal_mode_output = [] normal_std_output = [] for head in self.multihead[multihead_index:]: head_output = head(backbone_output) if head_output.shape[-1] == 1: mode = head_output normal_mode_output.append(mode) else: mode, std = torch.chunk(head_output, 2, axis=-1) normal_mode_output.append(mode) normal_std_output.append(std) normal_output = torch.cat(normal_mode_output, axis=-1) if normal_std_output: normal_std_output = torch.cat(normal_std_output, axis=-1) normal_output = torch.cat([normal_output, normal_std_output], axis=-1) multihead_output.append(normal_output) break else: raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.') output = torch.cat(multihead_output, axis= -1) soft_clamp_flag = True if "soft_clamp" in self.kwargs and self.kwargs['soft_clamp'] == True else False dist = self.dist_wrapper(output, adapt_std, soft_clamp=soft_clamp_flag, **kwargs) if hasattr(self, "dist_mu_shift"): dist = dist.shift(self.dist_mu_shift) return dist
[docs] def reset(self): pass
[docs] @torch.no_grad() def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True): dist = self(state) return dist.mode if deterministic else dist.sample()
[docs] class RecurrentPolicy(torch.nn.Module): """Initializes a recurrent policy network instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features in the RNN. hidden_layers (int): The number of layers in the RNN. dist_config (list): The configuration for the distributions used in the model. backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, backbone_type : Union[str, np.str_] ='gru', **kwargs): super().__init__() RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM rnn_hidden_features = 256 rnn_hidden_layers = min(hidden_layers, 3) rnn_output_features = min(out_features,8) self.rnn = RNN(in_features, rnn_hidden_features, rnn_hidden_layers) self.rnn_mlp = MLP(rnn_hidden_features, rnn_output_features, 0, 0) self.backbone = ResNet(in_features+rnn_output_features, out_features, hidden_features, hidden_layers, norm) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self): self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: x_shape = x.shape if len(x_shape) == 1: rnn_output, self.h = self.rnn(x.unsqueeze(0).unsqueeze(0), self.h) rnn_output = rnn_output.squeeze(0).squeeze(0) elif len(x_shape) == 2: rnn_output, self.h = self.rnn(x.unsqueeze(0), self.h) rnn_output = rnn_output.squeeze(0) else: rnn_output, self.h = self.rnn(x, self.h) rnn_output = self.rnn_mlp(rnn_output) logits = self.backbone(torch.cat([x, rnn_output],axis=-1)) return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class TsRecurrentPolicy(torch.nn.Module): input_is_dict = True """Initializes a recurrent policy network instance for ts_node. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features in the RNN. hidden_layers (int): The number of layers in the RNN. dist_config (list): The configuration for the distributions used in the model. backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, backbone_type : Union[str, np.str_] ='gru', ts_input_names : dict = {}, bidirectional: bool = True, dropout: float = 0, **kwargs): super().__init__() RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM self.ts_input_names = ts_input_names self.input_dim_dict = kwargs.get("input_dim_dict",{}) self.max_ts_steps = max(self.ts_input_names.values()) self.rnn = RNN(input_size=in_features, hidden_size=hidden_features, num_layers=hidden_layers, bidirectional=bidirectional, batch_first=True, dropout=dropout if hidden_layers > 1 else 0) if bidirectional: self.linear = nn.Linear(hidden_features * 2, out_features) else: self.linear = nn.Linear(hidden_features, out_features) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self): self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: x_shape = x[list(x.keys())[0]].shape if len(x_shape) == 3: x = {k:v.reshape(-1,v.shape[-1]) for k,v in x.items()} input = {} for k,v in x.items(): ts_steps = self.ts_input_names.get(k,1) input[k] = v.reshape(v.shape[0],ts_steps,-1) if ts_steps < self.max_ts_steps: input[k] = torch.cat([input[k][:,:1].repeat(1, self.max_ts_steps-ts_steps, 1), input[k]],axis=1) if len(input) == 1: rnn_input = list(input.values())[0] else: rnn_input = torch.cat(list(input.values()),axis=-1) out, _ = self.rnn(rnn_input,None) out = self.linear(out[:, -1, :]) if len(x_shape) == 3: logits = out.reshape(x_shape[0],x_shape[1],-1) else: logits = out.reshape(x_shape[0],-1) return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class TsTransformerPolicy(torch.nn.Module): input_is_dict = True """Initializes a recurrent policy network instance for ts_node. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features in the RNN. hidden_layers (int): The number of layers in the RNN. dist_config (list): The configuration for the distributions used in the model. backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, ts_input_names : dict = {}, dropout: float = 0, **kwargs): super().__init__() self.ts_input_names = ts_input_names self.input_dim_dict = kwargs.get("input_dim_dict",{}) self.max_ts_steps = max(self.ts_input_names.values()) self.backbone = TsTransformer(in_features=in_features, out_features=out_features, hidden_layers=hidden_layers, hidden_features=hidden_features, dropout=dropout) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self): self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: x_shape = x[list(x.keys())[0]].shape if len(x_shape) == 3: x = {k:v.reshape(-1,v.shape[-1]) for k,v in x.items()} input = {} for k,v in x.items(): ts_steps = self.ts_input_names.get(k,1) input[k] = v.reshape(v.shape[0],ts_steps,-1) if ts_steps < self.max_ts_steps: input[k] = torch.cat([input[k][:,:1].repeat(1, self.max_ts_steps-ts_steps, 1), input[k]],axis=1) if len(input) == 1: transformer_input = list(input.values())[0] else: transformer_input = torch.cat(list(input.values()),axis=-1) out = self.backbone(transformer_input) if len(x_shape) == 3: logits = out.reshape(x_shape[0],x_shape[1],-1) else: logits = out.reshape(x_shape[0],-1) return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class RecurrentRESPolicy(torch.nn.Module): def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, backbone_type : Union[str, np.str_] ='gru', rnn_hidden_features : int = 64, window_size : int = 0, **kwargs): super().__init__() self.kwargs = kwargs RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM self.in_feature_embed = nn.Sequential( nn.Linear(in_features, hidden_features), ACTIVATION_CREATORS['relu'](hidden_features), nn.Dropout(0.1), nn.LayerNorm(hidden_features) ) self.side_net = nn.Sequential( ResBlock(hidden_features, hidden_features), nn.Dropout(0.1), nn.LayerNorm(hidden_features), ResBlock(hidden_features, hidden_features), nn.Dropout(0.1), nn.LayerNorm(hidden_features) ) self.rnn = RNN(hidden_features, rnn_hidden_features, hidden_layers) self.backbone = MLP(hidden_features + rnn_hidden_features, out_features, hidden_features, 1) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config) self.hidden_layers = hidden_layers self.hidden_features = hidden_features self.rnn_hidden_features = rnn_hidden_features self.window_size = window_size self.hidden_que = deque(maxlen=self.window_size) # [(h_0, c_0), (h_1, c_1), ...] self.h = None
[docs] def reset(self): self.hidden_que = deque(maxlen=self.window_size) self.h = None
[docs] def preprocess_window(self, a_embed): shape = a_embed.shape if len(self.hidden_que) < self.window_size: # do Not pop if self.hidden_que: h = self.hidden_que[0] else: h = (torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed), torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed)) else: h = self.hidden_que.popleft() hidden_cat = torch.concat([h[0]] + [hidden[0] for hidden in self.hidden_que] + [torch.zeros_like(h[0]).to(h[0])], dim=1) # (layers, bs * 4, dim) cell_cat = torch.concat([h[1]] + [hidden[1] for hidden in self.hidden_que] + [torch.zeros_like(h[1]).to(h[1])], dim=1) # (layers, bs * 4, dim) a_embed = torch.repeat_interleave(a_embed, repeats=len(self.hidden_que)+2, dim=0).view(1, (len(self.hidden_que)+2)*shape[1], -1) # (1, bs * 4, dim) return a_embed, hidden_cat, cell_cat
[docs] def postprocess_window(self, rnn_output, hidden_cat, cell_cat): hidden_cat = torch.chunk(hidden_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim) cell_cat = torch.chunk(cell_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim) rnn_output = torch.chunk(rnn_output, chunks=len(self.hidden_que)+2, dim=1) # tuples of (1, bs, dim) self.hidden_que = deque(zip(hidden_cat[1:], cell_cat[1:]), maxlen=self.window_size) # important to discrad the first element !!! rnn_output = rnn_output[0] # (1, bs, dim) return rnn_output
[docs] def rnn_forward(self, a_embed, joint_train : bool): if self.window_size > 0: a_embed, hidden_cat, cell_cat = self.preprocess_window(a_embed) if joint_train: rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat)) else: with torch.no_grad(): rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat)) rnn_output = self.postprocess_window(rnn_output, hidden_cat, cell_cat) #(1, bs, dim) else: if joint_train: rnn_output, self.h = self.rnn(a_embed, self.h) #(1, bs, dim) else: with torch.no_grad(): rnn_output, self.h = self.rnn(a_embed, self.h) return rnn_output
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution: if field == 'bc': shape = x.shape joint_train = True if len(shape) == 2: # (bs, dim) x_embed = self.in_feature_embed(x) side_output = self.side_net(x_embed) a_embed = x_embed.unsqueeze(0) # [1, bs, dim] rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim] rnn_output = rnn_output.squeeze(0) # (bs, dim) logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim) else: assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}" self.reset() output = [] for i in range(shape[0]): a = x[i] a = a.unsqueeze(0) #(1, bs, dim) a_embed = self.in_feature_embed(a) # (1, bs, dim) side_output = self.side_net(a_embed) rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim] backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim) output.append(backbone_output) logits = torch.concat(output, dim=0) elif field == 'mail': shape = x.shape joint_train = False if len(shape) == 2: x_embed = self.in_feature_embed(x) side_output = self.side_net(x_embed) a_embed = x_embed.unsqueeze(0) rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim] rnn_output = rnn_output.squeeze(0) # (bs, dim) logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim) else: assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}" self.reset() output = [] for i in range(shape[0]): a = x[i] a = a.unsqueeze(0) #(1, bs, dim) a_embed = self.in_feature_embed(a) # (1, bs, dim) side_output = self.side_net(a_embed) rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim] backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim) output.append(backbone_output) logits = torch.concat(output, dim=0) else: raise NotImplementedError(f"unknow field: {field} in RNN training !") return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class ContextualPolicy(torch.nn.Module): def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, backbone_type : Union[str, np.str_] ='contextual_gru', **kwargs): super().__init__() self.kwargs = kwargs RNN = torch.nn.GRU if backbone_type == 'contextual_gru' else torch.nn.LSTM self.preprocess_mlp = MLP(in_features, hidden_features, 0, 0, output_activation='leakyrelu') self.rnn = RNN(hidden_features, hidden_features, 1) self.mlp = MLP(hidden_features + in_features, out_features, hidden_features, hidden_layers) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self): self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: in_feature = x if len(x.shape) == 2: x = x.unsqueeze(0) x = self.preprocess_mlp(x) rnn_output, self.h = self.rnn(x, self.h) rnn_output = rnn_output.squeeze(0) else: x = self.preprocess_mlp(x) rnn_output, self.h = self.rnn(x) logits = self.mlp(torch.cat((in_feature, rnn_output), dim=-1)) return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class EnsembleFeedForwardPolicy(torch.nn.Module): """Policy for using mlp, resnet and transformer backbone. Args: in_features : The number of input features, or a dictionary describing the input distribution out_features : The number of output features hidden_features : The number of hidden features in each layer hidden_layers : The number of hidden layers dist_config : A list of configurations for the distributions to use in the model norm : The type of normalization to apply to the input features hidden_activation : The activation function to use in the hidden layers backbone_type : The type of backbone to use in the model use_multihead : Whether to use a multihead model use_feature_embed : Whether to use feature embedding """ def __init__(self, in_features : Union[int, dict], out_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, dist_config : list, norm : str = None, hidden_activation : str = 'leakyrelu', backbone_type : Union[str, np.str_] = 'EnsembleRES', use_multihead : bool = False, **kwargs): super().__init__() self.multihead = [] self.ensemble_size = ensemble_size self.dist_config = dist_config self.kwargs = kwargs if isinstance(in_features, dict): in_features = sum(in_features.values()) if backbone_type == 'EnsembleMLP': self.backbone = VectorizedMLP(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation) elif backbone_type == 'EnsembleRES': self.backbone = VectorizedResNet(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm) else: raise NotImplementedError(f'backbone type {backbone_type} is not supported') self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: if isinstance(state, dict): assert 'input_names' in kwargs state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1) output = self.backbone(state) dist = self.dist_wrapper(output, adapt_std, **kwargs) if hasattr(self, "dist_mu_shift"): dist = dist.shift(self.dist_mu_shift) return dist
[docs] def reset(self): pass
[docs] @torch.no_grad() def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True): dist = self(state) return dist.mode if deterministic else dist.sample()
# ------------------------------- Transitions ------------------------------- #
[docs] class FeedForwardTransition(FeedForwardPolicy): r"""Initializes a feedforward transition instance. Args: in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The configuration of the distribution to use. norm (Optional[str]): The normalization method to use. None if no normalization is used. hidden_activation (str): The activation function to use for the hidden layers. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode to use for the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation for 'local' mode. use_feature_embed (bool): Whether to use feature embedding. """ def __init__(self, in_features : Union[int, dict], out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, norm : Optional[str] = None, hidden_activation : str = 'leakyrelu', backbone_type : Union[str, np.str_] = 'mlp', mode : str = 'global', obs_dim : Optional[int] = None, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': dist_types = [config['type'] for config in dist_config] if 'onehot' in dist_types or 'discrete_logistic' in dist_types: warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!') self.mode = 'global' if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(FeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs] def reset(self): pass
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(FeedForwardTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class RecurrentTransition(RecurrentPolicy): r""" Initializes a recurrent transition instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The distribution configuration. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode of the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, backbone_type : Union[str, np.str_] ='gru', mode : str = 'global', obs_dim : Optional[int] = None, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \ "The local mode of transition is not compatible with onehot data! Please fallback to global mode!" if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(RecurrentTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, norm, dist_config, backbone_type, **kwargs)
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(RecurrentTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class TsRecurrentTransition(TsRecurrentPolicy): input_is_dict = True r""" Initializes a recurrent transition instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The distribution configuration. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode of the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, backbone_type : Union[str, np.str_] ='gru', mode : str = 'global', obs_dim : Optional[int] = None, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \ "The local mode of transition is not compatible with onehot data! Please fallback to global mode!" if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(TsRecurrentTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, norm, dist_config, backbone_type, **kwargs)
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(TsRecurrentTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class TsTransformerTransition(TsTransformerPolicy): input_is_dict = True r""" Initializes a recurrent transition instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The distribution configuration. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode of the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, mode : str = 'global', obs_dim : Optional[int] = None, ts_input_names : dict = {}, dropout: float = 0, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \ "The local mode of transition is not compatible with onehot data! Please fallback to global mode!" if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(TsTransformerTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, norm, dist_config, ts_input_names, dropout, **kwargs)
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(TsTransformerTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class RecurrentRESTransition(RecurrentRESPolicy): def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, backbone_type : Union[str, np.str_] = 'mlp', mode : str = 'global', obs_dim : Optional[int] = None, rnn_hidden_features : int = 64, window_size : int = 0, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \ "The local mode of transition is not compatible with onehot data! Please fallback to global mode!" if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(RecurrentRESTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, dist_config, backbone_type, rnn_hidden_features=rnn_hidden_features, window_size=window_size, **kwargs)
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution: dist = super(RecurrentRESTransition, self).forward(state, adapt_std, field, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class EnsembleFeedForwardTransition(EnsembleFeedForwardPolicy): r"""Initializes a feedforward transition instance. Args: in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The configuration of the distribution to use. norm (Optional[str]): The normalization method to use. None if no normalization is used. hidden_activation (str): The activation function to use for the hidden layers. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode to use for the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation for 'local' mode. use_feature_embed (bool): Whether to use feature embedding. """ def __init__(self, in_features : Union[int, dict], out_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, dist_config : list, norm : Optional[str] = None, hidden_activation : str = 'leakyrelu', backbone_type : Union[str, np.str_] = 'EnsembleRES', mode : str = 'global', obs_dim : Optional[int] = None, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': dist_types = [config['type'] for config in dist_config] if 'onehot' in dist_types or 'discrete_logistic' in dist_types: warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!') self.mode = 'global' if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(EnsembleFeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, ensemble_size, dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs] def reset(self): pass
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(EnsembleFeedForwardTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
# ------------------------------- Matchers ------------------------------- #
[docs] class FeedForwardMatcher(torch.nn.Module): r""" Initializes a feedforward matcher instance. Args: in_features (int): The number of input features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. hidden_activation (str): The activation function to use for the hidden layers. norm (str): The normalization method to use. None if no normalization is used. backbone_type (Union[str, np.str_]): The type of backbone to use. """ def __init__(self, in_features : int, hidden_features : int, hidden_layers : int, hidden_activation : str = 'leakyrelu', norm : str = None, backbone_type : Union[str, np.str_] = 'mlp', **kwargs): super().__init__() if backbone_type == 'mlp': self.backbone = MLP(in_features, 1, hidden_features, hidden_layers, norm, hidden_activation, output_activation='sigmoid') elif backbone_type == 'res': self.backbone = ResNet(in_features, 1, hidden_features, hidden_layers, norm, output_activation='sigmoid') elif backbone_type == 'ft_transformer': self.backbone = FT_Transformer(in_features, 1, hidden_features, hidden_layers=hidden_layers) else: raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: # x.requires_grad = True shape = x.shape if len(shape) == 3: # [horizon, batch_size, dim] x = x.view(-1, shape[-1]) # [horizon * batch_size, dim] assert len(x.shape) == 2, f"len(x.shape) == {len(x.shape)}, expect 2" out = self.backbone(x) # [batch_size, dim] if len(shape) == 3: out = out.view(*shape[:-1], -1) # [horizon, batch_size, dim] return out
[docs] class RecurrentMatcher(torch.nn.Module): def __init__(self, in_features : int, hidden_features : int, hidden_layers : int, backbone_type : Union[str, np.str_] = 'gru', bidirect : bool = False): super().__init__() RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM self.rnn = RNN(in_features, hidden_features, hidden_layers, bidirectional=bidirect) self.output_layer = MLP(hidden_features * (2 if bidirect else 1), 1, 0, 0, output_activation='sigmoid')
[docs] def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor: x = torch.cat(inputs, dim=-1) rnn_output = self.rnn(x)[0] return self.output_layer(rnn_output)
[docs] class HierarchicalMatcher(torch.nn.Module): def __init__(self, in_features : list, hidden_features : int, hidden_layers : int, hidden_activation : int, norm : str = None): super().__init__() self.in_features = in_features process_layers = [] output_layers = [] feature = self.in_features[0] for i in range(1, len(self.in_features)): feature += self.in_features[i] process_layers.append(MLP(feature, hidden_features, hidden_features, hidden_layers, norm, hidden_activation, hidden_activation)) output_layers.append(torch.nn.Linear(hidden_features, 1)) feature = hidden_features self.process_layers = torch.nn.ModuleList(process_layers) self.output_layers = torch.nn.ModuleList(output_layers)
[docs] def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor: assert len(inputs) == len(self.in_features) last_feature = inputs[0] result = 0 for i, process_layer, output_layer in zip(range(1, len(self.in_features)), self.process_layers, self.output_layers): last_feature = torch.cat([last_feature, inputs[i]], dim=-1) last_feature = process_layer(last_feature) result += output_layer(last_feature) return torch.sigmoid(result)
[docs] class VectorizedMatcher(torch.nn.Module): r""" Initializes a feedforward matcher instance. Args: in_features (int): The number of input features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. hidden_activation (str): The activation function to use for the hidden layers. norm (str): The normalization method to use. None if no normalization is used. backbone_type (Union[str, np.str_]): The type of backbone to use. """ def __init__(self, in_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, hidden_activation : str = 'leakyrelu', norm : str = None, backbone_type : Union[str, np.str_] = 'mlp', **kwargs): super().__init__() self.backbone_type = backbone_type self.ensemble_size = ensemble_size if backbone_type == 'mlp': self.backbone = VectorizedMLP(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation, output_activation='sigmoid') elif backbone_type == 'res': self.backbone = VectorizedResNet(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, output_activation='sigmoid') else: raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs] def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor: x = torch.cat(inputs, dim=-1).detach() # x.requires_grad = True shape = x.shape if len(shape) == 2: # [batch_size, dim] x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [batch_size, dim] --> [ensemble_size, batch_size, dim] assert x.shape[0] == self.ensemble_size out = self.backbone(x) # [ensemble_size, batch_size, dim] return out
[docs] class EnsembleMatcher: def __init__(self, matcher_record: Dict[str, deque], ensemble_size : int = 50, ensemble_choosing_interval: int = 1, config : dict = None) -> None: self.matcher_weights_list = [] self.single_structure_dict = config self.matching_nodes = config['matching_nodes'] self.matching_fit_nodes = config['matching_fit_nodes'] # self.matching_nodes_fit_dict = config['matching_nodes_fit_index'] self.device = torch.device("cpu") self.min_generate_score = 0.0 self.max_expert_score = 1.0 self.selected_ind = [] counter = 0 for i in reversed(range(len(matcher_record['state_dicts']))): if counter % ensemble_choosing_interval == 0: matcher_weight = matcher_record['state_dicts'][i] self.selected_ind.append(i) self.matcher_weights_list.append(matcher_weight) if len(self.matcher_weights_list) >= ensemble_size: break counter += 1 self.ensemble_size = len(self.matcher_weights_list) self.matcher_ensemble = VectorizedMatcher(ensemble_size=self.ensemble_size, **config) self.init_ensemble_network() self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores'])
[docs] def load_min_max(self, expert_scores: deque, generate_scores: deque): self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ] self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ] self.max_expert_score = max(self.expert_scores) self.min_generate_score = min(self.generate_scores) self.range = max(abs(self.max_expert_score), abs(self.min_generate_score))
[docs] def init_ensemble_network(self,): if self.matcher_ensemble.backbone_type == 'mlp': with torch.no_grad(): for ind, state_dict in enumerate(self.matcher_weights_list): for key, val in state_dict.items(): layer_ind = int(key.split('.')[2]) if key.endswith("weight"): self.matcher_ensemble.backbone.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): self.matcher_ensemble.backbone.net[layer_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") elif self.matcher_ensemble.backbone_type == 'res': with torch.no_grad(): for ind, state_dict in enumerate(self.matcher_weights_list): for key, val in state_dict.items(): key_ls = key.split('.') if "skip_net" in key_ls: block_ind = int(key_ls[2]) if key.endswith("weight"): self.matcher_ensemble.backbone.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): self.matcher_ensemble.backbone.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") elif "process_net" in key_ls: block_ind, layer_ind = int(key_ls[2]), int(key_ls[4]) if key.endswith("weight"): self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") else: # linear layer at last block_ind = int(key_ls[2]) if key.endswith("weight"): self.matcher_ensemble.backbone.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): self.matcher_ensemble.backbone.resnet[block_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") else: raise NotImplementedError
[docs] def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor: with torch.no_grad(): scores = self.matcher_ensemble(inputs) # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1] scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size] if aggregation == "mean": return scores.mean(-1) return scores
[docs] def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor: with torch.no_grad(): scores = self.matcher_ensemble(inputs) # [ensemble_size, batch_size, 1] scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size] if clip: assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand" scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores) scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range if aggregation == "mean": return scores.mean(-1) return scores
[docs] def to(self, device): self.device = device self.matcher_ensemble = self.matcher_ensemble.to(device) self.expert_scores = self.expert_scores.to(device) self.generate_scores = self.generate_scores.to(device)
# class SequentialMatcher: # def __init__(self, # matcher_record: Dict[str, deque], # ensemble_size : int = 50, # ensemble_choosing_interval: int = 1, # config : dict = None) -> None: # self.matcher_list = [] # self.single_structure_dict = config # self.matching_nodes = config['matching_nodes'] # self.matching_fit_nodes = config['matching_fit_nodes'] # # self.matching_nodes_fit_dict = config['matching_nodes_fit_index'] # self.device = torch.device("cpu") # self.min_generate_score = 0.0 # self.max_expert_score = 1.0 # self.selected_ind = [] # counter = 0 # for i in reversed(range(len(matcher_record['state_dicts']))): # if counter % ensemble_choosing_interval == 0: # matcher_weight = matcher_record['state_dicts'][i] # self.selected_ind.append(i) # _matcher = FeedForwardMatcher(**config) # _matcher.load_state_dict(matcher_weight) # self.matcher_list.append(_matcher) # if len(self.matcher_list) >= ensemble_size: # break # counter += 1 # self.ensemble_size = len(self.matcher_list) # self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores']) # # prepare for vmap func # self.params, self.buffers = stack_module_state(self.matcher_list) # """ # Construct a "stateless" version of one of the models. It is "stateless" in # the sense that the parameters are meta Tensors and do not have storage. # """ # self.base_model = deepcopy(self.matcher_list[0]) # self.base_model = self.base_model.to('meta') # def vmap_forward(self, x): # def vmap_model(params, buffers, x): # return functional_call(self.base_model, (params, buffers), (x,)) # return vmap(vmap_model, in_dims=(0, 0, None))(self.params, self.buffers, x) # def load_min_max(self, expert_scores: deque, generate_scores: deque): # self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ] # self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ] # self.max_expert_score = max(self.expert_scores) # self.min_generate_score = min(self.generate_scores) # self.range = max(abs(self.max_expert_score), abs(self.min_generate_score)) # def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor: # with torch.no_grad(): # scores = self.vmap_forward(inputs) # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1] # scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size] # if aggregation == "mean": # return scores.mean(-1) # return scores # def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor: # with torch.no_grad(): # scores = self.vmap_forward(inputs) # [ensemble_size, batch_size, 1] # scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size] # if clip: # assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand" # scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores) # scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range # if aggregation == "mean": # return scores.mean(-1) # return scores # def to(self, device): # self.device = device # self.matcher_list = [_matcher.to(device) for _matcher in self.matcher_list] # self.expert_scores = self.expert_scores.to(device) # self.generate_scores = self.generate_scores.to(device) # # prepare for vmap func # self.params, self.buffers = stack_module_state(self.matcher_list) # ------------------------------- Others ------------------------------- #
[docs] class VectorizedCritic(VectorizedMLP):
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: x = super(VectorizedCritic, self).forward(x) return x.squeeze(-1)
# ------------------------------- Test ------------------------------- # if __name__ == "__main__": def test_ensemble_mlp(): rnd = torch.rand((6, 5)) ensemble_size = 3 ensemble_linear = VectorizedMLP(5, 4, hidden_features=16, hidden_layers=2, ensemble_size=ensemble_size) linear1 = MLP(5, 4, hidden_features=16, hidden_layers=2) linear2 = MLP(5, 4, hidden_features=16, hidden_layers=2) linear3 = MLP(5, 4, hidden_features=16, hidden_layers=2) state_dict = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()]) for ind, state_dict in enumerate(state_dict): for key, val in state_dict.items(): layer_ind = int(key.split('.')[1]) if key.endswith("weight"): ensemble_linear.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): ensemble_linear.net[layer_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") ensemble_out = ensemble_linear(rnd) linear1_out = linear1(rnd) linear2_out = linear2(rnd) linear3_out = linear3(rnd) # breakpoint() print((ensemble_out[0, ...] - linear1_out).abs().max()) print((ensemble_out[1, ...] - linear2_out).abs().max()) print((ensemble_out[2, ...] - linear3_out).abs().max()) def test_ensemble_res_distinct(): in_fea = 256 out_fea = 15 hidden_features=256 hidden_layers=4 batch_size = 512 ensemble_size = 3 rnd1 = torch.rand((batch_size, in_fea)) rnd2 = torch.rand((batch_size, in_fea)) rnd3 = torch.rand((batch_size, in_fea)) rnd = torch.stack([rnd1, rnd2, rnd3]) ensemble_linear = VectorizedResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, ensemble_size=ensemble_size, norm=None) linear1 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None) linear2 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None) linear3 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None) state_dicts = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()]) for ind, state_dict in enumerate(state_dicts): for key, val in state_dict.items(): key_ls = key.split('.') if "skip_net" in key_ls: block_ind = int(key_ls[1]) if key.endswith("weight"): ensemble_linear.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): ensemble_linear.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") elif "process_net" in key_ls: block_ind, layer_ind = int(key_ls[1]), int(key_ls[3]) if key.endswith("weight"): ensemble_linear.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): ensemble_linear.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") else: # linear layer at last block_ind = int(key_ls[1]) if key.endswith("weight"): ensemble_linear.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): ensemble_linear.resnet[block_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") ensemble_out = ensemble_linear(rnd) linear1_out = linear1(rnd1) linear2_out = linear2(rnd2) linear3_out = linear3(rnd3) # breakpoint() print((ensemble_out[0, ...] - linear1_out).abs().max()) print((ensemble_out[1, ...] - linear2_out).abs().max()) print((ensemble_out[2, ...] - linear3_out).abs().max())
[docs] class Value_Net_VectorizedCritic(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, input_dim_dict, q_hidden_features, q_hidden_layers, num_q_net, *args, **kwargs) -> None: super().__init__() self.kwargs = kwargs self.input_dim_dict = input_dim_dict if isinstance(input_dim_dict, dict): input_dim = sum(input_dim_dict.values()) self.value =VectorizedMLP (input_dim, 1, q_hidden_features, q_hidden_layers, num_q_net) if 'ts_conv_config' in self.kwargs.keys() and self.kwargs['ts_conv_config'] != None: other_state_layer_output = q_hidden_features//2 ts_state_conv_layer_output = q_hidden_features//2 + q_hidden_features%2 self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0) all_node_ts = None self.conv_ts_node = [] self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input']) for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): self.conv_ts_node.append('__history__'+tsnode) self.conv_other_node.append('__now__'+tsnode) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] if all_node_ts == None: all_node_ts = node_ts assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}" #ts-1 remove now state kernalsize = all_node_ts-1 self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: if hasattr(self, 'conv_ts_node'): #deal with all ts nodes with ts dim assert hasattr(self, 'conv_ts_node') assert hasattr(self, 'conv_other_node') inputdata = deepcopy(state.detach()) #data is deteched from other node for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True batch_size = np.prod(list(inputdata[tsnode].shape[:-1])) original_size = list(inputdata[tsnode].shape[:-1]) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] # inputdata[tsnode] temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1) inputdata['__history__'+tsnode] = temp_data[..., :-1, :] inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1]) state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1) state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1) state_other = self.other_state_layer(state_other) original_size = list(state_other.shape[:-1]) state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1]) state = torch.cat([state_other, state_ts_conv], dim=-1) else: pass # state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1) output = self.value(state) output = output.squeeze(-1) return output