Source code for revive.computation.modules

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2025 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
from loguru import logger
import time
import torch
import math
import warnings
import torch.nn.functional as F
from copy import deepcopy
from torch import nn#, vmap
# from torch.func import stack_module_state, functional_call
from typing import Optional, Union, List, Dict, cast
from collections import OrderedDict, deque

from revive.computation.dists import *
from revive.computation.utils import *


[docs] def reglu(x): a, b = x.chunk(2, dim=-1) return a * F.relu(b)
[docs] def geglu(x): a, b = x.chunk(2, dim=-1) return a * F.gelu(b)
[docs] class Swish(nn.Module):
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: return x * torch.sigmoid(x)
ACTIVATION_CREATORS = { 'relu' : lambda dim: nn.ReLU(inplace=True), 'elu' : lambda dim: nn.ELU(), 'leakyrelu' : lambda dim: nn.LeakyReLU(negative_slope=0.1, inplace=True), 'tanh' : lambda dim: nn.Tanh(), 'sigmoid' : lambda dim: nn.Sigmoid(), 'identity' : lambda dim: nn.Identity(), 'prelu' : lambda dim: nn.PReLU(dim), 'gelu' : lambda dim: nn.GELU(), 'reglu' : reglu, 'geglu' : geglu, 'swish' : lambda dim: Swish(), } # -------------------------------- Layers -------------------------------- #
[docs] class Value_Net(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, input_dim_dict, value_hidden_features, value_hidden_layers, value_normalization, value_activation, *args, **kwargs) -> None: super().__init__() self.kwargs = kwargs self.input_dim_dict = input_dim_dict if isinstance(input_dim_dict, dict): input_dim = sum(input_dim_dict.values()) else: input_dim = input_dim_dict self.value = MLP(input_dim, 1, value_hidden_features, value_hidden_layers, norm=value_normalization, hidden_activation=value_activation) if 'ts_conv_config' in self.kwargs.keys() and self.kwargs['ts_conv_config'] != None: other_state_layer_output = input_dim//2 ts_state_conv_layer_output = input_dim//2 + input_dim%2 self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=value_activation) all_node_ts = None self.conv_ts_node = [] self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input']) for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): self.conv_ts_node.append('__history__'+tsnode) self.conv_other_node.append('__now__'+tsnode) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] if all_node_ts == None: all_node_ts = node_ts assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}" #ts-1 remove now state kernalsize = all_node_ts-1 self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=value_activation)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: if hasattr(self, 'conv_ts_node'): #deal with all ts nodes with ts dim assert hasattr(self, 'conv_ts_node') assert hasattr(self, 'conv_other_node') inputdata = deepcopy(state.detach()) #data is deteched from other node for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): # assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True batch_size = np.prod(list(inputdata[tsnode].shape[:-1])) original_size = list(inputdata[tsnode].shape[:-1]) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] # inputdata[tsnode] temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1) inputdata['__history__'+tsnode] = temp_data[..., :-1, :] inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1]) state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1) state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1) state_other = self.other_state_layer(state_other) original_size = list(state_other.shape[:-1]) state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1]) state = torch.cat([state_other, state_ts_conv], dim=-1) else: state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1) output = self.value(state) return output
[docs] class ConvBlock(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, in_features: int, out_features: int, conv_time_step: int, output_activation : str = 'identity', ): super().__init__() output_activation_creator = ACTIVATION_CREATORS[output_activation] # self.in_features = in_features # self.out_features = out_features self.mlp_net = nn.Sequential( nn.Linear(in_features, out_features), output_activation_creator(out_features) ) self.conv_net = nn.Conv1d(in_channels=out_features, out_channels=out_features, groups=out_features, kernel_size=conv_time_step)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp_net(x) #exchange timestep dim with feature dime x = x.transpose(-1,-2) #combine the horizon and batchsize dims out = self.conv_net(x).squeeze(-1) return out
[docs] class VectorizedLinear(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, in_features: int, out_features: int, ensemble_size: int): super().__init__() self.in_features = in_features self.out_features = out_features self.ensemble_size = ensemble_size self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features)) self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features)) self.reset_parameters()
[docs] def reset_parameters(self): # default pytorch init for nn.Linear module for layer in range(self.ensemble_size): nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5)) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0]) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ input: [ensemble_size, batch_size, input_size] weight: [ensemble_size, input_size, out_size] out: [ensemble_size, batch_size, out_size] """ # out = torch.einsum("kij,kjl->kil", x, self.weight) + self.bias out = x @ self.weight + self.bias # out = torch.bmm(x, self.weight) + self.bias return out
[docs] class EnsembleLinear(nn.Module): def __init__( self, input_dim: int, output_dim: int, num_ensemble: int, weight_decay: float = 0.0 ) -> None: super().__init__() self.num_ensemble = num_ensemble self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim))) self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim))) nn.init.trunc_normal_(self.weight, std=1 / (2 * input_dim ** 0.5)) self.register_parameter("saved_weight", nn.Parameter(self.weight.detach().clone())) self.register_parameter("saved_bias", nn.Parameter(self.bias.detach().clone())) self.weight_decay = weight_decay self.device = torch.device('cpu')
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: weight = self.weight bias = self.bias if len(x.shape) == 2: x = torch.einsum('ij,bjk->bik', x, weight) elif len(x.shape) == 3: if x.shape[0] == weight.data.shape[0]: x = torch.einsum('bij,bjk->bik', x, weight) else: x = torch.einsum('cij,bjk->bcik', x, weight) elif len(x.shape) == 4: if x.shape[0] == weight.data.shape[0]: x = torch.einsum('cbij,cjk->cbik', x, weight) else: x = torch.einsum('cdij,bjk->bcdik', x, weight) elif len(x.shape) == 5: x = torch.einsum('bcdij,bjk->bcdik', x, weight) assert x.shape[0] == bias.shape[0] and x.shape[-1] == bias.shape[-1] if len(x.shape) == 4: bias = bias.unsqueeze(1) elif len(x.shape) == 5: bias = bias.unsqueeze(1) bias = bias.unsqueeze(1) x = x + bias return x
[docs] def load_save(self) -> None: self.weight.data.copy_(self.saved_weight.data) self.bias.data.copy_(self.saved_bias.data)
[docs] def update_save(self, indexes: List[int]) -> None: self.saved_weight.data[indexes] = self.weight.data[indexes] self.saved_bias.data[indexes] = self.bias.data[indexes]
[docs] def get_decay_loss(self) -> torch.Tensor: decay_loss = self.weight_decay * (0.5 * ((self.weight ** 2).sum())) return decay_loss
[docs] def to(self, device): if not device == self.device: self.device = device super().to(device) self.weight = self.weight.to(self.device) self.bias = self.bias.to(self.device) self.saved_weight = self.saved_weight.to(self.device) self.saved_bias = self.saved_bias.to(self.device)
# -------------------------------- Backbones -------------------------------- #
[docs] class MLP(nn.Module): r""" Multi-layer Perceptron Args: in_features : int, features numbers of the input out_features : int, features numbers of the output hidden_features : int, features numbers of the hidden layers hidden_layers : int, numbers of the hidden layers norm : str, normalization method between hidden layers, default : None hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu' output_activation : str, activation function used in output layer, default : 'identity' """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str = None, hidden_activation : str = 'leakyrelu', output_activation : str = 'identity'): super(MLP, self).__init__() hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation] output_activation_creator = ACTIVATION_CREATORS[output_activation] if hidden_layers == 0: self.net = nn.Sequential( nn.Linear(in_features, out_features), output_activation_creator(out_features) ) else: net = [] for i in range(hidden_layers): net.append(nn.Linear(in_features if i == 0 else hidden_features, hidden_features)) if norm: if norm == 'ln': net.append(nn.LayerNorm(hidden_features)) elif norm == 'bn': net.append(nn.BatchNorm1d(hidden_features)) else: raise NotImplementedError(f'{norm} does not supported!') net.append(hidden_activation_creator(hidden_features)) net.append(nn.Linear(hidden_features, out_features)) net.append(output_activation_creator(out_features)) self.net = nn.Sequential(*net)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: r""" forward method of MLP only assume the last dim of x matches `in_features` """ return self.net(x)
[docs] class VectorizedMLP(nn.Module): r""" Vectorized MLP Args: in_features : int, features numbers of the input out_features : int, features numbers of the output hidden_features : int, features numbers of the hidden layers hidden_layers : int, numbers of the hidden layers norm : str, normalization method between hidden layers, default : None hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu' output_activation : str, activation function used in output layer, default : 'identity' """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, norm : str = None, hidden_activation : str = 'leakyrelu', output_activation : str = 'identity'): super(VectorizedMLP, self).__init__() self.ensemble_size = ensemble_size hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation] output_activation_creator = ACTIVATION_CREATORS[output_activation] if hidden_layers == 0: self.net = nn.Sequential( VectorizedLinear(in_features, out_features, ensemble_size), output_activation_creator(out_features) ) else: net = [] for i in range(hidden_layers): net.append(VectorizedLinear(in_features if i == 0 else hidden_features, hidden_features, ensemble_size)) if norm: if norm == 'ln': net.append(nn.LayerNorm(hidden_features)) elif norm == 'bn': net.append(nn.BatchNorm1d(hidden_features)) else: raise NotImplementedError(f'{norm} does not supported!') net.append(hidden_activation_creator(hidden_features)) net.append(VectorizedLinear(hidden_features, out_features, ensemble_size)) net.append(output_activation_creator(out_features)) self.net = nn.Sequential(*net)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: r""" forward method of MLP only assume the last dim of x matches `in_features` """ if x.dim() == 2: # [batch_size, dim] x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [ensemble_size, batch_size, x_dim] assert x.dim() == 3 assert x.shape[0] == self.ensemble_size return self.net(x)
[docs] class ResBlock(nn.Module): """ Initializes a residual block instance. Args: input_feature (int): The number of input features to the block. output_feature (int): The number of output features from the block. norm (str, optional): The type of normalization to apply to the block. Default is 'ln' for layer normalization. """ def __init__(self, input_feature : int, output_feature : int, norm : str = 'ln', dropout: float = 0): super().__init__() self.dropout = nn.Dropout(dropout) if dropout else None if norm == 'ln': norm_class = torch.nn.LayerNorm self.process_net = torch.nn.Sequential( torch.nn.Linear(input_feature, output_feature), norm_class(output_feature), torch.nn.ReLU(True), torch.nn.Linear(output_feature, output_feature), norm_class(output_feature), torch.nn.ReLU(True) ) else: self.process_net = torch.nn.Sequential( torch.nn.Linear(input_feature, output_feature), torch.nn.ReLU(True), torch.nn.Linear(output_feature, output_feature), torch.nn.ReLU(True) ) if not input_feature == output_feature: self.skip_net = torch.nn.Linear(input_feature, output_feature) else: self.skip_net = torch.nn.Identity()
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: '''x should be a 2D Tensor due to batchnorm''' if self.dropout is not None: return self.dropout(self.process_net(x)) + self.skip_net(x) else: return self.process_net(x) + self.skip_net(x)
[docs] class VectorizedResBlock(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, in_features: int, out_features: int, ensemble_size: int, norm : str = 'ln'): super().__init__() self.in_features = in_features self.out_features = out_features self.ensemble_size = ensemble_size if norm == 'ln': norm_class = torch.nn.LayerNorm self.process_net = torch.nn.Sequential( VectorizedLinear(in_features, out_features, ensemble_size), norm_class(out_features), torch.nn.ReLU(True), VectorizedLinear(out_features, out_features, ensemble_size), norm_class(out_features), torch.nn.ReLU(True) ) else: self.process_net = torch.nn.Sequential( VectorizedLinear(in_features, out_features, ensemble_size), torch.nn.ReLU(True), VectorizedLinear(out_features, out_features, ensemble_size), torch.nn.ReLU(True) ) if not in_features == out_features: self.skip_net = VectorizedLinear(in_features, out_features, ensemble_size) else: self.skip_net = torch.nn.Identity()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # input: [ensemble_size, batch_size, input_size] # out: [ensemble_size, batch_size, out_size] out = self.process_net(x) + self.skip_net(x) return out
[docs] class ResNet(torch.nn.Module): """ Initializes a residual neural network (ResNet) instance. Args: in_features (int): The number of input features to the ResNet. out_features (int): The number of output features from the ResNet. hidden_features (int): The number of hidden features in each residual block. hidden_layers (int): The number of residual blocks in the ResNet. norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization. output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str = 'ln', output_activation : str = 'identity', dropout: float = 0): super().__init__() modules = [] for i in range(hidden_layers): if i == 0: modules.append(ResBlock(in_features, hidden_features, norm, dropout)) else: modules.append(ResBlock(hidden_features, hidden_features, norm, dropout)) modules.append(torch.nn.Linear(hidden_features, out_features)) modules.append(ACTIVATION_CREATORS[output_activation](out_features)) self.resnet = torch.nn.Sequential(*modules)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: '''NOTE: reshape is needed since resblock only support 2D Tensor''' shape = x.shape x = x.view(-1, shape[-1]) output = self.resnet(x) output = output.view(*shape[:-1], -1) return output
[docs] class VectorizedResNet(torch.nn.Module): """ Initializes a residual neural network (ResNet) instance. Args: in_features (int): The number of input features to the ResNet. out_features (int): The number of output features from the ResNet. hidden_features (int): The number of hidden features in each residual block. hidden_layers (int): The number of residual blocks in the ResNet. norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization. output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, norm : str = 'ln', output_activation : str = 'identity'): super().__init__() self.ensemble_size = ensemble_size modules = [] for i in range(hidden_layers): if i == 0: modules.append(VectorizedResBlock(in_features, hidden_features, ensemble_size, norm)) else: modules.append(VectorizedResBlock(hidden_features, hidden_features, ensemble_size, norm)) modules.append(VectorizedLinear(hidden_features, out_features, ensemble_size)) modules.append(ACTIVATION_CREATORS[output_activation](out_features)) self.resnet = torch.nn.Sequential(*modules)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: '''NOTE: reshape is needed since resblock only support 2D Tensor''' shape = x.shape if len(shape) == 2: x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [batch_size, dim] --> [ensemble_size, batch_size, dim] assert x.dim() == 3, f"shape == {shape}, with length == {len(shape)}, expect length 3" assert x.shape[0] == self.ensemble_size # [ensemble_size, batch_size, dim] output = self.resnet(x) return output
[docs] class TAPE(nn.Module): def __init__(self, d_model, max_len=200, scale_factor=1.0): super(TAPE, self).__init__() pe = torch.zeros(max_len, d_model) # positional encoding position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin((position * div_term)*(d_model/max_len)) pe[:, 1::2] = torch.cos((position * div_term)*(d_model/max_len)) pe = scale_factor * pe self.register_buffer('pe', pe) # this stores the variable in the state_dict (used for non-trainable variables)
[docs] def forward(self, x): return x + self.pe[:x.shape[1], :]
[docs] class TAttention(nn.Module): def __init__(self, d_model, nhead, dropout): super().__init__() self.d_model = d_model self.nhead = nhead self.qtrans = nn.Linear(d_model, d_model, bias=False) self.ktrans = nn.Linear(d_model, d_model, bias=False) self.vtrans = nn.Linear(d_model, d_model, bias=False) self.attn_dropout = [] if dropout > 0: for i in range(nhead): self.attn_dropout.append(nn.Dropout(p=dropout)) self.attn_dropout = nn.ModuleList(self.attn_dropout) # input LayerNorm self.norm1 = nn.LayerNorm(d_model, eps=1e-5) # FFN layerNorm self.norm2 = nn.LayerNorm(d_model, eps=1e-5) # FFN self.ffn = nn.Sequential( nn.Linear(d_model, d_model), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(d_model, d_model), nn.Dropout(p=dropout) )
[docs] def forward(self, x): x = self.norm1(x) q = self.qtrans(x) k = self.ktrans(x) v = self.vtrans(x) dim = int(self.d_model / self.nhead) att_output = [] for i in range(self.nhead): if i==self.nhead-1: qh = q[:, :, i * dim:] kh = k[:, :, i * dim:] vh = v[:, :, i * dim:] else: qh = q[:, :, i * dim:(i + 1) * dim] kh = k[:, :, i * dim:(i + 1) * dim] vh = v[:, :, i * dim:(i + 1) * dim] atten_ave_matrixh = torch.softmax(torch.matmul(qh, kh.transpose(1, 2)), dim=-1) if self.attn_dropout: atten_ave_matrixh = self.attn_dropout[i](atten_ave_matrixh) att_output.append(torch.matmul(atten_ave_matrixh, vh)) att_output = torch.concat(att_output, dim=-1) # FFN xt = x + att_output xt = self.norm2(xt) att_output = xt + self.ffn(xt) return att_output
[docs] class TsTransformer(nn.Module): def __init__(self, in_features, out_features, hidden_layers, hidden_features, dropout=0.): super(TsTransformer,self).__init__() layers = [] layers.append(nn.Linear(in_features, hidden_features)) layers.append(TAPE(hidden_features)) for i in range(hidden_layers): layers.append(TAttention(hidden_features, nhead=4, dropout=dropout)) layers.append(nn.Linear(hidden_features, out_features)) self.layers = nn.Sequential(*layers)
[docs] def forward(self, x): if len(x.shape) == 2: x = x.reshape(x.shape[0],-1,2) output = self.layers(x).squeeze(-1)[:,-1:] return output
[docs] class Transformer1D(nn.Module): """ This is an experimental backbone. """ def __init__(self, in_features : int, out_features : int, transformer_features : int = 16, transformer_heads : int = 8, transformer_layers : int = 4): super().__init__() self.linear = torch.nn.Linear(in_features, out_features) self.register_parameter('in_weight', torch.nn.Parameter(torch.randn(out_features, transformer_features))) self.register_parameter('in_bais', torch.nn.Parameter(torch.zeros(out_features, transformer_features))) self.register_parameter('out_weight', torch.nn.Parameter(torch.randn(out_features, transformer_features))) self.register_parameter('out_bais', torch.nn.Parameter(torch.zeros(out_features))) torch.nn.init.xavier_normal_(self.in_weight) torch.nn.init.zeros_(self.out_weight) # zero initialize encoder_layer = torch.nn.TransformerEncoderLayer(transformer_features, transformer_heads, 512) self.transformer = torch.nn.TransformerEncoder(encoder_layer, transformer_layers)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: shape = x.shape x = x.view(-1, x.shape[-1]) # [B, I] x = self.linear(x) # [B, O] x = x.unsqueeze(dim=-1) # [B, O, 1] x = x * self.in_weight + self.in_bais # [B, O, F] x = x.permute(1, 0, 2).contiguous() # [O, B, F] x = self.transformer(x) # [O, B, F] x = x.permute(1, 0, 2).contiguous() # [B, O, F] x = torch.sum(x * self.out_weight, dim=-1) + self.out_bais # [B, O] x = x.view(*shape[:-1], x.shape[-1]) return x
[docs] class Tokenizer(nn.Module): category_offsets: Optional[torch.Tensor] def __init__( self, d_numerical: int, categories: Optional[List[int]], d_token: int, bias: bool, ) -> None: super().__init__() if categories is None: d_bias = d_numerical self.category_offsets = None self.category_embeddings = None else: d_bias = d_numerical + len(categories) category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0) self.register_buffer('category_offsets', category_offsets) self.category_embeddings = nn.Embedding(sum(categories), d_token) nn.init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5)) print(f'{self.category_embeddings.weight.shape=}') # take [CLS] token into account self.weight = nn.Parameter(torch.Tensor(d_numerical + 1, d_token)) self.bias = nn.Parameter(torch.Tensor(d_bias, d_token)) if bias else None # The initialization is inspired by nn.Linear nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5)) @property def n_tokens(self) -> int: return len(self.weight) + ( 0 if self.category_offsets is None else len(self.category_offsets) )
[docs] def forward(self, x_num: torch.Tensor, x_cat: Optional[torch.Tensor]): x_some = x_num if x_cat is None else x_cat assert x_some is not None x_num = torch.cat( [torch.ones(len(x_some), 1, device=x_some.device)] # [CLS] + ([] if x_num is None else [x_num]), dim=1, ) x = self.weight[None] * x_num[:, :, None] if x_cat is not None: x = torch.cat( [x, self.category_embeddings(x_cat + self.category_offsets[None])], dim=1, ) if self.bias is not None: bias = torch.cat( [ torch.zeros(1, self.bias.shape[1], device=x.device), self.bias, ] ) x = x + bias[None] return x
[docs] class MultiheadAttention(nn.Module): def __init__( self, d: int, n_heads: int, dropout: float, initialization: str ) -> None: if n_heads > 1: assert d % n_heads == 0 assert initialization in ['xavier', 'kaiming'] super().__init__() self.W_q = nn.Linear(d, d) self.W_k = nn.Linear(d, d) self.W_v = nn.Linear(d, d) self.W_out = nn.Linear(d, d) if n_heads > 1 else None self.n_heads = n_heads self.dropout = nn.Dropout(dropout) if dropout else None for m in [self.W_q, self.W_k, self.W_v]: if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v): # gain is needed since W_qkv is represented with 3 separate layers nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2)) nn.init.zeros_(m.bias) if self.W_out is not None: nn.init.zeros_(self.W_out.bias) def _reshape(self, x): batch_size, n_tokens, d = x.shape d_head = d // self.n_heads return ( x.reshape(batch_size, n_tokens, self.n_heads, d_head) .transpose(1, 2) .reshape(batch_size * self.n_heads, n_tokens, d_head) )
[docs] def forward( self, x_q, x_kv, key_compression: Optional[nn.Linear], value_compression: Optional[nn.Linear], ): q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv) for tensor in [q, k, v]: assert tensor.shape[-1] % self.n_heads == 0 if key_compression is not None: assert value_compression is not None k = key_compression(k.transpose(1, 2)).transpose(1, 2) v = value_compression(v.transpose(1, 2)).transpose(1, 2) else: assert value_compression is None batch_size = len(q) d_head_key = k.shape[-1] // self.n_heads d_head_value = v.shape[-1] // self.n_heads n_q_tokens = q.shape[1] q = self._reshape(q) k = self._reshape(k) attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1) if self.dropout is not None: attention = self.dropout(attention) x = attention @ self._reshape(v) x = ( x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value) .transpose(1, 2) .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value) ) if self.W_out is not None: x = self.W_out(x) return x
[docs] class FT_Transformer(nn.Module): """FT_Transformer References: - https://github.com/Yura52/tabular-dl-revisiting-models/blob/main/bin/ft_transformer.py """ def __init__( self, in_features : int, out_features : int, hidden_features : int =256, hidden_layers : int = 5, categories: Optional[List[int]]=None, token_bias: bool = True, # transformer n_heads: int = 8, d_ffn_factor: float = 1.33, attention_dropout: float = 0.0, ffn_dropout: float = 0.0, residual_dropout: float = 0.0, activation: str = "reglu", prenormalization: bool = True, initialization: str = "kaiming", # linformer kv_compression: Optional[float] = None, kv_compression_sharing: Optional[str] = None, ) -> None: d_numerical = in_features d_out = out_features n_layers = hidden_layers d_token = hidden_features assert (kv_compression is None) ^ (kv_compression_sharing is not None) super().__init__() self.tokenizer = Tokenizer(d_numerical, categories, d_token, token_bias) n_tokens = self.tokenizer.n_tokens def make_kv_compression(): assert kv_compression compression = nn.Linear( n_tokens, int(n_tokens * kv_compression), bias=False ) if initialization == 'xavier': nn.init.xavier_uniform_(compression.weight) return compression self.shared_kv_compression = ( make_kv_compression() if kv_compression and kv_compression_sharing == 'layerwise' else None ) def make_normalization(): return nn.LayerNorm(d_token) d_hidden = int(d_token * d_ffn_factor) self.layers = nn.ModuleList([]) for layer_idx in range(n_layers): layer = nn.ModuleDict( { 'attention': MultiheadAttention( d_token, n_heads, attention_dropout, initialization ), 'linear0': nn.Linear( d_token, d_hidden * (2 if activation.endswith('glu') else 1) ), 'linear1': nn.Linear(d_hidden, d_token), 'norm1': make_normalization(), } ) if not prenormalization or layer_idx: layer['norm0'] = make_normalization() if kv_compression and self.shared_kv_compression is None: layer['key_compression'] = make_kv_compression() if kv_compression_sharing == 'headwise': layer['value_compression'] = make_kv_compression() else: assert kv_compression_sharing == 'key-value' self.layers.append(layer) self.activation = ACTIVATION_CREATORS[activation] self.last_activation = ACTIVATION_CREATORS[activation] self.prenormalization = prenormalization self.last_normalization = make_normalization() if prenormalization else None self.ffn_dropout = ffn_dropout self.residual_dropout = residual_dropout self.head = nn.Linear(int(d_token//2), d_out) def _get_kv_compressions(self, layer): return ( (self.shared_kv_compression, self.shared_kv_compression) if self.shared_kv_compression is not None else (layer['key_compression'], layer['value_compression']) if 'key_compression' in layer and 'value_compression' in layer else (layer['key_compression'], layer['key_compression']) if 'key_compression' in layer else (None, None) ) def _start_residual(self, x, layer, norm_idx): x_residual = x if self.prenormalization: norm_key = f'norm{norm_idx}' if norm_key in layer: x_residual = layer[norm_key](x_residual) return x_residual def _end_residual(self, x, x_residual, layer, norm_idx): if self.residual_dropout: x_residual = F.dropout(x_residual, self.residual_dropout, self.training) x = x + x_residual if not self.prenormalization: x = layer[f'norm{norm_idx}'](x) return x
[docs] def forward(self, x_num, x_cat=None): if len(x_num.shape) == 3: x = x_num.view(-1, x_num.shape[-1]) else: x = x_num x = self.tokenizer(x, x_cat) for layer_idx, layer in enumerate(self.layers): is_last_layer = layer_idx + 1 == len(self.layers) layer = cast(Dict[str, nn.Module], layer) x_residual = self._start_residual(x, layer, 0) x_residual = layer['attention']( # for the last attention, it is enough to process only [CLS] (x_residual[:, :1] if is_last_layer else x_residual), x_residual, *self._get_kv_compressions(layer), ) if is_last_layer: x = x[:, : x_residual.shape[1]] x = self._end_residual(x, x_residual, layer, 0) x_residual = self._start_residual(x, layer, 1) x_residual = layer['linear0'](x_residual) x_residual = self.activation(x_residual) if self.ffn_dropout: x_residual = F.dropout(x_residual, self.ffn_dropout, self.training) x_residual = layer['linear1'](x_residual) x = self._end_residual(x, x_residual, layer, 1) assert x.shape[1] == 1 x = x[:, 0] if self.last_normalization is not None: x = self.last_normalization(x) x = self.last_activation(x) x = self.head(x) x = x.squeeze(-1) if len(x_num.shape) == 3: x = x.view(-1, x_num.shape[1], x.shape[-1]) return x
[docs] class DistributionWrapper(nn.Module): r"""wrap output of Module to distribution""" BASE_TYPES = ['normal', 'gmm', 'onehot', 'discrete_logistic'] SUPPORTED_TYPES = BASE_TYPES + ['mix'] def __init__(self, distribution_type : str = 'normal', **params): super().__init__() self.distribution_type = distribution_type self.params = params assert self.distribution_type in self.SUPPORTED_TYPES, f"{self.distribution_type} is not supported!" if self.distribution_type == 'normal': self.max_logstd = nn.Parameter(torch.ones(self.params['dim']) * 0.0, requires_grad=True) self.min_logstd = nn.Parameter(torch.ones(self.params['dim']) * -7, requires_grad=True) if not self.params.get('conditioned_std', True): self.logstd = nn.Parameter(torch.zeros(self.params['dim']), requires_grad=True) elif self.distribution_type == 'gmm': self.max_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * 0, requires_grad=True) self.min_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * -10, requires_grad=True) if not self.params.get('conditioned_std', True): self.logstd = nn.Parameter(torch.zeros(self.params['mixture'], self.params['dim']), requires_grad=True) elif self.distribution_type == 'discrete_logistic': self.num = self.params['num'] elif self.distribution_type == 'mix': assert 'dist_config' in self.params.keys(), "You need to provide `dist_config` for Mix distribution" self.dist_config = self.params['dist_config'] self.wrapper_list = [] self.input_sizes = [] self.output_sizes = [] for config in self.dist_config: dist_type = config['type'] assert dist_type in self.SUPPORTED_TYPES, f"{dist_type} is not supported!" assert not dist_type == 'mix', "recursive MixDistribution is not supported!" self.wrapper_list.append(DistributionWrapper(dist_type, **config)) self.input_sizes.append(config['dim']) self.output_sizes.append(config['output_dim']) self.wrapper_list = nn.ModuleList(self.wrapper_list)
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, payload : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: ''' Warp the given tensor to distribution :param adapt_std : it will overwrite the std part of the distribution (optional) :param payload : payload will be applied to the output distribution after built (optional) ''' # [ OTHER ] replace the clip with tanh # [ OTHER ] better std controlling strategy is needed todo if self.distribution_type == 'normal': if self.params.get('conditioned_std', True): mu, logstd = torch.chunk(x, 2, dim=-1) if 'soft_clamp' in kwargs and kwargs['soft_clamp'] == True: # distribution wrapper only for revive_f venv training logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd) std = torch.exp(logstd) else: # distribution wrapper for others logstd_logit = logstd max_std = 0.5 min_std = 0.001 std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std else: mu, logstd = x, self.logstd logstd_logit = self.logstd max_std = 0.5 min_std = 0.001 std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std # std = torch.exp(logstd) if payload is not None: mu = mu + safe_atanh(payload) mu = torch.tanh(mu) """replace tanh node = kwargs["node"] mu = node.processor.process_torch({node.name:mu})[node.name] mu = torch.clamp(mu, min=-1.1, max=1.1) """ if adapt_std is not None: std = torch.ones_like(mu).to(mu) * adapt_std return DiagnalNormal(mu, std) elif self.distribution_type == 'gmm': if self.params.get('conditioned_std', True): logits, mus, logstds = torch.split(x, [self.params['mixture'], self.params['mixture'] * self.params['dim'], self.params['mixture'] * self.params['dim']], dim=-1) mus = mus.view(*mus.shape[:-1], self.params['mixture'], self.params['dim']) logstds = logstds.view(*logstds.shape[:-1], self.params['mixture'], self.params['dim']) else: logits, mus = torch.split(x, [self.params['mixture'], self.params['mixture'] * self.params['dim']], dim=-1) logstds = self.logstd if payload is not None: mus = mus + safe_atanh(payload.unsqueeze(dim=-2)) mus = torch.tanh(mus) stds = adapt_std if adapt_std is not None else torch.exp(soft_clamp(logstds, self.min_logstd, self.max_logstd)) return GaussianMixture(mus, stds, logits) elif self.distribution_type == 'onehot': return Onehot(x) elif self.distribution_type == 'discrete_logistic': mu, logstd = torch.chunk(x, 2, dim=-1) logstd = torch.clamp(logstd, -7, 1) return DiscreteLogistic(mu, torch.exp(logstd), num=self.num) elif self.distribution_type == 'mix': xs = torch.split(x, self.output_sizes, dim=-1) if isinstance(adapt_std, float): adapt_stds = [adapt_std] * len(self.input_sizes) elif adapt_std is not None: adapt_stds = torch.split(adapt_std, self.input_sizes, dim=-1) else: adapt_stds = [None] * len(self.input_sizes) if payload is not None: payloads = torch.split(payload, self.input_sizes + [payload.shape[-1] - sum(self.input_sizes)], dim=-1)[:-1] else: payloads = [None] * len(self.input_sizes) dists = [wrapper(x, _adapt_std, _payload, **kwargs) for x, _adapt_std, _payload, wrapper in zip(xs, adapt_stds, payloads, self.wrapper_list)] return MixDistribution(dists)
[docs] def extra_repr(self) -> str: return 'type={}, dim={}'.format( self.distribution_type, self.params['dim'] if not self.distribution_type == 'mix' else len(self.wrapper_list) )
# --------------------------------- Policies -------------------------------- #
[docs] class FeedForwardPolicy(torch.nn.Module): """Policy for using mlp, resnet and transformer backbone. Args: in_features : The number of input features, or a dictionary describing the input distribution out_features : The number of output features hidden_features : The number of hidden features in each layer hidden_layers : The number of hidden layers dist_config : A list of configurations for the distributions to use in the model norm : The type of normalization to apply to the input features hidden_activation : The activation function to use in the hidden layers backbone_type : The type of backbone to use in the model use_multihead : Whether to use a multihead model use_feature_embed : Whether to use feature embedding """ def __init__(self, in_features : Union[int, dict], out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, norm : str = None, hidden_activation : str = 'leakyrelu', backbone_type : Union[str, np.str_] = 'mlp', use_multihead : bool = False, **kwargs): super().__init__() self.multihead = [] self.dist_config = dist_config self.kwargs = kwargs if isinstance(in_features, dict): in_features = sum(in_features.values()) if not use_multihead: if backbone_type == 'mlp': self.backbone = MLP(in_features, out_features, hidden_features, hidden_layers, norm, hidden_activation) elif backbone_type == 'res': self.backbone = ResNet(in_features, out_features, hidden_features, hidden_layers, norm) elif backbone_type == 'ft_transformer': self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers) else: raise NotImplementedError(f'backbone type {backbone_type} is not supported') else: if backbone_type == 'mlp': self.backbone = MLP(in_features, hidden_features, hidden_features, hidden_layers, norm, hidden_activation) elif backbone_type == 'res': self.backbone = ResNet(in_features, hidden_features, hidden_features, hidden_layers, norm) elif backbone_type == 'ft_transformer': self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers) else: raise NotImplementedError(f'backbone type {backbone_type} is not supported') for dist in self.dist_config: dist_type = dist["type"] output_dim = dist["output_dim"] if dist_type == "onehot" or dist_type == "discrete_logistic": self.multihead.append(MLP(hidden_features, output_dim, 64, 1, norm=None, hidden_activation='leakyrelu')) elif dist_type == "normal": normal_dim = dist["dim"] for dim in range(normal_dim): self.multihead.append(MLP(hidden_features, int(output_dim // normal_dim), 64, 1, norm=None, hidden_activation='leakyrelu')) else: raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.') self.multihead = nn.ModuleList(self.multihead) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config) if "ts_conv_config" in self.kwargs and self.kwargs['ts_conv_config'] != None: other_state_layer_output = in_features//2 ts_state_conv_layer_output = in_features//2 + in_features%2 self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=hidden_activation) all_node_ts = None self.conv_ts_node = [] self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input']) for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): self.conv_ts_node.append('__history__'+tsnode) self.conv_other_node.append('__now__'+tsnode) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] if all_node_ts == None: all_node_ts = node_ts assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}" #ts-1 remove now state kernalsize = all_node_ts-1 self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=hidden_activation)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: if "ts_conv_config" in self.kwargs and self.kwargs['ts_conv_config'] != None: #deal with all ts nodes with ts dim # assert 'input_names' in kwargs assert hasattr(self, 'conv_ts_node') assert hasattr(self, 'conv_other_node') inputdata = deepcopy(state) #data is deteched from other node for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): # assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True batch_size = int(np.prod(list(inputdata[tsnode].shape[:-1]))) original_size = list(inputdata[tsnode].shape[:-1]) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] # inputdata[tsnode] temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1) inputdata['__history__'+tsnode] = temp_data[..., :-1, :] inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1]) state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1) state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1) state_other = self.other_state_layer(state_other) original_size = list(state_other.shape[:-1]) state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1]) state = torch.cat([state_other, state_ts_conv], dim=-1) elif isinstance(state, dict): assert 'input_names' in kwargs state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1) if not self.multihead: output = self.backbone(state) else: backbone_output = self.backbone(state) multihead_output = [] multihead_index = 0 for dist in self.dist_config: dist_type = dist["type"] output_dim = dist["output_dim"] if dist_type == "onehot" or dist_type == "discrete_logistic": multihead_output.append(self.multihead[multihead_index](backbone_output)) multihead_index += 1 elif dist_type == "normal": normal_mode_output = [] normal_std_output = [] for head in self.multihead[multihead_index:]: head_output = head(backbone_output) if head_output.shape[-1] == 1: mode = head_output normal_mode_output.append(mode) else: mode, std = torch.chunk(head_output, 2, axis=-1) normal_mode_output.append(mode) normal_std_output.append(std) normal_output = torch.cat(normal_mode_output, axis=-1) if normal_std_output: normal_std_output = torch.cat(normal_std_output, axis=-1) normal_output = torch.cat([normal_output, normal_std_output], axis=-1) multihead_output.append(normal_output) break else: raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.') output = torch.cat(multihead_output, axis= -1) soft_clamp_flag = True if "soft_clamp" in self.kwargs and self.kwargs['soft_clamp'] == True else False dist = self.dist_wrapper(output, adapt_std, soft_clamp=soft_clamp_flag, **kwargs) if hasattr(self, "dist_mu_shift"): dist = dist.shift(self.dist_mu_shift) return dist
[docs] def reset(self): pass
[docs] @torch.no_grad() def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True): dist = self(state) return dist.mode if deterministic else dist.sample()
[docs] class RecurrentPolicy(torch.nn.Module): """Initializes a recurrent policy network instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features in the RNN. hidden_layers (int): The number of layers in the RNN. dist_config (list): The configuration for the distributions used in the model. backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, backbone_type : Union[str, np.str_] ='gru', **kwargs): super().__init__() RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM rnn_hidden_features = 256 rnn_hidden_layers = min(hidden_layers, 3) rnn_output_features = min(out_features,8) self.rnn = RNN(in_features, rnn_hidden_features, rnn_hidden_layers) self.rnn_mlp = MLP(rnn_hidden_features, rnn_output_features, 0, 0) self.backbone = ResNet(in_features+rnn_output_features, out_features, hidden_features, hidden_layers, norm) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self): self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: x_shape = x.shape if len(x_shape) == 1: rnn_output, self.h = self.rnn(x.unsqueeze(0).unsqueeze(0), self.h) rnn_output = rnn_output.squeeze(0).squeeze(0) elif len(x_shape) == 2: rnn_output, self.h = self.rnn(x.unsqueeze(0), self.h) rnn_output = rnn_output.squeeze(0) else: rnn_output, self.h = self.rnn(x, self.h) rnn_output = self.rnn_mlp(rnn_output) logits = self.backbone(torch.cat([x, rnn_output],axis=-1)) return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class TsRecurrentPolicy(torch.nn.Module): input_is_dict = True """Initializes a recurrent policy network instance for ts_node. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features in the RNN. hidden_layers (int): The number of layers in the RNN. dist_config (list): The configuration for the distributions used in the model. backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, backbone_type : Union[str, np.str_] ='gru', ts_input_names : dict = {}, bidirectional: bool = True, dropout: float = 0.1, rnn_input_withtime=True, smooth_input=False, **kwargs): super().__init__() RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM self.ts_input_names = ts_input_names self.input_dim_dict = kwargs.get("input_dim_dict",{}) self.max_ts_steps = max([v[1] for v in self.ts_input_names.values()]) self.rnn_input_withtime = rnn_input_withtime self.smooth_input = smooth_input if self.smooth_input: self.smooth_linear_1 = nn.Linear(in_features, in_features) self.smooth_linear_2 = nn.Linear(in_features, in_features) if self.rnn_input_withtime: rnn_in_features = in_features + 1 else: rnn_in_features = in_features self.rnn = RNN(input_size=rnn_in_features, hidden_size=hidden_features, num_layers=hidden_layers, bidirectional=bidirectional, batch_first=True, dropout=dropout if hidden_layers > 1 else 0) if bidirectional: self.linear = nn.Linear(hidden_features * 2, out_features) # self.linear = ResNet(hidden_features * 2, out_features, hidden_features, hidden_layers, norm, output_activation='identity',dropout=dropout) else: self.linear = nn.Linear(hidden_features, out_features) # self.linear = ResNet(hidden_features, out_features, hidden_features, hidden_layers, norm, output_activation='identity',dropout=dropout) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self): self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: x_shape = x[list(x.keys())[0]].shape if len(x_shape) == 3: x = {k:v.reshape(-1,v.shape[-1]) for k,v in x.items()} input = {} for k,v in x.items(): if k in self.ts_input_names: ts_steps = self.ts_input_names.get(k)[1] else: ts_steps = 1 input[k] = v.reshape(v.shape[0],ts_steps,-1) if ts_steps < self.max_ts_steps: # input[k] = torch.cat([input[k][:,:1].repeat(1, self.max_ts_steps+1-ts_steps, 1), input[k]],axis=1) input[k] = torch.cat([input[k], input[k][:,-1:].repeat(1, self.max_ts_steps-ts_steps, 1)],axis=1) if len(input) == 1: rnn_input = list(input.values())[0] else: rnn_input = torch.cat(list(input.values()),axis=-1) if self.smooth_input: rnn_input_pool = F.avg_pool1d(rnn_input.permute(0,2,1), kernel_size=3, stride=1, padding=1).permute(0,2,1) rnn_input = self.smooth_linear_1(rnn_input_pool) + self.smooth_linear_2(rnn_input-rnn_input_pool) if self.rnn_input_withtime: time_steps = torch.arange(rnn_input.shape[1]).unsqueeze(0).unsqueeze(2).expand(rnn_input.shape[0], rnn_input.shape[1], 1).float().to(rnn_input.device) time_steps_normalized = time_steps / (rnn_input.shape[1] - 1) rnn_input = torch.cat([rnn_input, time_steps_normalized], dim=2) out, _ = self.rnn(rnn_input,None) out = self.linear(out[:, -1, :]) if len(x_shape) == 3: logits = out.reshape(x_shape[0],x_shape[1],-1) else: logits = out.reshape(x_shape[0],-1) return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class TsTransformerPolicy(torch.nn.Module): input_is_dict = True """Initializes a recurrent policy network instance for ts_node. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features in the RNN. hidden_layers (int): The number of layers in the RNN. dist_config (list): The configuration for the distributions used in the model. backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, ts_input_names : dict = {}, dropout: float = 0, smooth_input: bool = False, **kwargs): super().__init__() self.smooth_input = smooth_input self.ts_input_names = ts_input_names self.input_dim_dict = kwargs.get("input_dim_dict",{}) self.max_ts_steps = max([v[1] for v in self.ts_input_names.values()]) if self.smooth_input: self.smooth_linear_1 = nn.Linear(in_features, in_features) self.smooth_linear_2 = nn.Linear(in_features, in_features) self.backbone = TsTransformer(in_features=in_features, out_features=out_features, hidden_layers=hidden_layers, hidden_features=hidden_features, dropout=dropout) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self): self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: x_shape = x[list(x.keys())[0]].shape if len(x_shape) == 3: x = {k:v.reshape(-1,v.shape[-1]) for k,v in x.items()} input = {} for k,v in x.items(): if k in self.ts_input_names: ts_steps = self.ts_input_names.get(k)[1] else: ts_steps = 1 input[k] = v.reshape(v.shape[0],ts_steps,-1) if ts_steps < self.max_ts_steps+1: input[k] = torch.cat([input[k][:,:1].repeat(1, self.max_ts_steps+1-ts_steps, 1), input[k]],axis=1) if len(input) == 1: transformer_input = list(input.values())[0] else: transformer_input = torch.cat(list(input.values()),axis=-1) if self.smooth_input: transformer_input_pool = F.avg_pool1d(transformer_input.permute(0,2,1), kernel_size=3, stride=1, padding=1).permute(0,2,1) transformer_input = self.smooth_linear_1(transformer_input_pool) + self.smooth_linear_2(transformer_input-transformer_input_pool) out = self.backbone(transformer_input) if len(x_shape) == 3: logits = out.reshape(x_shape[0],x_shape[1],-1) else: logits = out.reshape(x_shape[0],-1) return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class RecurrentRESPolicy(torch.nn.Module): def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, backbone_type : Union[str, np.str_] ='gru', rnn_hidden_features : int = 64, window_size : int = 0, **kwargs): super().__init__() self.kwargs = kwargs RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM self.in_feature_embed = nn.Sequential( nn.Linear(in_features, hidden_features), ACTIVATION_CREATORS['relu'](hidden_features), nn.Dropout(0.1), nn.LayerNorm(hidden_features) ) self.side_net = nn.Sequential( ResBlock(hidden_features, hidden_features), nn.Dropout(0.1), nn.LayerNorm(hidden_features), ResBlock(hidden_features, hidden_features), nn.Dropout(0.1), nn.LayerNorm(hidden_features) ) self.rnn = RNN(hidden_features, rnn_hidden_features, hidden_layers) self.backbone = MLP(hidden_features + rnn_hidden_features, out_features, hidden_features, 1) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config) self.hidden_layers = hidden_layers self.hidden_features = hidden_features self.rnn_hidden_features = rnn_hidden_features self.window_size = window_size self.hidden_que = deque(maxlen=self.window_size) # [(h_0, c_0), (h_1, c_1), ...] self.h = None
[docs] def reset(self): self.hidden_que = deque(maxlen=self.window_size) self.h = None
[docs] def preprocess_window(self, a_embed): shape = a_embed.shape if len(self.hidden_que) < self.window_size: # do Not pop if self.hidden_que: h = self.hidden_que[0] else: h = (torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed), torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed)) else: h = self.hidden_que.popleft() hidden_cat = torch.concat([h[0]] + [hidden[0] for hidden in self.hidden_que] + [torch.zeros_like(h[0]).to(h[0])], dim=1) # (layers, bs * 4, dim) cell_cat = torch.concat([h[1]] + [hidden[1] for hidden in self.hidden_que] + [torch.zeros_like(h[1]).to(h[1])], dim=1) # (layers, bs * 4, dim) a_embed = torch.repeat_interleave(a_embed, repeats=len(self.hidden_que)+2, dim=0).view(1, (len(self.hidden_que)+2)*shape[1], -1) # (1, bs * 4, dim) return a_embed, hidden_cat, cell_cat
[docs] def postprocess_window(self, rnn_output, hidden_cat, cell_cat): hidden_cat = torch.chunk(hidden_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim) cell_cat = torch.chunk(cell_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim) rnn_output = torch.chunk(rnn_output, chunks=len(self.hidden_que)+2, dim=1) # tuples of (1, bs, dim) self.hidden_que = deque(zip(hidden_cat[1:], cell_cat[1:]), maxlen=self.window_size) # important to discrad the first element !!! rnn_output = rnn_output[0] # (1, bs, dim) return rnn_output
[docs] def rnn_forward(self, a_embed, joint_train : bool): if self.window_size > 0: a_embed, hidden_cat, cell_cat = self.preprocess_window(a_embed) if joint_train: rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat)) else: with torch.no_grad(): rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat)) rnn_output = self.postprocess_window(rnn_output, hidden_cat, cell_cat) #(1, bs, dim) else: if joint_train: rnn_output, self.h = self.rnn(a_embed, self.h) #(1, bs, dim) else: with torch.no_grad(): rnn_output, self.h = self.rnn(a_embed, self.h) return rnn_output
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution: if field == 'bc': shape = x.shape joint_train = True if len(shape) == 2: # (bs, dim) x_embed = self.in_feature_embed(x) side_output = self.side_net(x_embed) a_embed = x_embed.unsqueeze(0) # [1, bs, dim] rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim] rnn_output = rnn_output.squeeze(0) # (bs, dim) logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim) else: assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}" self.reset() output = [] for i in range(shape[0]): a = x[i] a = a.unsqueeze(0) #(1, bs, dim) a_embed = self.in_feature_embed(a) # (1, bs, dim) side_output = self.side_net(a_embed) rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim] backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim) output.append(backbone_output) logits = torch.concat(output, dim=0) elif field == 'mail': shape = x.shape joint_train = False if len(shape) == 2: x_embed = self.in_feature_embed(x) side_output = self.side_net(x_embed) a_embed = x_embed.unsqueeze(0) rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim] rnn_output = rnn_output.squeeze(0) # (bs, dim) logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim) else: assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}" self.reset() output = [] for i in range(shape[0]): a = x[i] a = a.unsqueeze(0) #(1, bs, dim) a_embed = self.in_feature_embed(a) # (1, bs, dim) side_output = self.side_net(a_embed) rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim] backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim) output.append(backbone_output) logits = torch.concat(output, dim=0) else: raise NotImplementedError(f"unknow field: {field} in RNN training !") return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class ContextualPolicy(torch.nn.Module): def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, backbone_type : Union[str, np.str_] ='contextual_gru', **kwargs): super().__init__() self.kwargs = kwargs RNN = torch.nn.GRU if backbone_type == 'contextual_gru' else torch.nn.LSTM self.preprocess_mlp = MLP(in_features, hidden_features, 0, 0, output_activation='leakyrelu') self.rnn = RNN(hidden_features, hidden_features, 1) self.mlp = MLP(hidden_features + in_features, out_features, hidden_features, hidden_layers) self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self): self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: in_feature = x if len(x.shape) == 2: x = x.unsqueeze(0) x = self.preprocess_mlp(x) rnn_output, self.h = self.rnn(x, self.h) rnn_output = rnn_output.squeeze(0) else: x = self.preprocess_mlp(x) rnn_output, self.h = self.rnn(x) logits = self.mlp(torch.cat((in_feature, rnn_output), dim=-1)) return self.dist_wrapper(logits, adapt_std, **kwargs)
[docs] class EnsembleFeedForwardPolicy(torch.nn.Module): """Policy for using mlp, resnet and transformer backbone. Args: in_features : The number of input features, or a dictionary describing the input distribution out_features : The number of output features hidden_features : The number of hidden features in each layer hidden_layers : The number of hidden layers dist_config : A list of configurations for the distributions to use in the model norm : The type of normalization to apply to the input features hidden_activation : The activation function to use in the hidden layers backbone_type : The type of backbone to use in the model use_multihead : Whether to use a multihead model use_feature_embed : Whether to use feature embedding """ def __init__(self, in_features : Union[int, dict], out_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, dist_config : list, norm : str = None, hidden_activation : str = 'leakyrelu', backbone_type : Union[str, np.str_] = 'EnsembleRES', use_multihead : bool = False, **kwargs): super().__init__() self.multihead = [] self.ensemble_size = ensemble_size self.dist_config = dist_config self.kwargs = kwargs if isinstance(in_features, dict): in_features = sum(in_features.values()) if backbone_type == 'EnsembleMLP': self.backbone = VectorizedMLP(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation) elif backbone_type == 'EnsembleRES': self.backbone = VectorizedResNet(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm) else: raise NotImplementedError(f'backbone type {backbone_type} is not supported') self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: if isinstance(state, dict): assert 'input_names' in kwargs state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1) output = self.backbone(state) dist = self.dist_wrapper(output, adapt_std, **kwargs) if hasattr(self, "dist_mu_shift"): dist = dist.shift(self.dist_mu_shift) return dist
[docs] def reset(self): pass
[docs] @torch.no_grad() def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True): dist = self(state) return dist.mode if deterministic else dist.sample()
# ------------------------------- Transitions ------------------------------- #
[docs] class FeedForwardTransition(FeedForwardPolicy): r"""Initializes a feedforward transition instance. Args: in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The configuration of the distribution to use. norm (Optional[str]): The normalization method to use. None if no normalization is used. hidden_activation (str): The activation function to use for the hidden layers. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode to use for the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation for 'local' mode. use_feature_embed (bool): Whether to use feature embedding. """ def __init__(self, in_features : Union[int, dict], out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, norm : Optional[str] = None, hidden_activation : str = 'leakyrelu', backbone_type : Union[str, np.str_] = 'mlp', mode : str = 'global', obs_dim : Optional[int] = None, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': dist_types = [config['type'] for config in dist_config] if 'onehot' in dist_types or 'discrete_logistic' in dist_types: warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!') self.mode = 'global' if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(FeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs] def reset(self): pass
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(FeedForwardTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class RecurrentTransition(RecurrentPolicy): r""" Initializes a recurrent transition instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The distribution configuration. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode of the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, backbone_type : Union[str, np.str_] ='gru', mode : str = 'global', obs_dim : Optional[int] = None, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \ "The local mode of transition is not compatible with onehot data! Please fallback to global mode!" if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(RecurrentTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, norm, dist_config, backbone_type, **kwargs)
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(RecurrentTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class TsRecurrentTransition(TsRecurrentPolicy): input_is_dict = True r""" Initializes a recurrent transition instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The distribution configuration. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode of the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, backbone_type : Union[str, np.str_] ='gru', mode : str = 'global', obs_dim : Optional[int] = None, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \ "The local mode of transition is not compatible with onehot data! Please fallback to global mode!" if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(TsRecurrentTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, norm, dist_config, backbone_type, **kwargs)
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(TsRecurrentTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class TsTransformerTransition(TsTransformerPolicy): input_is_dict = True r""" Initializes a recurrent transition instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The distribution configuration. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode of the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation. """ def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, norm : str, dist_config : list, mode : str = 'global', obs_dim : Optional[int] = None, ts_input_names : dict = {}, dropout: float = 0, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \ "The local mode of transition is not compatible with onehot data! Please fallback to global mode!" if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(TsTransformerTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, norm, dist_config, ts_input_names, dropout, **kwargs)
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(TsTransformerTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class RecurrentRESTransition(RecurrentRESPolicy): def __init__(self, in_features : int, out_features : int, hidden_features : int, hidden_layers : int, dist_config : list, backbone_type : Union[str, np.str_] = 'mlp', mode : str = 'global', obs_dim : Optional[int] = None, rnn_hidden_features : int = 64, window_size : int = 0, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \ "The local mode of transition is not compatible with onehot data! Please fallback to global mode!" if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(RecurrentRESTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, dist_config, backbone_type, rnn_hidden_features=rnn_hidden_features, window_size=window_size, **kwargs)
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution: dist = super(RecurrentRESTransition, self).forward(state, adapt_std, field, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
[docs] class EnsembleFeedForwardTransition(EnsembleFeedForwardPolicy): r"""Initializes a feedforward transition instance. Args: in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes. out_features (int): The number of output features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. dist_config (list): The configuration of the distribution to use. norm (Optional[str]): The normalization method to use. None if no normalization is used. hidden_activation (str): The activation function to use for the hidden layers. backbone_type (Union[str, np.str_]): The type of backbone to use. mode (str): The mode to use for the transition. Either 'global' or 'local'. obs_dim (Optional[int]): The dimension of the observation for 'local' mode. use_feature_embed (bool): Whether to use feature embedding. """ def __init__(self, in_features : Union[int, dict], out_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, dist_config : list, norm : Optional[str] = None, hidden_activation : str = 'leakyrelu', backbone_type : Union[str, np.str_] = 'EnsembleRES', mode : str = 'global', obs_dim : Optional[int] = None, **kwargs): self.mode = mode self.obs_dim = obs_dim if self.mode == 'local': dist_types = [config['type'] for config in dist_config] if 'onehot' in dist_types or 'discrete_logistic' in dist_types: warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!') self.mode = 'global' if self.mode == 'local': assert self.obs_dim is not None, \ "For local mode, the dim of observation should be given!" super(EnsembleFeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, ensemble_size, dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs] def reset(self): pass
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution: dist = super(EnsembleFeedForwardTransition, self).forward(state, adapt_std, **kwargs) if self.mode == 'local' and self.obs_dim is not None: dist = dist.shift(state[..., :self.obs_dim]) return dist
# ------------------------------- Matchers ------------------------------- #
[docs] class FeedForwardMatcher(torch.nn.Module): r""" Initializes a feedforward matcher instance. Args: in_features (int): The number of input features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. hidden_activation (str): The activation function to use for the hidden layers. norm (str): The normalization method to use. None if no normalization is used. backbone_type (Union[str, np.str_]): The type of backbone to use. """ def __init__(self, in_features : int, hidden_features : int, hidden_layers : int, hidden_activation : str = 'leakyrelu', norm : str = None, backbone_type : Union[str, np.str_] = 'mlp', **kwargs): super().__init__() if backbone_type == 'mlp': self.backbone = MLP(in_features, 1, hidden_features, hidden_layers, norm, hidden_activation, output_activation='sigmoid') elif backbone_type == 'res': self.backbone = ResNet(in_features, 1, hidden_features, hidden_layers, norm, output_activation='sigmoid') elif backbone_type == 'ft_transformer': self.backbone = FT_Transformer(in_features, 1, hidden_features, hidden_layers=hidden_layers) else: raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: # x.requires_grad = True shape = x.shape if len(shape) == 3: # [horizon, batch_size, dim] x = x.view(-1, shape[-1]) # [horizon * batch_size, dim] assert len(x.shape) == 2, f"len(x.shape) == {len(x.shape)}, expect 2" out = self.backbone(x) # [batch_size, dim] if len(shape) == 3: out = out.view(*shape[:-1], -1) # [horizon, batch_size, dim] return out
[docs] class RecurrentMatcher(torch.nn.Module): def __init__(self, in_features : int, hidden_features : int, hidden_layers : int, backbone_type : Union[str, np.str_] = 'gru', bidirect : bool = False): super().__init__() RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM self.rnn = RNN(in_features, hidden_features, hidden_layers, bidirectional=bidirect) self.output_layer = MLP(hidden_features * (2 if bidirect else 1), 1, 0, 0, output_activation='sigmoid')
[docs] def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor: x = torch.cat(inputs, dim=-1) rnn_output = self.rnn(x)[0] return self.output_layer(rnn_output)
[docs] class HierarchicalMatcher(torch.nn.Module): def __init__(self, in_features : list, hidden_features : int, hidden_layers : int, hidden_activation : int, norm : str = None): super().__init__() self.in_features = in_features process_layers = [] output_layers = [] feature = self.in_features[0] for i in range(1, len(self.in_features)): feature += self.in_features[i] process_layers.append(MLP(feature, hidden_features, hidden_features, hidden_layers, norm, hidden_activation, hidden_activation)) output_layers.append(torch.nn.Linear(hidden_features, 1)) feature = hidden_features self.process_layers = torch.nn.ModuleList(process_layers) self.output_layers = torch.nn.ModuleList(output_layers)
[docs] def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor: assert len(inputs) == len(self.in_features) last_feature = inputs[0] result = 0 for i, process_layer, output_layer in zip(range(1, len(self.in_features)), self.process_layers, self.output_layers): last_feature = torch.cat([last_feature, inputs[i]], dim=-1) last_feature = process_layer(last_feature) result += output_layer(last_feature) return torch.sigmoid(result)
[docs] class VectorizedMatcher(torch.nn.Module): r""" Initializes a feedforward matcher instance. Args: in_features (int): The number of input features. hidden_features (int): The number of hidden features. hidden_layers (int): The number of hidden layers. hidden_activation (str): The activation function to use for the hidden layers. norm (str): The normalization method to use. None if no normalization is used. backbone_type (Union[str, np.str_]): The type of backbone to use. """ def __init__(self, in_features : int, hidden_features : int, hidden_layers : int, ensemble_size : int, hidden_activation : str = 'leakyrelu', norm : str = None, backbone_type : Union[str, np.str_] = 'mlp', **kwargs): super().__init__() self.backbone_type = backbone_type self.ensemble_size = ensemble_size if backbone_type == 'mlp': self.backbone = VectorizedMLP(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation, output_activation='sigmoid') elif backbone_type == 'res': self.backbone = VectorizedResNet(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, output_activation='sigmoid') else: raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs] def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor: x = torch.cat(inputs, dim=-1).detach() # x.requires_grad = True shape = x.shape if len(shape) == 2: # [batch_size, dim] x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0) # [batch_size, dim] --> [ensemble_size, batch_size, dim] assert x.shape[0] == self.ensemble_size out = self.backbone(x) # [ensemble_size, batch_size, dim] return out
[docs] class EnsembleMatcher: def __init__(self, matcher_record: Dict[str, deque], ensemble_size : int = 50, ensemble_choosing_interval: int = 1, config : dict = None) -> None: self.matcher_weights_list = [] self.single_structure_dict = config self.matching_nodes = config['matching_nodes'] self.matching_fit_nodes = config['matching_fit_nodes'] # self.matching_nodes_fit_dict = config['matching_nodes_fit_index'] self.device = torch.device("cpu") self.min_generate_score = 0.0 self.max_expert_score = 1.0 self.selected_ind = [] counter = 0 for i in reversed(range(len(matcher_record['state_dicts']))): if counter % ensemble_choosing_interval == 0: matcher_weight = matcher_record['state_dicts'][i] self.selected_ind.append(i) self.matcher_weights_list.append(matcher_weight) if len(self.matcher_weights_list) >= ensemble_size: break counter += 1 self.ensemble_size = len(self.matcher_weights_list) self.matcher_ensemble = VectorizedMatcher(ensemble_size=self.ensemble_size, **config) self.init_ensemble_network() self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores'])
[docs] def load_min_max(self, expert_scores: deque, generate_scores: deque): self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ] self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ] self.max_expert_score = max(self.expert_scores) self.min_generate_score = min(self.generate_scores) self.range = max(abs(self.max_expert_score), abs(self.min_generate_score))
[docs] def init_ensemble_network(self,): if self.matcher_ensemble.backbone_type == 'mlp': with torch.no_grad(): for ind, state_dict in enumerate(self.matcher_weights_list): for key, val in state_dict.items(): layer_ind = int(key.split('.')[2]) if key.endswith("weight"): self.matcher_ensemble.backbone.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): self.matcher_ensemble.backbone.net[layer_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") elif self.matcher_ensemble.backbone_type == 'res': with torch.no_grad(): for ind, state_dict in enumerate(self.matcher_weights_list): for key, val in state_dict.items(): key_ls = key.split('.') if "skip_net" in key_ls: block_ind = int(key_ls[2]) if key.endswith("weight"): self.matcher_ensemble.backbone.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): self.matcher_ensemble.backbone.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") elif "process_net" in key_ls: block_ind, layer_ind = int(key_ls[2]), int(key_ls[4]) if key.endswith("weight"): self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") else: # linear layer at last block_ind = int(key_ls[2]) if key.endswith("weight"): self.matcher_ensemble.backbone.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): self.matcher_ensemble.backbone.resnet[block_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") else: raise NotImplementedError
[docs] def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor: with torch.no_grad(): scores = self.matcher_ensemble(inputs) # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1] scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size] if aggregation == "mean": return scores.mean(-1) return scores
[docs] def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor: with torch.no_grad(): scores = self.matcher_ensemble(inputs) # [ensemble_size, batch_size, 1] scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size] if clip: assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand" scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores) scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range if aggregation == "mean": return scores.mean(-1) return scores
[docs] def to(self, device): self.device = device self.matcher_ensemble = self.matcher_ensemble.to(device) self.expert_scores = self.expert_scores.to(device) self.generate_scores = self.generate_scores.to(device)
# class SequentialMatcher: # def __init__(self, # matcher_record: Dict[str, deque], # ensemble_size : int = 50, # ensemble_choosing_interval: int = 1, # config : dict = None) -> None: # self.matcher_list = [] # self.single_structure_dict = config # self.matching_nodes = config['matching_nodes'] # self.matching_fit_nodes = config['matching_fit_nodes'] # # self.matching_nodes_fit_dict = config['matching_nodes_fit_index'] # self.device = torch.device("cpu") # self.min_generate_score = 0.0 # self.max_expert_score = 1.0 # self.selected_ind = [] # counter = 0 # for i in reversed(range(len(matcher_record['state_dicts']))): # if counter % ensemble_choosing_interval == 0: # matcher_weight = matcher_record['state_dicts'][i] # self.selected_ind.append(i) # _matcher = FeedForwardMatcher(**config) # _matcher.load_state_dict(matcher_weight) # self.matcher_list.append(_matcher) # if len(self.matcher_list) >= ensemble_size: # break # counter += 1 # self.ensemble_size = len(self.matcher_list) # self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores']) # # prepare for vmap func # self.params, self.buffers = stack_module_state(self.matcher_list) # """ # Construct a "stateless" version of one of the models. It is "stateless" in # the sense that the parameters are meta Tensors and do not have storage. # """ # self.base_model = deepcopy(self.matcher_list[0]) # self.base_model = self.base_model.to('meta') # def vmap_forward(self, x): # def vmap_model(params, buffers, x): # return functional_call(self.base_model, (params, buffers), (x,)) # return vmap(vmap_model, in_dims=(0, 0, None))(self.params, self.buffers, x) # def load_min_max(self, expert_scores: deque, generate_scores: deque): # self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ] # self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind] # [ensemble_size, ] # self.max_expert_score = max(self.expert_scores) # self.min_generate_score = min(self.generate_scores) # self.range = max(abs(self.max_expert_score), abs(self.min_generate_score)) # def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor: # with torch.no_grad(): # scores = self.vmap_forward(inputs) # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1] # scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size] # if aggregation == "mean": # return scores.mean(-1) # return scores # def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor: # with torch.no_grad(): # scores = self.vmap_forward(inputs) # [ensemble_size, batch_size, 1] # scores = scores.squeeze(-1).transpose(1, 0).detach() # [batch_size, ensemble_size] # if clip: # assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand" # scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores) # scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range # if aggregation == "mean": # return scores.mean(-1) # return scores # def to(self, device): # self.device = device # self.matcher_list = [_matcher.to(device) for _matcher in self.matcher_list] # self.expert_scores = self.expert_scores.to(device) # self.generate_scores = self.generate_scores.to(device) # # prepare for vmap func # self.params, self.buffers = stack_module_state(self.matcher_list) # ------------------------------- Others ------------------------------- #
[docs] class VectorizedCritic(VectorizedMLP):
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor: x = super(VectorizedCritic, self).forward(x) return x.squeeze(-1)
# ------------------------------- Test ------------------------------- # if __name__ == "__main__": def test_ensemble_mlp(): rnd = torch.rand((6, 5)) ensemble_size = 3 ensemble_linear = VectorizedMLP(5, 4, hidden_features=16, hidden_layers=2, ensemble_size=ensemble_size) linear1 = MLP(5, 4, hidden_features=16, hidden_layers=2) linear2 = MLP(5, 4, hidden_features=16, hidden_layers=2) linear3 = MLP(5, 4, hidden_features=16, hidden_layers=2) state_dict = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()]) for ind, state_dict in enumerate(state_dict): for key, val in state_dict.items(): layer_ind = int(key.split('.')[1]) if key.endswith("weight"): ensemble_linear.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): ensemble_linear.net[layer_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") ensemble_out = ensemble_linear(rnd) linear1_out = linear1(rnd) linear2_out = linear2(rnd) linear3_out = linear3(rnd) # breakpoint() print((ensemble_out[0, ...] - linear1_out).abs().max()) print((ensemble_out[1, ...] - linear2_out).abs().max()) print((ensemble_out[2, ...] - linear3_out).abs().max()) def test_ensemble_res_distinct(): in_fea = 256 out_fea = 15 hidden_features=256 hidden_layers=4 batch_size = 512 ensemble_size = 3 rnd1 = torch.rand((batch_size, in_fea)) rnd2 = torch.rand((batch_size, in_fea)) rnd3 = torch.rand((batch_size, in_fea)) rnd = torch.stack([rnd1, rnd2, rnd3]) ensemble_linear = VectorizedResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, ensemble_size=ensemble_size, norm=None) linear1 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None) linear2 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None) linear3 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None) state_dicts = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()]) for ind, state_dict in enumerate(state_dicts): for key, val in state_dict.items(): key_ls = key.split('.') if "skip_net" in key_ls: block_ind = int(key_ls[1]) if key.endswith("weight"): ensemble_linear.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): ensemble_linear.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") elif "process_net" in key_ls: block_ind, layer_ind = int(key_ls[1]), int(key_ls[3]) if key.endswith("weight"): ensemble_linear.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): ensemble_linear.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") else: # linear layer at last block_ind = int(key_ls[1]) if key.endswith("weight"): ensemble_linear.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0)) elif key.endswith("bias"): ensemble_linear.resnet[block_ind].bias.data[ind, 0, ...].copy_(val) else: raise NotImplementedError("Only except network params with weight and bias") ensemble_out = ensemble_linear(rnd) linear1_out = linear1(rnd1) linear2_out = linear2(rnd2) linear3_out = linear3(rnd3) # breakpoint() print((ensemble_out[0, ...] - linear1_out).abs().max()) print((ensemble_out[1, ...] - linear2_out).abs().max()) print((ensemble_out[2, ...] - linear3_out).abs().max())
[docs] class Value_Net_VectorizedCritic(nn.Module): r""" Initializes a vectorized linear layer instance. Args: in_features (int): The number of input features. out_features (int): The number of output features. ensemble_size (int): The number of ensembles to use. """ def __init__(self, input_dim_dict, q_hidden_features, q_hidden_layers, num_q_net, *args, **kwargs) -> None: super().__init__() self.kwargs = kwargs self.input_dim_dict = input_dim_dict if isinstance(input_dim_dict, dict): input_dim = sum(input_dim_dict.values()) self.value =VectorizedMLP (input_dim, 1, q_hidden_features, q_hidden_layers, num_q_net) if 'ts_conv_config' in self.kwargs.keys() and self.kwargs['ts_conv_config'] != None: other_state_layer_output = q_hidden_features//2 ts_state_conv_layer_output = q_hidden_features//2 + q_hidden_features%2 self.other_state_layer = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0) all_node_ts = None self.conv_ts_node = [] self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input']) for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): self.conv_ts_node.append('__history__'+tsnode) self.conv_other_node.append('__now__'+tsnode) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] if all_node_ts == None: all_node_ts = node_ts assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}" #ts-1 remove now state kernalsize = all_node_ts-1 self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: if hasattr(self, 'conv_ts_node'): #deal with all ts nodes with ts dim assert hasattr(self, 'conv_ts_node') assert hasattr(self, 'conv_other_node') inputdata = deepcopy(state.detach()) #data is deteched from other node for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys(): assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True batch_size = np.prod(list(inputdata[tsnode].shape[:-1])) original_size = list(inputdata[tsnode].shape[:-1]) node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts'] # inputdata[tsnode] temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1) inputdata['__history__'+tsnode] = temp_data[..., :-1, :] inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1]) state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1) state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1) state_other = self.other_state_layer(state_other) original_size = list(state_other.shape[:-1]) state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1]) state = torch.cat([state_other, state_ts_conv], dim=-1) else: pass # state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1) output = self.value(state) output = output.squeeze(-1) return output