''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2025 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.
    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
from loguru import logger
import time
import torch
import math
import warnings
import torch.nn.functional as F
from copy import deepcopy
from torch import nn#, vmap
# from torch.func import stack_module_state, functional_call
from typing import Optional, Union, List, Dict, cast
from collections import OrderedDict, deque
from revive.computation.dists import *
from revive.computation.utils import *
[docs]
def reglu(x):
    a, b = x.chunk(2, dim=-1)
    return a * F.relu(b) 
[docs]
def geglu(x):
    a, b = x.chunk(2, dim=-1)
    return a * F.gelu(b) 
[docs]
class Swish(nn.Module):
[docs]
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        return x * torch.sigmoid(x) 
 
ACTIVATION_CREATORS = {
    'relu' : lambda dim: nn.ReLU(inplace=True),
    'elu' : lambda dim: nn.ELU(),
    'leakyrelu' : lambda dim: nn.LeakyReLU(negative_slope=0.1, inplace=True),
    'tanh' : lambda dim: nn.Tanh(),
    'sigmoid' : lambda dim: nn.Sigmoid(),
    'identity' : lambda dim: nn.Identity(),
    'prelu' : lambda dim: nn.PReLU(dim),
    'gelu' : lambda dim: nn.GELU(),
    'reglu' : reglu,
    'geglu' : geglu,
    'swish' : lambda dim: Swish(),
}
# -------------------------------- Layers -------------------------------- #
[docs]
class Value_Net(nn.Module):
    r"""
        Initializes a vectorized linear layer instance.
        Args:
            in_features (int): The number of input features.
            out_features (int): The number of output features.
            ensemble_size (int): The number of ensembles to use.
    """
    def __init__(self, 
                 input_dim_dict,
                 value_hidden_features,
                 value_hidden_layers,
                 value_normalization,
                 value_activation,
                 *args, **kwargs) -> None:
        super().__init__()
        self.kwargs = kwargs
        self.input_dim_dict = input_dim_dict
        if isinstance(input_dim_dict, dict):
            input_dim = sum(input_dim_dict.values())
        else:
            input_dim = input_dim_dict
        self.value = MLP(input_dim, 1, value_hidden_features, value_hidden_layers,
                            norm=value_normalization, hidden_activation=value_activation)
        if 'ts_conv_config' in self.kwargs.keys() and  self.kwargs['ts_conv_config'] != None:
            other_state_layer_output = input_dim//2
            ts_state_conv_layer_output = input_dim//2 + input_dim%2
            self.other_state_layer   = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=value_activation)
            all_node_ts = None
            self.conv_ts_node = []
            self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input'])
            for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
                self.conv_ts_node.append('__history__'+tsnode)
                self.conv_other_node.append('__now__'+tsnode)  
                node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
                if all_node_ts == None:
                    all_node_ts = node_ts
                assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}"
            
            #ts-1 remove now state
            kernalsize = all_node_ts-1 
            self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=value_activation)
[docs]
    def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
        if hasattr(self, 'conv_ts_node'):
            #deal with all ts nodes with ts dim
            assert hasattr(self, 'conv_ts_node')
            assert hasattr(self, 'conv_other_node')
            inputdata = deepcopy(state.detach()) #data is deteched from other node 
            for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
                # assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True
                batch_size = np.prod(list(inputdata[tsnode].shape[:-1]))
                original_size = list(inputdata[tsnode].shape[:-1])
                node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
                # inputdata[tsnode]
                temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1)
                inputdata['__history__'+tsnode] = temp_data[..., :-1, :]
                inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1])      
                      
            state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1)            
            state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1)
            
            state_other = self.other_state_layer(state_other)
            original_size = list(state_other.shape[:-1])
            state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1])
            state = torch.cat([state_other, state_ts_conv], dim=-1)
        else:
            state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1)
        output = self.value(state)
        return output 
 
[docs]
class ConvBlock(nn.Module):
    r"""
        Initializes a vectorized linear layer instance.
        Args:
            in_features (int): The number of input features.
            out_features (int): The number of output features.
            ensemble_size (int): The number of ensembles to use.
    """
    def __init__(self, 
                 in_features: int, 
                 out_features: int,
                 conv_time_step: int,
                 output_activation : str = 'identity',
                 ):
        
        
        super().__init__()
        output_activation_creator = ACTIVATION_CREATORS[output_activation]
        # self.in_features = in_features
        # self.out_features = out_features
        self.mlp_net = nn.Sequential(
                    nn.Linear(in_features, out_features),
                    output_activation_creator(out_features)
                )
        self.conv_net = nn.Conv1d(in_channels=out_features, out_channels=out_features, groups=out_features, kernel_size=conv_time_step)
[docs]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mlp_net(x)
        #exchange timestep dim with feature dime
        x = x.transpose(-1,-2) 
        #combine the horizon and batchsize dims
        out = self.conv_net(x).squeeze(-1)       
        return out 
 
[docs]
class VectorizedLinear(nn.Module):
    r"""
        Initializes a vectorized linear layer instance.
        Args:
            in_features (int): The number of input features.
            out_features (int): The number of output features.
            ensemble_size (int): The number of ensembles to use.
    """
    def __init__(self, in_features: int, out_features: int, ensemble_size: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size
        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))
        self.reset_parameters()
[docs]
    def reset_parameters(self):
        # default pytorch init for nn.Linear module
        for layer in range(self.ensemble_size):
            nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.bias, -bound, bound) 
[docs]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        input: [ensemble_size, batch_size, input_size]
        weight: [ensemble_size, input_size, out_size]
        out: [ensemble_size, batch_size, out_size]
        """
        # out = torch.einsum("kij,kjl->kil", x, self.weight) + self.bias
        out = x @ self.weight + self.bias
        # out = torch.bmm(x, self.weight) + self.bias
        return out 
 
[docs]
class EnsembleLinear(nn.Module):
    def __init__(
            self,
            input_dim: int,
            output_dim: int,
            num_ensemble: int,
            weight_decay: float = 0.0
    ) -> None:
        super().__init__()
        self.num_ensemble = num_ensemble
        self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim)))
        self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim)))
        nn.init.trunc_normal_(self.weight, std=1 / (2 * input_dim ** 0.5))
        self.register_parameter("saved_weight", nn.Parameter(self.weight.detach().clone()))
        self.register_parameter("saved_bias", nn.Parameter(self.bias.detach().clone()))
        self.weight_decay = weight_decay
        self.device = torch.device('cpu')
[docs]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        weight = self.weight
        bias = self.bias
        if len(x.shape) == 2:
            x = torch.einsum('ij,bjk->bik', x, weight)
        elif len(x.shape) == 3:
            if x.shape[0] == weight.data.shape[0]:
                x = torch.einsum('bij,bjk->bik', x, weight)
            else:
                x = torch.einsum('cij,bjk->bcik', x, weight)
        elif len(x.shape) == 4:
            if x.shape[0] == weight.data.shape[0]:
                x = torch.einsum('cbij,cjk->cbik', x, weight)
            else:
                x = torch.einsum('cdij,bjk->bcdik', x, weight)
        elif len(x.shape) == 5:
            x = torch.einsum('bcdij,bjk->bcdik', x, weight)
        assert x.shape[0] == bias.shape[0] and x.shape[-1] == bias.shape[-1]
        if len(x.shape) == 4:
            bias = bias.unsqueeze(1)
        elif len(x.shape) == 5:
            bias = bias.unsqueeze(1)
            bias = bias.unsqueeze(1)
        x = x + bias
        return x 
[docs]
    def load_save(self) -> None:
        self.weight.data.copy_(self.saved_weight.data)
        self.bias.data.copy_(self.saved_bias.data) 
[docs]
    def update_save(self, indexes: List[int]) -> None:
        self.saved_weight.data[indexes] = self.weight.data[indexes]
        self.saved_bias.data[indexes] = self.bias.data[indexes] 
[docs]
    def get_decay_loss(self) -> torch.Tensor:
        decay_loss = self.weight_decay * (0.5 * ((self.weight ** 2).sum()))
        return decay_loss 
[docs]
    def to(self, device):
        if not device == self.device:
            self.device = device
            super().to(device)
            self.weight = self.weight.to(self.device)
            self.bias = self.bias.to(self.device)
            self.saved_weight = self.saved_weight.to(self.device)
            self.saved_bias = self.saved_bias.to(self.device) 
 
# -------------------------------- Backbones -------------------------------- #
[docs]
class MLP(nn.Module):
    r"""
        Multi-layer Perceptron
        Args:
        
            in_features : int, features numbers of the input
            out_features : int, features numbers of the output
            hidden_features : int, features numbers of the hidden layers
            hidden_layers : int, numbers of the hidden layers
            norm : str, normalization method between hidden layers, default : None
            hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu'
            output_activation : str, activation function used in output layer, default : 'identity'
    """
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 norm : str = None, 
                 hidden_activation : str = 'leakyrelu', 
                 output_activation : str = 'identity'):
        super(MLP, self).__init__()
        hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation]
        output_activation_creator = ACTIVATION_CREATORS[output_activation]
        if hidden_layers == 0:
            self.net = nn.Sequential(
                nn.Linear(in_features, out_features),
                output_activation_creator(out_features)
            )
        else:
            net = []
            for i in range(hidden_layers):
                net.append(nn.Linear(in_features if i == 0 else hidden_features, hidden_features))
                if norm:
                    if norm == 'ln':
                        net.append(nn.LayerNorm(hidden_features))
                    elif norm == 'bn':
                        net.append(nn.BatchNorm1d(hidden_features))
                    else:
                        raise NotImplementedError(f'{norm} does not supported!')
                net.append(hidden_activation_creator(hidden_features))
            net.append(nn.Linear(hidden_features, out_features))
            net.append(output_activation_creator(out_features))
            self.net = nn.Sequential(*net)
[docs]
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        r""" forward method of MLP only assume the last dim of x matches `in_features` """
        return self.net(x) 
 
[docs]
class VectorizedMLP(nn.Module):
    r"""
        Vectorized MLP
        Args:
            in_features : int, features numbers of the input
            out_features : int, features numbers of the output
            hidden_features : int, features numbers of the hidden layers
            hidden_layers : int, numbers of the hidden layers
            norm : str, normalization method between hidden layers, default : None
            hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu'
            output_activation : str, activation function used in output layer, default : 'identity'
    """
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 ensemble_size : int,
                 norm : str = None, 
                 hidden_activation : str = 'leakyrelu', 
                 output_activation : str = 'identity'):
        super(VectorizedMLP, self).__init__()
        self.ensemble_size = ensemble_size
        hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation]
        output_activation_creator = ACTIVATION_CREATORS[output_activation]
        if hidden_layers == 0:
            self.net = nn.Sequential(
                VectorizedLinear(in_features, out_features, ensemble_size),
                output_activation_creator(out_features)
            )
        else:
            net = []
            for i in range(hidden_layers):
                net.append(VectorizedLinear(in_features if i == 0 else hidden_features, hidden_features, ensemble_size))
                if norm:
                    if norm == 'ln':
                        net.append(nn.LayerNorm(hidden_features))
                    elif norm == 'bn':
                        net.append(nn.BatchNorm1d(hidden_features))
                    else:
                        raise NotImplementedError(f'{norm} does not supported!')
                net.append(hidden_activation_creator(hidden_features))
            net.append(VectorizedLinear(hidden_features, out_features, ensemble_size))
            net.append(output_activation_creator(out_features))
            self.net = nn.Sequential(*net)
[docs]
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        r""" forward method of MLP only assume the last dim of x matches `in_features` """
        if x.dim() == 2:  # [batch_size, dim]
            x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0)  # [ensemble_size, batch_size, x_dim]
        assert x.dim() == 3
        assert x.shape[0] == self.ensemble_size
        return self.net(x) 
 
[docs]
class ResBlock(nn.Module):
    """
    Initializes a residual block instance.
    Args:
        input_feature (int): The number of input features to the block.
        output_feature (int): The number of output features from the block.
        norm (str, optional): The type of normalization to apply to the block. Default is 'ln' for layer normalization.
    """
    def __init__(self, input_feature : int, output_feature : int, norm : str = 'ln', dropout: float = 0):
        super().__init__()
        self.dropout = nn.Dropout(dropout) if dropout else None
        if norm == 'ln':
            norm_class = torch.nn.LayerNorm
            self.process_net = torch.nn.Sequential(
                torch.nn.Linear(input_feature, output_feature),
                norm_class(output_feature),
                torch.nn.ReLU(True),
                torch.nn.Linear(output_feature, output_feature),
                norm_class(output_feature),
                torch.nn.ReLU(True)
            )
        else:
            self.process_net = torch.nn.Sequential(
                torch.nn.Linear(input_feature, output_feature),
                torch.nn.ReLU(True),
                torch.nn.Linear(output_feature, output_feature),
                torch.nn.ReLU(True)
            )
        if not input_feature == output_feature:
            self.skip_net = torch.nn.Linear(input_feature, output_feature)
        else:
            self.skip_net = torch.nn.Identity()
[docs]
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        '''x should be a 2D Tensor due to batchnorm'''
        if self.dropout is not None:
            return self.dropout(self.process_net(x)) + self.skip_net(x) 
        else:
            return self.process_net(x) + self.skip_net(x) 
 
[docs]
class VectorizedResBlock(nn.Module):
    r"""
        Initializes a vectorized linear layer instance.
        Args:
            in_features (int): The number of input features.
            out_features (int): The number of output features.
            ensemble_size (int): The number of ensembles to use.
    """
    def __init__(self, 
                 in_features: int, 
                 out_features: int, 
                 ensemble_size: int, 
                 norm : str = 'ln'):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size
        if norm == 'ln':
            norm_class = torch.nn.LayerNorm
            self.process_net = torch.nn.Sequential(
                VectorizedLinear(in_features, out_features, ensemble_size),
                norm_class(out_features),
                torch.nn.ReLU(True),
                VectorizedLinear(out_features, out_features, ensemble_size),
                norm_class(out_features),
                torch.nn.ReLU(True)
            )
        else:
            self.process_net = torch.nn.Sequential(
                VectorizedLinear(in_features, out_features, ensemble_size),
                torch.nn.ReLU(True),
                VectorizedLinear(out_features, out_features, ensemble_size),
                torch.nn.ReLU(True)
            )
        if not in_features == out_features:
            self.skip_net = VectorizedLinear(in_features, out_features, ensemble_size)
        else:
            self.skip_net = torch.nn.Identity()
[docs]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, batch_size, input_size]
        # out: [ensemble_size, batch_size, out_size]
        out = self.process_net(x) + self.skip_net(x)
        return out 
 
[docs]
class ResNet(torch.nn.Module):
    """
    Initializes a residual neural network (ResNet) instance.
    
    Args:
        in_features (int): The number of input features to the ResNet.
        out_features (int): The number of output features from the ResNet.
        hidden_features (int): The number of hidden features in each residual block.
        hidden_layers (int): The number of residual blocks in the ResNet.
        norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization.
        output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'.
    """
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 norm : str = 'ln',
                 output_activation : str = 'identity',
                 dropout: float = 0):
        super().__init__()
        modules = []
        for i in range(hidden_layers):
            if i == 0:
                modules.append(ResBlock(in_features, hidden_features, norm, dropout))
            else:
                modules.append(ResBlock(hidden_features, hidden_features, norm, dropout))
        modules.append(torch.nn.Linear(hidden_features, out_features))
        modules.append(ACTIVATION_CREATORS[output_activation](out_features))
        self.resnet = torch.nn.Sequential(*modules)
[docs]
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        '''NOTE: reshape is needed since resblock only support 2D Tensor'''
        shape = x.shape
        x = x.view(-1, shape[-1])
        output = self.resnet(x)
        output = output.view(*shape[:-1], -1)
        return output  
 
[docs]
class VectorizedResNet(torch.nn.Module):
    """
    Initializes a residual neural network (ResNet) instance.
    
    Args:
        in_features (int): The number of input features to the ResNet.
        out_features (int): The number of output features from the ResNet.
        hidden_features (int): The number of hidden features in each residual block.
        hidden_layers (int): The number of residual blocks in the ResNet.
        norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization.
        output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'.
    """
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 ensemble_size : int,
                 norm : str = 'ln',
                 output_activation : str = 'identity'):
        super().__init__()
        self.ensemble_size = ensemble_size
        modules = []
        for i in range(hidden_layers):
            if i == 0:
                modules.append(VectorizedResBlock(in_features, hidden_features, ensemble_size, norm))
            else:
                modules.append(VectorizedResBlock(hidden_features, hidden_features, ensemble_size, norm))
        modules.append(VectorizedLinear(hidden_features, out_features, ensemble_size))
        modules.append(ACTIVATION_CREATORS[output_activation](out_features))
        self.resnet = torch.nn.Sequential(*modules)
[docs]
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        '''NOTE: reshape is needed since resblock only support 2D Tensor'''
        shape = x.shape
        if len(shape) == 2:  
            x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0)  # [batch_size, dim] --> [ensemble_size, batch_size, dim]
        assert x.dim() == 3, f"shape == {shape}, with length == {len(shape)}, expect length 3"
        assert x.shape[0] == self.ensemble_size  # [ensemble_size, batch_size, dim]
        output = self.resnet(x)
        return output 
 
[docs]
class TAPE(nn.Module):
    def __init__(self, d_model, max_len=200, scale_factor=1.0):
        super(TAPE, self).__init__()
        pe = torch.zeros(max_len, d_model)  # positional encoding
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin((position * div_term)*(d_model/max_len))
        pe[:, 1::2] = torch.cos((position * div_term)*(d_model/max_len))
        pe = scale_factor * pe
        self.register_buffer('pe', pe)  # this stores the variable in the state_dict (used for non-trainable variables)
[docs]
    def forward(self, x):
        return x + self.pe[:x.shape[1], :] 
 
    
[docs]
class TAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.qtrans = nn.Linear(d_model, d_model, bias=False)
        self.ktrans = nn.Linear(d_model, d_model, bias=False)
        self.vtrans = nn.Linear(d_model, d_model, bias=False)
        self.attn_dropout = []
        if dropout > 0:
            for i in range(nhead):
                self.attn_dropout.append(nn.Dropout(p=dropout))
            self.attn_dropout = nn.ModuleList(self.attn_dropout)
        # input LayerNorm
        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
        # FFN layerNorm
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(d_model, d_model),
            nn.Dropout(p=dropout)
        )
[docs]
    def forward(self, x):
        x = self.norm1(x)
        q = self.qtrans(x)
        k = self.ktrans(x)
        v = self.vtrans(x)
        dim = int(self.d_model / self.nhead)
        att_output = []
        for i in range(self.nhead):
            if i==self.nhead-1:
                qh = q[:, :, i * dim:]
                kh = k[:, :, i * dim:]
                vh = v[:, :, i * dim:]
            else:
                qh = q[:, :, i * dim:(i + 1) * dim]
                kh = k[:, :, i * dim:(i + 1) * dim]
                vh = v[:, :, i * dim:(i + 1) * dim]
            atten_ave_matrixh = torch.softmax(torch.matmul(qh, kh.transpose(1, 2)), dim=-1)
            if self.attn_dropout:
                atten_ave_matrixh = self.attn_dropout[i](atten_ave_matrixh)
            att_output.append(torch.matmul(atten_ave_matrixh, vh))
        att_output = torch.concat(att_output, dim=-1)
        # FFN
        xt = x + att_output
        xt = self.norm2(xt)
        att_output = xt + self.ffn(xt)
        return att_output 
 
    
[docs]
class Tokenizer(nn.Module):
    category_offsets: Optional[torch.Tensor]
    def __init__(
        self,
        d_numerical: int,
        categories: Optional[List[int]],
        d_token: int,
        bias: bool,
    ) -> None:
        super().__init__()
        if categories is None:
            d_bias = d_numerical
            self.category_offsets = None
            self.category_embeddings = None
        else:
            d_bias = d_numerical + len(categories)
            category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
            self.register_buffer('category_offsets', category_offsets)
            self.category_embeddings = nn.Embedding(sum(categories), d_token)
            nn.init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
            print(f'{self.category_embeddings.weight.shape=}')
        # take [CLS] token into account
        self.weight = nn.Parameter(torch.Tensor(d_numerical + 1, d_token))
        self.bias = nn.Parameter(torch.Tensor(d_bias, d_token)) if bias else None
        # The initialization is inspired by nn.Linear
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5))
    @property
    def n_tokens(self) -> int:
        return len(self.weight) + (
            0 if self.category_offsets is None else len(self.category_offsets)
        )
[docs]
    def forward(self, x_num: torch.Tensor, x_cat: Optional[torch.Tensor]):
        x_some = x_num if x_cat is None else x_cat
        assert x_some is not None
        x_num = torch.cat(
            [torch.ones(len(x_some), 1, device=x_some.device)]  # [CLS]
            + ([] if x_num is None else [x_num]),
            dim=1,
        )
        x = self.weight[None] * x_num[:, :, None]
        if x_cat is not None:
            x = torch.cat(
                [x, self.category_embeddings(x_cat + self.category_offsets[None])],
                dim=1,
            )
        if self.bias is not None:
            bias = torch.cat(
                [
                    torch.zeros(1, self.bias.shape[1], device=x.device),
                    self.bias,
                ]
            )
            x = x + bias[None]
        return x 
 
[docs]
class MultiheadAttention(nn.Module):
    def __init__(
        self, d: int, n_heads: int, dropout: float, initialization: str
    ) -> None:
        if n_heads > 1:
            assert d % n_heads == 0
        assert initialization in ['xavier', 'kaiming']
        super().__init__()
        self.W_q = nn.Linear(d, d)
        self.W_k = nn.Linear(d, d)
        self.W_v = nn.Linear(d, d)
        self.W_out = nn.Linear(d, d) if n_heads > 1 else None
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout) if dropout else None
        for m in [self.W_q, self.W_k, self.W_v]:
            if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
                # gain is needed since W_qkv is represented with 3 separate layers
                nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
            nn.init.zeros_(m.bias)
        if self.W_out is not None:
            nn.init.zeros_(self.W_out.bias)
    def _reshape(self, x):
        batch_size, n_tokens, d = x.shape
        d_head = d // self.n_heads
        return (
            x.reshape(batch_size, n_tokens, self.n_heads, d_head)
            .transpose(1, 2)
            .reshape(batch_size * self.n_heads, n_tokens, d_head)
        )
[docs]
    def forward(
        self,
        x_q,
        x_kv,
        key_compression: Optional[nn.Linear],
        value_compression: Optional[nn.Linear],
    ):
        q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
        for tensor in [q, k, v]:
            assert tensor.shape[-1] % self.n_heads == 0
        if key_compression is not None:
            assert value_compression is not None
            k = key_compression(k.transpose(1, 2)).transpose(1, 2)
            v = value_compression(v.transpose(1, 2)).transpose(1, 2)
        else:
            assert value_compression is None
        batch_size = len(q)
        d_head_key = k.shape[-1] // self.n_heads
        d_head_value = v.shape[-1] // self.n_heads
        n_q_tokens = q.shape[1]
        q = self._reshape(q)
        k = self._reshape(k)
        attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1)
        if self.dropout is not None:
            attention = self.dropout(attention)
        x = attention @ self._reshape(v)
        x = (
            x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
            .transpose(1, 2)
            .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
        )
        if self.W_out is not None:
            x = self.W_out(x)
        return x 
 
[docs]
class DistributionWrapper(nn.Module):
    r"""wrap output of Module to distribution"""
    BASE_TYPES = ['normal', 'gmm', 'onehot', 'discrete_logistic']
    SUPPORTED_TYPES = BASE_TYPES + ['mix']
    def __init__(self, distribution_type : str = 'normal', **params):
        super().__init__()
        self.distribution_type = distribution_type
        self.params = params
        assert self.distribution_type in self.SUPPORTED_TYPES, f"{self.distribution_type} is not supported!"
        if self.distribution_type == 'normal':
            self.max_logstd = nn.Parameter(torch.ones(self.params['dim']) * 0.0, requires_grad=True)
            self.min_logstd = nn.Parameter(torch.ones(self.params['dim']) * -7, requires_grad=True)
            if not self.params.get('conditioned_std', True):
                self.logstd = nn.Parameter(torch.zeros(self.params['dim']), requires_grad=True)
        elif self.distribution_type == 'gmm':
            self.max_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * 0, requires_grad=True)
            self.min_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * -10, requires_grad=True)            
            if not self.params.get('conditioned_std', True):
                self.logstd = nn.Parameter(torch.zeros(self.params['mixture'], self.params['dim']), requires_grad=True)
        elif self.distribution_type == 'discrete_logistic':
            self.num = self.params['num']
        elif self.distribution_type == 'mix':
            assert 'dist_config' in self.params.keys(), "You need to provide `dist_config` for Mix distribution"
            self.dist_config = self.params['dist_config']
            self.wrapper_list = []
            self.input_sizes = []
            self.output_sizes = []
            for config in self.dist_config:
                dist_type = config['type']
                assert dist_type in self.SUPPORTED_TYPES, f"{dist_type} is not supported!"
                assert not dist_type == 'mix', "recursive MixDistribution is not supported!"
                self.wrapper_list.append(DistributionWrapper(dist_type, **config))
                self.input_sizes.append(config['dim'])
                self.output_sizes.append(config['output_dim'])
                
            self.wrapper_list = nn.ModuleList(self.wrapper_list)                                     
[docs]
    def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, payload : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        '''
            Warp the given tensor to distribution
            :param adapt_std : it will overwrite the std part of the distribution (optional)
            :param payload : payload will be applied to the output distribution after built (optional)
        '''
        # [ OTHER ] replace the clip with tanh
        # [ OTHER ] better std controlling strategy is needed todo
        if self.distribution_type == 'normal':
            if self.params.get('conditioned_std', True):
                mu, logstd = torch.chunk(x, 2, dim=-1)
                if 'soft_clamp' in kwargs and kwargs['soft_clamp'] == True:
                    # distribution wrapper only for revive_f venv training
                    logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd)
                    std = torch.exp(logstd)
                else:
                    # distribution wrapper for others
                    logstd_logit = logstd
                    max_std = 0.5
                    min_std = 0.001
                    std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std
            else:
                mu, logstd = x, self.logstd
                logstd_logit = self.logstd
                max_std = 0.5
                min_std = 0.001
                std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std
                # std = torch.exp(logstd)
            if payload is not None:
                mu = mu + safe_atanh(payload)
            mu = torch.tanh(mu)
            """replace tanh
            node = kwargs["node"]
            mu = node.processor.process_torch({node.name:mu})[node.name]
            mu = torch.clamp(mu, min=-1.1, max=1.1)
            """
            if adapt_std is not None:
                std = torch.ones_like(mu).to(mu) * adapt_std
            return DiagnalNormal(mu, std)
        elif self.distribution_type == 'gmm':
            if self.params.get('conditioned_std', True):
                logits, mus, logstds = torch.split(x, [self.params['mixture'], 
                                                       self.params['mixture'] * self.params['dim'], 
                                                       self.params['mixture'] * self.params['dim']], dim=-1)
                mus = mus.view(*mus.shape[:-1], self.params['mixture'], self.params['dim'])      
                logstds = logstds.view(*logstds.shape[:-1], self.params['mixture'], self.params['dim'])
            else:
                logits, mus = torch.split(x, [self.params['mixture'], self.params['mixture'] * self.params['dim']], dim=-1)
                logstds = self.logstd
            if payload is not None:
                mus = mus + safe_atanh(payload.unsqueeze(dim=-2))
            mus = torch.tanh(mus)
            stds = adapt_std if adapt_std is not None else torch.exp(soft_clamp(logstds, self.min_logstd, self.max_logstd))
            return GaussianMixture(mus, stds, logits)
        elif self.distribution_type == 'onehot':
            return Onehot(x)
        elif self.distribution_type == 'discrete_logistic':
            mu, logstd = torch.chunk(x, 2, dim=-1)
            logstd = torch.clamp(logstd, -7, 1)
            return DiscreteLogistic(mu, torch.exp(logstd), num=self.num)
        elif self.distribution_type == 'mix':
            xs = torch.split(x, self.output_sizes, dim=-1)
            
            if isinstance(adapt_std, float):
                adapt_stds = [adapt_std] * len(self.input_sizes)
            elif adapt_std is not None:
                adapt_stds = torch.split(adapt_std, self.input_sizes, dim=-1)
            else:
                adapt_stds = [None] * len(self.input_sizes)
            
            if payload is not None:
                payloads = torch.split(payload, self.input_sizes + [payload.shape[-1] - sum(self.input_sizes)], dim=-1)[:-1]
            else:
                payloads = [None] * len(self.input_sizes)
            dists = [wrapper(x, _adapt_std, _payload, **kwargs) for x, _adapt_std, _payload, wrapper in zip(xs, adapt_stds, payloads, self.wrapper_list)]
            return MixDistribution(dists) 
 
# --------------------------------- Policies -------------------------------- #
[docs]
class FeedForwardPolicy(torch.nn.Module):
    """Policy for using mlp, resnet  and transformer backbone.
    
    Args:
        in_features : The number of input features, or a dictionary describing the input distribution
        out_features : The number of output features
        hidden_features : The number of hidden features in each layer
        hidden_layers : The number of hidden layers
        dist_config : A list of configurations for the distributions to use in the model
        norm : The type of normalization to apply to the input features
        hidden_activation : The activation function to use in the hidden layers
        backbone_type : The type of backbone to use in the model
        use_multihead : Whether to use a multihead model
        use_feature_embed : Whether to use feature embedding
    """
    def __init__(self, 
                 in_features : Union[int, dict], 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 dist_config : list,
                 norm : str = None, 
                 hidden_activation : str = 'leakyrelu', 
                 backbone_type : Union[str, np.str_] = 'mlp',
                 use_multihead : bool = False,
                 **kwargs):
        super().__init__()
        self.multihead = []
        self.dist_config = dist_config
        self.kwargs = kwargs
        if isinstance(in_features, dict):
            in_features = sum(in_features.values())
        if not use_multihead:
            if backbone_type == 'mlp':
                self.backbone = MLP(in_features, out_features, hidden_features, hidden_layers, norm, hidden_activation)
            elif backbone_type == 'res':
                self.backbone = ResNet(in_features, out_features, hidden_features, hidden_layers, norm)
            elif backbone_type == 'ft_transformer':
                self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers)
            else:
                raise NotImplementedError(f'backbone type {backbone_type} is not supported')
        else:
            if backbone_type == 'mlp':
                self.backbone = MLP(in_features, hidden_features, hidden_features, hidden_layers, norm, hidden_activation)
            elif backbone_type == 'res':
                self.backbone = ResNet(in_features, hidden_features, hidden_features, hidden_layers, norm)
            elif backbone_type == 'ft_transformer':
                self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers)
            else:
                raise NotImplementedError(f'backbone type {backbone_type} is not supported')
        
            for dist in self.dist_config:
                dist_type = dist["type"]
                output_dim = dist["output_dim"] 
                if dist_type == "onehot" or dist_type == "discrete_logistic":
                    self.multihead.append(MLP(hidden_features, output_dim, 64, 1, norm=None, hidden_activation='leakyrelu')) 
                elif dist_type == "normal":
                    normal_dim = dist["dim"]
                    for dim in range(normal_dim):  
                        self.multihead.append(MLP(hidden_features, int(output_dim // normal_dim), 64, 1, norm=None, hidden_activation='leakyrelu')) 
                else:
                    raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.')
                
            self.multihead = nn.ModuleList(self.multihead)
        self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
        if "ts_conv_config" in self.kwargs and self.kwargs['ts_conv_config'] != None:
            other_state_layer_output = in_features//2
            ts_state_conv_layer_output = in_features//2 + in_features%2
            self.other_state_layer   = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0, output_activation=hidden_activation)
            all_node_ts = None
            self.conv_ts_node = []
            self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input'])
            for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
                self.conv_ts_node.append('__history__'+tsnode)
                self.conv_other_node.append('__now__'+tsnode)  
                node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
                if all_node_ts == None:
                    all_node_ts = node_ts
                assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}"
            
            #ts-1 remove now state
            kernalsize = all_node_ts-1 
            self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize, output_activation=hidden_activation)
[docs]
    def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        if "ts_conv_config" in self.kwargs  and self.kwargs['ts_conv_config'] != None:
            #deal with all ts nodes with ts dim
            # assert 'input_names' in kwargs
            assert hasattr(self, 'conv_ts_node')
            assert hasattr(self, 'conv_other_node')
            inputdata = deepcopy(state) #data is deteched from other node 
            for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
                # assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True
                batch_size = int(np.prod(list(inputdata[tsnode].shape[:-1])))
                original_size = list(inputdata[tsnode].shape[:-1])
                node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
                # inputdata[tsnode]
                temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1)
                inputdata['__history__'+tsnode] = temp_data[..., :-1, :]
                inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1])
            state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1)
            state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1)
            state_other = self.other_state_layer(state_other)
            original_size = list(state_other.shape[:-1])
            state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1])
            state = torch.cat([state_other, state_ts_conv], dim=-1)
        elif isinstance(state, dict):
            assert 'input_names' in kwargs
            state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1)
        if not self.multihead:
            output = self.backbone(state)
        else:
            backbone_output = self.backbone(state) 
            multihead_output = []
            multihead_index = 0
            for dist in self.dist_config:
                dist_type = dist["type"]
                output_dim = dist["output_dim"] 
                if dist_type == "onehot" or dist_type == "discrete_logistic":
                    multihead_output.append(self.multihead[multihead_index](backbone_output))
                    multihead_index += 1
                elif dist_type == "normal":
                    normal_mode_output = []
                    normal_std_output = []
                    for head in self.multihead[multihead_index:]:
                        head_output = head(backbone_output)
                        if head_output.shape[-1] == 1:
                            mode = head_output
                            normal_mode_output.append(mode)
                        else:
                            mode, std = torch.chunk(head_output, 2, axis=-1)
                            normal_mode_output.append(mode)
                            normal_std_output.append(std)          
                    normal_output = torch.cat(normal_mode_output, axis=-1)
                    if normal_std_output:
                        normal_std_output = torch.cat(normal_std_output, axis=-1)
                        normal_output = torch.cat([normal_output, normal_std_output], axis=-1)
                        
                    multihead_output.append(normal_output)
                    break
                else:
                    raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.')                
            output = torch.cat(multihead_output, axis= -1)
        soft_clamp_flag = True if "soft_clamp" in self.kwargs and self.kwargs['soft_clamp'] == True else False
        dist = self.dist_wrapper(output, adapt_std, soft_clamp=soft_clamp_flag, **kwargs)
        if hasattr(self, "dist_mu_shift"):
            dist = dist.shift(self.dist_mu_shift)
        return dist 
    
[docs]
    @torch.no_grad()
    def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True):
        dist = self(state)
        return dist.mode if deterministic else dist.sample() 
 
[docs]
class RecurrentPolicy(torch.nn.Module):
    """Initializes a recurrent policy network instance.
    
    Args:
        in_features (int): The number of input features.
        out_features (int): The number of output features.
        hidden_features (int): The number of hidden features in the RNN.
        hidden_layers (int): The number of layers in the RNN.
        dist_config (list): The configuration for the distributions used in the model.
        backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'.
    """
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 norm : str,
                 dist_config : list, 
                 backbone_type : Union[str, np.str_] ='gru',
                 **kwargs):
        super().__init__()
        RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
        rnn_hidden_features = 256
        rnn_hidden_layers = min(hidden_layers, 3)
        rnn_output_features = min(out_features,8)
        self.rnn = RNN(in_features, rnn_hidden_features, rnn_hidden_layers)
        self.rnn_mlp = MLP(rnn_hidden_features, rnn_output_features, 0, 0)
        
        self.backbone = ResNet(in_features+rnn_output_features, out_features, hidden_features, hidden_layers, norm)
        
        self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
    def reset(self):
        self.h = None 
[docs]
    def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        x_shape = x.shape
        if len(x_shape) == 1:
            rnn_output, self.h = self.rnn(x.unsqueeze(0).unsqueeze(0), self.h)
            rnn_output = rnn_output.squeeze(0).squeeze(0)
        elif len(x_shape) == 2:
            rnn_output, self.h = self.rnn(x.unsqueeze(0), self.h)
            rnn_output = rnn_output.squeeze(0)
        else:
            rnn_output, self.h = self.rnn(x, self.h)
        rnn_output = self.rnn_mlp(rnn_output)
        logits = self.backbone(torch.cat([x, rnn_output],axis=-1))
        
        return self.dist_wrapper(logits, adapt_std, **kwargs) 
 
[docs]
class TsRecurrentPolicy(torch.nn.Module):
    input_is_dict = True
    """Initializes a recurrent policy network instance for ts_node.
    
    Args:
        in_features (int): The number of input features.
        out_features (int): The number of output features.
        hidden_features (int): The number of hidden features in the RNN.
        hidden_layers (int): The number of layers in the RNN.
        dist_config (list): The configuration for the distributions used in the model.
        backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'.
    """
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 norm : str,
                 dist_config : list, 
                 backbone_type : Union[str, np.str_] ='gru',
                 ts_input_names : dict  = {},
                 bidirectional: bool = True,
                 dropout: float = 0.1,
                 rnn_input_withtime=True,
                 smooth_input=False,
                 **kwargs):
        super().__init__()
        RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
        self.ts_input_names = ts_input_names
        self.input_dim_dict = kwargs.get("input_dim_dict",{})
        self.max_ts_steps = max([v[1] for v in self.ts_input_names.values()])
        self.rnn_input_withtime = rnn_input_withtime
        self.smooth_input = smooth_input
        if self.smooth_input:
            self.smooth_linear_1 = nn.Linear(in_features, in_features)
            self.smooth_linear_2 = nn.Linear(in_features, in_features)
        if self.rnn_input_withtime:
            rnn_in_features = in_features + 1
        else:
            rnn_in_features = in_features
        self.rnn = RNN(input_size=rnn_in_features, 
                        hidden_size=hidden_features, 
                        num_layers=hidden_layers, 
                        bidirectional=bidirectional, 
                        batch_first=True,
                        dropout=dropout if hidden_layers > 1 else 0)
        if bidirectional:
            self.linear = nn.Linear(hidden_features * 2, out_features) 
            # self.linear = ResNet(hidden_features * 2, out_features, hidden_features, hidden_layers, norm, output_activation='identity',dropout=dropout)
        else:
            self.linear = nn.Linear(hidden_features, out_features)
            # self.linear = ResNet(hidden_features, out_features, hidden_features, hidden_layers, norm, output_activation='identity',dropout=dropout)
        
        self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
    def reset(self):
        self.h = None 
[docs]
    def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        x_shape = x[list(x.keys())[0]].shape
        if len(x_shape) == 3:
            x = {k:v.reshape(-1,v.shape[-1]) for k,v in x.items()}
        input = {}
        for k,v in x.items():
            if k in self.ts_input_names:
                ts_steps = self.ts_input_names.get(k)[1]
            else:
                ts_steps = 1
            input[k] = v.reshape(v.shape[0],ts_steps,-1)
            if ts_steps < self.max_ts_steps:
                # input[k] = torch.cat([input[k][:,:1].repeat(1, self.max_ts_steps+1-ts_steps, 1),  input[k]],axis=1)
                input[k] = torch.cat([input[k], input[k][:,-1:].repeat(1, self.max_ts_steps-ts_steps, 1)],axis=1)
        if len(input) == 1:
            rnn_input = list(input.values())[0]
        else:
            rnn_input = torch.cat(list(input.values()),axis=-1)
        if self.smooth_input:
            rnn_input_pool = F.avg_pool1d(rnn_input.permute(0,2,1), kernel_size=3, stride=1, padding=1).permute(0,2,1)
            rnn_input = self.smooth_linear_1(rnn_input_pool) + self.smooth_linear_2(rnn_input-rnn_input_pool)
        if self.rnn_input_withtime:
            time_steps = torch.arange(rnn_input.shape[1]).unsqueeze(0).unsqueeze(2).expand(rnn_input.shape[0], rnn_input.shape[1], 1).float().to(rnn_input.device)
            time_steps_normalized = time_steps / (rnn_input.shape[1] - 1)
            rnn_input = torch.cat([rnn_input, time_steps_normalized], dim=2)
        out, _ = self.rnn(rnn_input,None)
        out = self.linear(out[:, -1, :])  
        if len(x_shape) == 3:
            logits = out.reshape(x_shape[0],x_shape[1],-1)
        else:
            logits = out.reshape(x_shape[0],-1)
        
        return self.dist_wrapper(logits, adapt_std, **kwargs) 
 
[docs]
class RecurrentRESPolicy(torch.nn.Module):
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 dist_config : list,
                 backbone_type : Union[str, np.str_] ='gru',
                 rnn_hidden_features : int = 64,
                 window_size : int = 0,
                 **kwargs):
        super().__init__()
        self.kwargs = kwargs
        RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
        self.in_feature_embed = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            ACTIVATION_CREATORS['relu'](hidden_features),
            nn.Dropout(0.1),
            nn.LayerNorm(hidden_features)
        )
        self.side_net = nn.Sequential(
            ResBlock(hidden_features, hidden_features),
            nn.Dropout(0.1),
            nn.LayerNorm(hidden_features),
            ResBlock(hidden_features, hidden_features),
            nn.Dropout(0.1),
            nn.LayerNorm(hidden_features)
        )
        self.rnn = RNN(hidden_features, rnn_hidden_features, hidden_layers)
        self.backbone = MLP(hidden_features + rnn_hidden_features, out_features, hidden_features, 1)
        self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
        self.hidden_layers = hidden_layers
        self.hidden_features = hidden_features
        self.rnn_hidden_features = rnn_hidden_features
        self.window_size = window_size
        self.hidden_que = deque(maxlen=self.window_size)  # [(h_0, c_0), (h_1, c_1), ...]
        self.h = None
[docs]
    def reset(self):
        self.hidden_que = deque(maxlen=self.window_size)
        self.h = None 
[docs]
    def preprocess_window(self, a_embed):
        shape = a_embed.shape
        if len(self.hidden_que) < self.window_size:  # do Not pop
            if self.hidden_que:
                h = self.hidden_que[0]
            else:
                h = (torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed), torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed))
        else:
            h = self.hidden_que.popleft()
        hidden_cat = torch.concat([h[0]] + [hidden[0] for hidden in self.hidden_que] + [torch.zeros_like(h[0]).to(h[0])], dim=1)  # (layers, bs * 4, dim)
        cell_cat = torch.concat([h[1]] + [hidden[1] for hidden in self.hidden_que] + [torch.zeros_like(h[1]).to(h[1])], dim=1)  # (layers, bs * 4, dim)
        a_embed = torch.repeat_interleave(a_embed, repeats=len(self.hidden_que)+2, dim=0).view(1, (len(self.hidden_que)+2)*shape[1], -1)  # (1, bs * 4, dim)
        return a_embed, hidden_cat, cell_cat 
[docs]
    def postprocess_window(self, rnn_output, hidden_cat, cell_cat):
        hidden_cat = torch.chunk(hidden_cat, chunks=len(self.hidden_que)+2, dim=1)  # tuples of (layers, bs, dim)
        cell_cat = torch.chunk(cell_cat, chunks=len(self.hidden_que)+2, dim=1)  # tuples of (layers, bs, dim)
        rnn_output = torch.chunk(rnn_output, chunks=len(self.hidden_que)+2, dim=1)  # tuples of (1, bs, dim)
        self.hidden_que = deque(zip(hidden_cat[1:], cell_cat[1:]), maxlen=self.window_size)  # important to discrad the first element !!!
        rnn_output = rnn_output[0]  # (1, bs, dim)
        return rnn_output 
[docs]
    def rnn_forward(self, a_embed, joint_train : bool):
        if self.window_size > 0:
            a_embed, hidden_cat, cell_cat = self.preprocess_window(a_embed)
            if joint_train:
                rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat))
            else:
                with torch.no_grad():
                    rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat))
            rnn_output = self.postprocess_window(rnn_output, hidden_cat, cell_cat)  #(1, bs, dim)
        else:
            if joint_train:
                rnn_output, self.h = self.rnn(a_embed, self.h)  #(1, bs, dim)
            else:
                with torch.no_grad():
                    rnn_output, self.h = self.rnn(a_embed, self.h)
        return rnn_output 
[docs]
    def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution:
        if field == 'bc':
            shape = x.shape
            joint_train = True
            if len(shape) == 2:  # (bs, dim)
                x_embed = self.in_feature_embed(x)
                side_output = self.side_net(x_embed)
                a_embed = x_embed.unsqueeze(0)  # [1, bs, dim]
                rnn_output = self.rnn_forward(a_embed, joint_train)  # [1, bs, dim]
                rnn_output = rnn_output.squeeze(0)  # (bs, dim)
                logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1))  # (bs, dim)
            else:
                assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}"
                self.reset()
                output = []
                for i in range(shape[0]):
                    a = x[i]
                    a = a.unsqueeze(0) #(1, bs, dim)
                    a_embed = self.in_feature_embed(a)  # (1, bs, dim)
                    side_output = self.side_net(a_embed)
                    rnn_output = self.rnn_forward(a_embed, joint_train)  # [1, bs, dim]
                    backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1))  # (1, bs, dim)
                    output.append(backbone_output)
                logits = torch.concat(output, dim=0)
        elif field == 'mail':
            shape = x.shape
            joint_train = False
            if len(shape) == 2:
                x_embed = self.in_feature_embed(x)
                side_output = self.side_net(x_embed)
                a_embed = x_embed.unsqueeze(0)
                rnn_output = self.rnn_forward(a_embed, joint_train)  # [1, bs, dim]
                rnn_output = rnn_output.squeeze(0)  # (bs, dim)
                logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1))  # (bs, dim)
            else:
                assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}"
                self.reset()
                output = []
                for i in range(shape[0]):
                    a = x[i]
                    a = a.unsqueeze(0) #(1, bs, dim)
                    a_embed = self.in_feature_embed(a)  # (1, bs, dim)
                    side_output = self.side_net(a_embed)
                    rnn_output = self.rnn_forward(a_embed, joint_train)  # [1, bs, dim]
                    backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1))  # (1, bs, dim)
                    output.append(backbone_output)
                logits = torch.concat(output, dim=0)
        else:
            raise NotImplementedError(f"unknow field: {field} in RNN training !")
        return self.dist_wrapper(logits, adapt_std, **kwargs) 
 
[docs]
class ContextualPolicy(torch.nn.Module):
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 dist_config : list, 
                 backbone_type : Union[str, np.str_] ='contextual_gru',
                 **kwargs):
        super().__init__()
        self.kwargs = kwargs
        RNN = torch.nn.GRU if backbone_type == 'contextual_gru' else torch.nn.LSTM
        self.preprocess_mlp = MLP(in_features, hidden_features, 0, 0, output_activation='leakyrelu')
        self.rnn = RNN(hidden_features, hidden_features, 1)
        self.mlp = MLP(hidden_features + in_features, out_features, hidden_features, hidden_layers)
        self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
    def reset(self):
        self.h = None 
[docs]
    def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        in_feature = x
        if len(x.shape) == 2:
            x = x.unsqueeze(0)
            x = self.preprocess_mlp(x)
            rnn_output, self.h = self.rnn(x, self.h)
            rnn_output = rnn_output.squeeze(0)
        else:
            x = self.preprocess_mlp(x)
            rnn_output, self.h = self.rnn(x)
        logits = self.mlp(torch.cat((in_feature, rnn_output), dim=-1))
        return self.dist_wrapper(logits, adapt_std, **kwargs) 
 
[docs]
class EnsembleFeedForwardPolicy(torch.nn.Module):
    """Policy for using mlp, resnet  and transformer backbone.
    
    Args:
        in_features : The number of input features, or a dictionary describing the input distribution
        out_features : The number of output features
        hidden_features : The number of hidden features in each layer
        hidden_layers : The number of hidden layers
        dist_config : A list of configurations for the distributions to use in the model
        norm : The type of normalization to apply to the input features
        hidden_activation : The activation function to use in the hidden layers
        backbone_type : The type of backbone to use in the model
        use_multihead : Whether to use a multihead model
        use_feature_embed : Whether to use feature embedding
    """
    def __init__(self, 
                 in_features : Union[int, dict], 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int,
                 ensemble_size : int,
                 dist_config : list,
                 norm : str = None, 
                 hidden_activation : str = 'leakyrelu', 
                 backbone_type : Union[str, np.str_] = 'EnsembleRES',
                 use_multihead : bool = False,
                 **kwargs):
        super().__init__()
        self.multihead = []
        self.ensemble_size = ensemble_size
        self.dist_config = dist_config
        self.kwargs = kwargs
        if isinstance(in_features, dict):
            in_features = sum(in_features.values())
        if backbone_type == 'EnsembleMLP':
            self.backbone = VectorizedMLP(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation)
        elif backbone_type == 'EnsembleRES':
            self.backbone = VectorizedResNet(in_features, out_features, hidden_features, hidden_layers, ensemble_size, norm)
        else:
            raise NotImplementedError(f'backbone type {backbone_type} is not supported')
        self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs]
    def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        if isinstance(state, dict):
            assert 'input_names' in kwargs
            state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1)
        output = self.backbone(state)
        dist = self.dist_wrapper(output, adapt_std, **kwargs)
        if hasattr(self, "dist_mu_shift"):
            dist = dist.shift(self.dist_mu_shift)
        return dist 
    
[docs]
    @torch.no_grad()
    def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True):
        dist = self(state)
        return dist.mode if deterministic else dist.sample() 
 
# ------------------------------- Transitions ------------------------------- #
[docs]
class FeedForwardTransition(FeedForwardPolicy):
    r"""Initializes a feedforward transition instance.
    
    Args:
        in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes.
        out_features (int): The number of output features.
        hidden_features (int): The number of hidden features.
        hidden_layers (int): The number of hidden layers.
        dist_config (list): The configuration of the distribution to use.
        norm (Optional[str]): The normalization method to use. None if no normalization is used.
        hidden_activation (str): The activation function to use for the hidden layers.
        backbone_type (Union[str, np.str_]): The type of backbone to use.
        mode (str): The mode to use for the transition. Either 'global' or 'local'.
        obs_dim (Optional[int]): The dimension of the observation for 'local' mode.
        use_feature_embed (bool): Whether to use feature embedding.
    """
    def __init__(self, 
                 in_features : Union[int, dict], 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 dist_config : list,
                 norm : Optional[str] = None, 
                 hidden_activation : str = 'leakyrelu', 
                 backbone_type : Union[str, np.str_] = 'mlp',
                 mode : str = 'global',
                 obs_dim : Optional[int] = None,
                 **kwargs):
        
        self.mode = mode
        self.obs_dim = obs_dim
        if self.mode == 'local': 
            dist_types = [config['type'] for config in dist_config]
            if 'onehot' in dist_types or 'discrete_logistic' in dist_types:
                warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!')
                self.mode = 'global'
        if self.mode == 'local': assert self.obs_dim is not None, \
            
"For local mode, the dim of observation should be given!"
        
        super(FeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, 
                                                    dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs]
    def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        dist = super(FeedForwardTransition, self).forward(state, adapt_std, **kwargs)
        if self.mode == 'local' and self.obs_dim is not None:
            dist = dist.shift(state[..., :self.obs_dim])
        return dist 
 
[docs]
class RecurrentTransition(RecurrentPolicy):
    r"""
    Initializes a recurrent transition instance.
    
    Args:
        in_features (int): The number of input features.
        out_features (int): The number of output features.
        hidden_features (int): The number of hidden features.
        hidden_layers (int): The number of hidden layers.
        dist_config (list): The distribution configuration.
        backbone_type (Union[str, np.str_]): The type of backbone to use.
        mode (str): The mode of the transition. Either 'global' or 'local'.
        obs_dim (Optional[int]): The dimension of the observation.
    """
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 norm : str,
                 dist_config : list, 
                 backbone_type : Union[str, np.str_] ='gru',
                 mode : str = 'global',
                 obs_dim : Optional[int] = None,
                 **kwargs):
        self.mode = mode
        self.obs_dim = obs_dim
        if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
            
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
        if self.mode == 'local': assert self.obs_dim is not None, \
            
"For local mode, the dim of observation should be given!"
        super(RecurrentTransition, self).__init__(in_features, 
                                                  out_features, 
                                                  hidden_features, 
                                                  hidden_layers, 
                                                  norm,
                                                  dist_config, 
                                                  backbone_type, 
                                                  **kwargs)
[docs]
    def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        dist = super(RecurrentTransition, self).forward(state, adapt_std, **kwargs)
        if self.mode == 'local' and self.obs_dim is not None:
            dist = dist.shift(state[..., :self.obs_dim])
        return dist 
 
[docs]
class TsRecurrentTransition(TsRecurrentPolicy):
    input_is_dict = True
    r"""
    Initializes a recurrent transition instance.
    
    Args:
        in_features (int): The number of input features.
        out_features (int): The number of output features.
        hidden_features (int): The number of hidden features.
        hidden_layers (int): The number of hidden layers.
        dist_config (list): The distribution configuration.
        backbone_type (Union[str, np.str_]): The type of backbone to use.
        mode (str): The mode of the transition. Either 'global' or 'local'.
        obs_dim (Optional[int]): The dimension of the observation.
    """
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 norm : str,
                 dist_config : list, 
                 backbone_type : Union[str, np.str_] ='gru',
                 mode : str = 'global',
                 obs_dim : Optional[int] = None,
                 **kwargs):
        self.mode = mode
        self.obs_dim = obs_dim
        if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
            
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
        if self.mode == 'local': assert self.obs_dim is not None, \
            
"For local mode, the dim of observation should be given!"
        super(TsRecurrentTransition, self).__init__(in_features, 
                                                  out_features, 
                                                  hidden_features, 
                                                  hidden_layers, 
                                                  norm,
                                                  dist_config, 
                                                  backbone_type, 
                                                  **kwargs)
[docs]
    def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        dist = super(TsRecurrentTransition, self).forward(state, adapt_std, **kwargs)
        if self.mode == 'local' and self.obs_dim is not None:
            dist = dist.shift(state[..., :self.obs_dim])
        return dist 
 
    
[docs]
class RecurrentRESTransition(RecurrentRESPolicy):
    def __init__(self, 
                 in_features : int, 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 dist_config : list,
                 backbone_type : Union[str, np.str_] = 'mlp',
                 mode : str = 'global',
                 obs_dim : Optional[int] = None,
                 rnn_hidden_features : int = 64,
                 window_size : int = 0,
                 **kwargs):
        
        self.mode = mode
        self.obs_dim = obs_dim
        if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
            
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
        if self.mode == 'local': assert self.obs_dim is not None, \
            
"For local mode, the dim of observation should be given!"
        super(RecurrentRESTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, 
                                                  dist_config, backbone_type, 
                                                  rnn_hidden_features=rnn_hidden_features, window_size=window_size,
                                                  **kwargs)
[docs]
    def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution:
        dist = super(RecurrentRESTransition, self).forward(state, adapt_std, field, **kwargs)
        if self.mode == 'local' and self.obs_dim is not None:
            dist = dist.shift(state[..., :self.obs_dim])
        return dist 
 
[docs]
class EnsembleFeedForwardTransition(EnsembleFeedForwardPolicy):
    r"""Initializes a feedforward transition instance.
    
    Args:
        in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes.
        out_features (int): The number of output features.
        hidden_features (int): The number of hidden features.
        hidden_layers (int): The number of hidden layers.
        dist_config (list): The configuration of the distribution to use.
        norm (Optional[str]): The normalization method to use. None if no normalization is used.
        hidden_activation (str): The activation function to use for the hidden layers.
        backbone_type (Union[str, np.str_]): The type of backbone to use.
        mode (str): The mode to use for the transition. Either 'global' or 'local'.
        obs_dim (Optional[int]): The dimension of the observation for 'local' mode.
        use_feature_embed (bool): Whether to use feature embedding.
    """
    def __init__(self, 
                 in_features : Union[int, dict], 
                 out_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 ensemble_size : int,
                 dist_config : list,
                 norm : Optional[str] = None, 
                 hidden_activation : str = 'leakyrelu', 
                 backbone_type : Union[str, np.str_] = 'EnsembleRES',
                 mode : str = 'global',
                 obs_dim : Optional[int] = None,
                 **kwargs):
        
        self.mode = mode
        self.obs_dim = obs_dim
        if self.mode == 'local': 
            dist_types = [config['type'] for config in dist_config]
            if 'onehot' in dist_types or 'discrete_logistic' in dist_types:
                warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!')
                self.mode = 'global'
        if self.mode == 'local': assert self.obs_dim is not None, \
            
"For local mode, the dim of observation should be given!"
        
        super(EnsembleFeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers, ensemble_size,
                                                    dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs]
    def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
        dist = super(EnsembleFeedForwardTransition, self).forward(state, adapt_std, **kwargs)
        if self.mode == 'local' and self.obs_dim is not None:
            dist = dist.shift(state[..., :self.obs_dim])
        return dist 
 
# ------------------------------- Matchers ------------------------------- #
[docs]
class FeedForwardMatcher(torch.nn.Module):
    r"""
    Initializes a feedforward matcher instance.
    
    Args:
        in_features (int): The number of input features.
        hidden_features (int): The number of hidden features.
        hidden_layers (int): The number of hidden layers.
        hidden_activation (str): The activation function to use for the hidden layers.
        norm (str): The normalization method to use. None if no normalization is used.
        backbone_type (Union[str, np.str_]): The type of backbone to use.
    """
    def __init__(self, 
                 in_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 hidden_activation : str = 'leakyrelu', 
                 norm : str = None,
                 backbone_type : Union[str, np.str_] = 'mlp',
                 **kwargs):
        super().__init__()
        if backbone_type == 'mlp':
            self.backbone = MLP(in_features, 1, hidden_features, hidden_layers, norm, hidden_activation, output_activation='sigmoid')
        elif backbone_type == 'res':
            self.backbone = ResNet(in_features, 1, hidden_features, hidden_layers, norm, output_activation='sigmoid')
        elif backbone_type == 'ft_transformer':
                self.backbone = FT_Transformer(in_features, 1, hidden_features, hidden_layers=hidden_layers)
        else:
            raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs]
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        # x.requires_grad = True
        shape = x.shape
        if len(shape) == 3:  # [horizon, batch_size, dim]
            x = x.view(-1, shape[-1])  # [horizon * batch_size, dim]
        assert len(x.shape) == 2, f"len(x.shape) == {len(x.shape)}, expect 2"
        out = self.backbone(x)  # [batch_size, dim]
        if len(shape) == 3:
            out = out.view(*shape[:-1], -1)  # [horizon, batch_size, dim]
        return out 
 
[docs]
class RecurrentMatcher(torch.nn.Module):
    def __init__(self, 
                 in_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 backbone_type : Union[str, np.str_] = 'gru', 
                 bidirect : bool = False):
        super().__init__()
        RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
        self.rnn = RNN(in_features, hidden_features, hidden_layers, bidirectional=bidirect)
        self.output_layer = MLP(hidden_features * (2 if bidirect else 1), 1, 0, 0, output_activation='sigmoid')
[docs]
    def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
        x = torch.cat(inputs, dim=-1)
        rnn_output = self.rnn(x)[0]
        return self.output_layer(rnn_output) 
 
[docs]
class HierarchicalMatcher(torch.nn.Module):
    def __init__(self, 
                 in_features : list, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 hidden_activation : int, 
                 norm : str = None):
        super().__init__()
        self.in_features = in_features
        process_layers = []
        output_layers = []
        feature = self.in_features[0]
        for i in range(1, len(self.in_features)):
            feature += self.in_features[i]
            process_layers.append(MLP(feature, hidden_features, hidden_features, hidden_layers, norm, hidden_activation, hidden_activation))
            output_layers.append(torch.nn.Linear(hidden_features, 1))
            feature = hidden_features
        self.process_layers = torch.nn.ModuleList(process_layers)
        self.output_layers = torch.nn.ModuleList(output_layers)
[docs]
    def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
        assert len(inputs) == len(self.in_features)
        last_feature = inputs[0]
        result = 0
        for i, process_layer, output_layer in zip(range(1, len(self.in_features)), self.process_layers, self.output_layers):
            last_feature = torch.cat([last_feature, inputs[i]], dim=-1)
            last_feature = process_layer(last_feature)
            result += output_layer(last_feature)
        return torch.sigmoid(result) 
 
[docs]
class VectorizedMatcher(torch.nn.Module):
    r"""
    Initializes a feedforward matcher instance.
    
    Args:
        in_features (int): The number of input features.
        hidden_features (int): The number of hidden features.
        hidden_layers (int): The number of hidden layers.
        hidden_activation (str): The activation function to use for the hidden layers.
        norm (str): The normalization method to use. None if no normalization is used.
        backbone_type (Union[str, np.str_]): The type of backbone to use.
    """
    def __init__(self, 
                 in_features : int, 
                 hidden_features : int, 
                 hidden_layers : int, 
                 ensemble_size : int,
                 hidden_activation : str = 'leakyrelu', 
                 norm : str = None,
                 backbone_type : Union[str, np.str_] = 'mlp',
                 **kwargs):
        super().__init__()
        self.backbone_type = backbone_type
        self.ensemble_size = ensemble_size
        if backbone_type == 'mlp':
            self.backbone = VectorizedMLP(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, hidden_activation, output_activation='sigmoid')
        elif backbone_type == 'res':
            self.backbone = VectorizedResNet(in_features, 1, hidden_features, hidden_layers, ensemble_size, norm, output_activation='sigmoid')
        else:
            raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs]
    def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
        x = torch.cat(inputs, dim=-1).detach()
        # x.requires_grad = True
        shape = x.shape
        if len(shape) == 2:  # [batch_size, dim]
            x = x.unsqueeze(0).repeat_interleave(self.ensemble_size, dim=0)  # [batch_size, dim] --> [ensemble_size, batch_size, dim]
        assert x.shape[0] == self.ensemble_size
        out = self.backbone(x)  # [ensemble_size, batch_size, dim]
        return out 
 
[docs]
class EnsembleMatcher:
    def __init__(self, 
                 matcher_record: Dict[str, deque], 
                 ensemble_size : int = 50,
                 ensemble_choosing_interval: int = 1,
                 config : dict = None) -> None:
        
        self.matcher_weights_list = []
        self.single_structure_dict = config
        self.matching_nodes = config['matching_nodes']
        self.matching_fit_nodes = config['matching_fit_nodes']
        # self.matching_nodes_fit_dict = config['matching_nodes_fit_index']
        
        self.device = torch.device("cpu")
        self.min_generate_score = 0.0
        self.max_expert_score = 1.0
        self.selected_ind = []
        counter = 0
        for i in reversed(range(len(matcher_record['state_dicts']))):
            if counter % ensemble_choosing_interval == 0:
                matcher_weight = matcher_record['state_dicts'][i]
                self.selected_ind.append(i)
                self.matcher_weights_list.append(matcher_weight)
                if len(self.matcher_weights_list) >= ensemble_size:
                    break
            counter += 1
        self.ensemble_size = len(self.matcher_weights_list)
        self.matcher_ensemble = VectorizedMatcher(ensemble_size=self.ensemble_size, **config)
        self.init_ensemble_network()
        self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores'])
[docs]
    def load_min_max(self, expert_scores: deque, generate_scores: deque):
        self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind]  # [ensemble_size, ]
        self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind]  # [ensemble_size, ]
        self.max_expert_score = max(self.expert_scores)
        self.min_generate_score = min(self.generate_scores)
        self.range = max(abs(self.max_expert_score), abs(self.min_generate_score)) 
[docs]
    def init_ensemble_network(self,):
        if self.matcher_ensemble.backbone_type == 'mlp':
            with torch.no_grad():
                for ind, state_dict in enumerate(self.matcher_weights_list):
                    for key, val in state_dict.items():
                        layer_ind = int(key.split('.')[2])
                        if key.endswith("weight"):
                            self.matcher_ensemble.backbone.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
                        elif key.endswith("bias"):
                            self.matcher_ensemble.backbone.net[layer_ind].bias.data[ind, 0, ...].copy_(val)
                        else:
                            raise NotImplementedError("Only except network params with weight and bias")
        elif self.matcher_ensemble.backbone_type == 'res':
            with torch.no_grad():
                for ind, state_dict in enumerate(self.matcher_weights_list):
                    for key, val in state_dict.items():
                        key_ls = key.split('.')
                        if "skip_net" in key_ls:
                            block_ind = int(key_ls[2])
                            if key.endswith("weight"):
                                self.matcher_ensemble.backbone.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0))
                            elif key.endswith("bias"):
                                self.matcher_ensemble.backbone.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val)
                            else:
                                raise NotImplementedError("Only except network params with weight and bias")
                        elif "process_net" in key_ls:
                            block_ind, layer_ind = int(key_ls[2]), int(key_ls[4])
                            if key.endswith("weight"):
                                self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
                            elif key.endswith("bias"):
                                self.matcher_ensemble.backbone.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val)
                            else:
                                raise NotImplementedError("Only except network params with weight and bias")
                        else:
                            # linear layer at last
                            block_ind = int(key_ls[2])
                            if key.endswith("weight"):
                                self.matcher_ensemble.backbone.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
                            elif key.endswith("bias"):
                                self.matcher_ensemble.backbone.resnet[block_ind].bias.data[ind, 0, ...].copy_(val)
                            else:
                                raise NotImplementedError("Only except network params with weight and bias")
        else:
            raise NotImplementedError 
[docs]
    def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor:
        with torch.no_grad():
            scores = self.matcher_ensemble(inputs)  # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1]
            scores = scores.squeeze(-1).transpose(1, 0).detach()  # [batch_size, ensemble_size]
            if aggregation == "mean":
                return scores.mean(-1)
            return scores 
[docs]
    def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor:
        with torch.no_grad():
            scores = self.matcher_ensemble(inputs)  # [ensemble_size, batch_size, 1]
            scores = scores.squeeze(-1).transpose(1, 0).detach()  # [batch_size, ensemble_size]
            if clip:
                assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand"
                scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores)
                scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range
            if aggregation == "mean":
                return scores.mean(-1)
            return scores 
[docs]
    def to(self, device):
        self.device = device
        self.matcher_ensemble = self.matcher_ensemble.to(device)
        self.expert_scores = self.expert_scores.to(device)
        self.generate_scores = self.generate_scores.to(device) 
 
# class SequentialMatcher:
#     def __init__(self, 
#                  matcher_record: Dict[str, deque], 
#                  ensemble_size : int = 50,
#                  ensemble_choosing_interval: int = 1,
#                  config : dict = None) -> None:
        
#         self.matcher_list = []
#         self.single_structure_dict = config
#         self.matching_nodes = config['matching_nodes']
#         self.matching_fit_nodes = config['matching_fit_nodes']
#         # self.matching_nodes_fit_dict = config['matching_nodes_fit_index']
        
#         self.device = torch.device("cpu")
#         self.min_generate_score = 0.0
#         self.max_expert_score = 1.0
#         self.selected_ind = []
#         counter = 0
#         for i in reversed(range(len(matcher_record['state_dicts']))):
#             if counter % ensemble_choosing_interval == 0:
#                 matcher_weight = matcher_record['state_dicts'][i]
#                 self.selected_ind.append(i)
#                 _matcher = FeedForwardMatcher(**config)
#                 _matcher.load_state_dict(matcher_weight)
#                 self.matcher_list.append(_matcher)
#                 if len(self.matcher_list) >= ensemble_size:
#                     break
#             counter += 1
#         self.ensemble_size = len(self.matcher_list)
#         self.load_min_max(matcher_record['expert_scores'], matcher_record['generated_scores'])
#         # prepare for vmap func
#         self.params, self.buffers = stack_module_state(self.matcher_list)
#         """
#         Construct a "stateless" version of one of the models. It is "stateless" in
#         the sense that the parameters are meta Tensors and do not have storage.
#         """
#         self.base_model = deepcopy(self.matcher_list[0])
#         self.base_model = self.base_model.to('meta')
#     def vmap_forward(self, x):
#         def vmap_model(params, buffers, x):
#             return functional_call(self.base_model, (params, buffers), (x,))
        
#         return vmap(vmap_model, in_dims=(0, 0, None))(self.params, self.buffers, x)
#     def load_min_max(self, expert_scores: deque, generate_scores: deque):
#         self.expert_scores = torch.tensor(expert_scores, dtype=torch.float32, device=self.device)[self.selected_ind]  # [ensemble_size, ]
#         self.generate_scores = torch.tensor(generate_scores, dtype=torch.float32, device=self.device)[self.selected_ind]  # [ensemble_size, ]
#         self.max_expert_score = max(self.expert_scores)
#         self.min_generate_score = min(self.generate_scores)
#         self.range = max(abs(self.max_expert_score), abs(self.min_generate_score))
#     def run_raw_scores(self, inputs: torch.Tensor, aggregation=None) -> torch.Tensor:
#         with torch.no_grad():
#             scores = self.vmap_forward(inputs)  # [ensemble_size, batch_size, 1], [ensemble_size, horizon, batch_size, 1]
#             scores = scores.squeeze(-1).transpose(1, 0).detach()  # [batch_size, ensemble_size]
#             if aggregation == "mean":
#                 return scores.mean(-1)
#             return scores
#     def run_scores(self, inputs: torch.Tensor, aggregation: str=None, clip: bool=True) -> torch.Tensor:
#         with torch.no_grad():
#             scores = self.vmap_forward(inputs)  # [ensemble_size, batch_size, 1]
#             scores = scores.squeeze(-1).transpose(1, 0).detach()  # [batch_size, ensemble_size]
#             if clip:
#                 assert self.expert_scores is not None and self.generate_scores is not None, "if clip==True, you should call load_min_max() beforehand"
#                 scores = torch.clamp(scores, min=self.generate_scores, max=self.expert_scores)
#                 scores = (scores - self.generate_scores) / (self.expert_scores - self.generate_scores + 1e-7) * self.range
#             if aggregation == "mean":
#                 return scores.mean(-1)
#             return scores
#     def to(self, device):
#         self.device = device
#         self.matcher_list = [_matcher.to(device) for _matcher in self.matcher_list]
#         self.expert_scores = self.expert_scores.to(device)
#         self.generate_scores = self.generate_scores.to(device)
#         # prepare for vmap func
#         self.params, self.buffers = stack_module_state(self.matcher_list)
# ------------------------------- Others ------------------------------- #
[docs]
class VectorizedCritic(VectorizedMLP):
[docs]
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        x = super(VectorizedCritic, self).forward(x)  
        return x.squeeze(-1) 
 
# ------------------------------- Test ------------------------------- #
if __name__ == "__main__":
    def test_ensemble_mlp():
        rnd = torch.rand((6, 5))
        ensemble_size = 3
        ensemble_linear = VectorizedMLP(5, 4, hidden_features=16, hidden_layers=2, ensemble_size=ensemble_size)
        linear1 = MLP(5, 4, hidden_features=16, hidden_layers=2)
        linear2 = MLP(5, 4, hidden_features=16, hidden_layers=2)
        linear3 = MLP(5, 4, hidden_features=16, hidden_layers=2)
        state_dict = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()])
        for ind, state_dict in enumerate(state_dict):
            for key, val in state_dict.items():
                layer_ind = int(key.split('.')[1])
                if key.endswith("weight"):
                    ensemble_linear.net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
                elif key.endswith("bias"):
                    ensemble_linear.net[layer_ind].bias.data[ind, 0, ...].copy_(val)
                else:
                    raise NotImplementedError("Only except network params with weight and bias")
        ensemble_out = ensemble_linear(rnd)
        linear1_out = linear1(rnd)
        linear2_out = linear2(rnd)
        linear3_out = linear3(rnd)
        # breakpoint()
        print((ensemble_out[0, ...] - linear1_out).abs().max())
        print((ensemble_out[1, ...] - linear2_out).abs().max())
        print((ensemble_out[2, ...] - linear3_out).abs().max())
    def test_ensemble_res_distinct():
        in_fea = 256
        out_fea = 15
        hidden_features=256
        hidden_layers=4
        batch_size = 512
        ensemble_size = 3
        rnd1 = torch.rand((batch_size, in_fea))
        rnd2 = torch.rand((batch_size, in_fea))
        rnd3 = torch.rand((batch_size, in_fea))
        rnd = torch.stack([rnd1, rnd2, rnd3])
        ensemble_linear = VectorizedResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, ensemble_size=ensemble_size, norm=None)
        linear1 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None)
        linear2 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None)
        linear3 = ResNet(in_fea, out_fea, hidden_features=hidden_features, hidden_layers=hidden_layers, norm=None)
        state_dicts = deque([linear1.state_dict(), linear2.state_dict(), linear3.state_dict()])
        for ind, state_dict in enumerate(state_dicts):
            for key, val in state_dict.items():
                key_ls = key.split('.')
                if "skip_net" in key_ls:
                    block_ind = int(key_ls[1])
                    if key.endswith("weight"):
                        ensemble_linear.resnet[block_ind].skip_net.weight.data[ind, ...].copy_(val.transpose(1, 0))
                    elif key.endswith("bias"):
                        ensemble_linear.resnet[block_ind].skip_net.bias.data[ind, 0, ...].copy_(val)
                    else:
                        raise NotImplementedError("Only except network params with weight and bias")
                elif "process_net" in key_ls:
                    block_ind, layer_ind = int(key_ls[1]), int(key_ls[3])
                    if key.endswith("weight"):
                        ensemble_linear.resnet[block_ind].process_net[layer_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
                    elif key.endswith("bias"):
                        ensemble_linear.resnet[block_ind].process_net[layer_ind].bias.data[ind, 0, ...].copy_(val)
                    else:
                        raise NotImplementedError("Only except network params with weight and bias")
                else:
                    # linear layer at last
                    block_ind = int(key_ls[1])
                    if key.endswith("weight"):
                        ensemble_linear.resnet[block_ind].weight.data[ind, ...].copy_(val.transpose(1, 0))
                    elif key.endswith("bias"):
                        ensemble_linear.resnet[block_ind].bias.data[ind, 0, ...].copy_(val)
                    else:
                        raise NotImplementedError("Only except network params with weight and bias")
        ensemble_out = ensemble_linear(rnd)
        linear1_out = linear1(rnd1)
        linear2_out = linear2(rnd2)
        linear3_out = linear3(rnd3)
        # breakpoint()
        print((ensemble_out[0, ...] - linear1_out).abs().max())
        print((ensemble_out[1, ...] - linear2_out).abs().max())
        print((ensemble_out[2, ...] - linear3_out).abs().max())
[docs]
class Value_Net_VectorizedCritic(nn.Module):
    r"""
        Initializes a vectorized linear layer instance.
        Args:
            in_features (int): The number of input features.
            out_features (int): The number of output features.
            ensemble_size (int): The number of ensembles to use.
    """
    def __init__(self, 
                 input_dim_dict,
                 q_hidden_features,
                 q_hidden_layers,
                 num_q_net,
                 *args, **kwargs) -> None:
        super().__init__()
        self.kwargs = kwargs
        self.input_dim_dict = input_dim_dict
        if isinstance(input_dim_dict, dict):
            input_dim = sum(input_dim_dict.values())
        self.value =VectorizedMLP (input_dim, 
                                    1, 
                                    q_hidden_features, 
                                    q_hidden_layers, 
                                    num_q_net)
        
        if 'ts_conv_config' in self.kwargs.keys() and  self.kwargs['ts_conv_config'] != None:
            other_state_layer_output = q_hidden_features//2
            ts_state_conv_layer_output = q_hidden_features//2 + q_hidden_features%2
            self.other_state_layer   = MLP(self.kwargs['ts_conv_net_config']['other_net_input'], other_state_layer_output, 0, 0)
            all_node_ts = None
            self.conv_ts_node = []
            self.conv_other_node = deepcopy(self.kwargs['ts_conv_config']['no_ts_input'])
            for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
                self.conv_ts_node.append('__history__'+tsnode)
                self.conv_other_node.append('__now__'+tsnode)  
                node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
                if all_node_ts == None:
                    all_node_ts = node_ts
                assert all_node_ts==node_ts, f"expect ts step == {all_node_ts}. However got {tsnode} with ts step {node_ts}"
            
            #ts-1 remove now state
            kernalsize = all_node_ts-1 
            self.ts_state_conv_layer = ConvBlock(self.kwargs['ts_conv_net_config']['conv_net_input'], ts_state_conv_layer_output, kernalsize)
[docs]
    def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
        if hasattr(self, 'conv_ts_node'):
            #deal with all ts nodes with ts dim
            assert hasattr(self, 'conv_ts_node')
            assert hasattr(self, 'conv_other_node')
            inputdata = deepcopy(state.detach()) #data is deteched from other node 
            for tsnode in self.kwargs['ts_conv_config']['with_ts_input'].keys():
                assert self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['endpoint'] == True
                batch_size = np.prod(list(inputdata[tsnode].shape[:-1]))
                original_size = list(inputdata[tsnode].shape[:-1])
                node_ts = self.kwargs['ts_conv_config']['with_ts_input'][tsnode]['ts']
                # inputdata[tsnode]
                temp_data = inputdata[tsnode].reshape(batch_size, node_ts, -1)
                inputdata['__history__'+tsnode] = temp_data[..., :-1, :]
                inputdata['__now__'+tsnode] = temp_data[..., -1, :].reshape([*original_size,-1])      
                      
            state_other = torch.cat([inputdata[key] for key in self.conv_other_node], dim=-1)            
            state_ts_conv = torch.cat([inputdata[key] for key in self.conv_ts_node], dim=-1)
            
            state_other = self.other_state_layer(state_other)
            original_size = list(state_other.shape[:-1])
            state_ts_conv = self.ts_state_conv_layer(state_ts_conv).reshape([*original_size,-1])
            state = torch.cat([state_other, state_ts_conv], dim=-1)
        else:
            pass
            # state = torch.cat([state[key].detach() for key in self.input_dim_dict.keys()], dim=-1)
        output = self.value(state)
        output = output.squeeze(-1)
        return output