import torch
import math
import warnings
from torch import nn
from typing import Optional, Union, List, Dict, cast
from collections import OrderedDict, deque
from revive.computation.dists import *
from revive.computation.utils import *
[docs]def reglu(x):
a, b = x.chunk(2, dim=-1)
return a * F.relu(b)
[docs]def geglu(x):
a, b = x.chunk(2, dim=-1)
return a * F.gelu(b)
[docs]class Swish(nn.Module):
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
'relu' : lambda dim: nn.ReLU(inplace=True),
'elu' : lambda dim: nn.ELU(),
'leakyrelu' : lambda dim: nn.LeakyReLU(negative_slope=0.1, inplace=True),
'tanh' : lambda dim: nn.Tanh(),
'sigmoid' : lambda dim: nn.Sigmoid(),
'identity' : lambda dim: nn.Identity(),
'prelu' : lambda dim: nn.PReLU(dim),
'gelu' : lambda dim: nn.GELU(),
'reglu' : reglu,
'geglu' : geglu,
'swish' : lambda dim: Swish(),
# -------------------------------- Backbones -------------------------------- #
[docs]class MLP(nn.Module):
Multi-layer Perceptron
in_features : int, features numbers of the input
out_features : int, features numbers of the output
hidden_features : int, features numbers of the hidden layers
hidden_layers : int, numbers of the hidden layers
norm : str, normalization method between hidden layers, default : None
hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu'
output_activation : str, activation function used in output layer, default : 'identity'
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str = None,
hidden_activation : str = 'leakyrelu',
output_activation : str = 'identity'):
super(MLP, self).__init__()
hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation]
output_activation_creator = ACTIVATION_CREATORS[output_activation]
if hidden_layers == 0:
self.net = nn.Sequential(
nn.Linear(in_features, out_features),
net = []
for i in range(hidden_layers):
net.append(nn.Linear(in_features if i == 0 else hidden_features, hidden_features))
if norm:
if norm == 'ln':
elif norm == 'bn':
raise NotImplementedError(f'{norm} does not supported!')
net.append(nn.Linear(hidden_features, out_features))
self.net = nn.Sequential(*net)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor:
r""" forward method of MLP only assume the last dim of x matches `in_features` """
return self.net(x)
[docs]class ResBlock(nn.Module):
Initializes a residual block instance.
input_feature (int): The number of input features to the block.
output_feature (int): The number of output features from the block.
norm (str, optional): The type of normalization to apply to the block. Default is 'ln' for layer normalization.
def __init__(self, input_feature : int, output_feature : int, norm : str = 'ln'):
if norm == 'ln':
norm_class = torch.nn.LayerNorm
self.process_net = torch.nn.Sequential(
torch.nn.Linear(input_feature, output_feature),
torch.nn.Linear(output_feature, output_feature),
self.process_net = torch.nn.Sequential(
torch.nn.Linear(input_feature, output_feature),
torch.nn.Linear(output_feature, output_feature),
if not input_feature == output_feature:
self.skip_net = torch.nn.Linear(input_feature, output_feature)
self.skip_net = torch.nn.Identity()
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor:
'''x should be a 2D Tensor due to batchnorm'''
return self.process_net(x) + self.skip_net(x)
[docs]class VectorizedLinear(nn.Module):
Initializes a vectorized linear layer instance.
in_features (int): The number of input features.
out_features (int): The number of output features.
ensemble_size (int): The number of ensembles to use.
def __init__(self, in_features: int, out_features: int, ensemble_size: int):
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size
self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))
[docs] def reset_parameters(self):
# default pytorch init for nn.Linear module
for layer in range(self.ensemble_size):
nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
# input: [ensemble_size, batch_size, input_size]
# weight: [ensemble_size, input_size, out_size]
# out: [ensemble_size, batch_size, out_size]
return x @ self.weight + self.bias
[docs]class VectorizedMLP(nn.Module):
Vectorized MLP
in_features : int, features numbers of the input
out_features : int, features numbers of the output
hidden_features : int, features numbers of the hidden layers
hidden_layers : int, numbers of the hidden layers
norm : str, normalization method between hidden layers, default : None
hidden_activation : str, activation function used in hidden layers, default : 'leakyrelu'
output_activation : str, activation function used in output layer, default : 'identity'
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
nums : int,
norm : str = None,
hidden_activation : str = 'leakyrelu',
output_activation : str = 'identity'):
super(VectorizedMLP, self).__init__()
self.nums = nums
hidden_activation_creator = ACTIVATION_CREATORS[hidden_activation]
output_activation_creator = ACTIVATION_CREATORS[output_activation]
if hidden_layers == 0:
self.net = nn.Sequential(
VectorizedLinear(in_features, out_features, nums),
net = []
for i in range(hidden_layers):
net.append(VectorizedLinear(in_features if i == 0 else hidden_features, hidden_features, nums),)
net.append(VectorizedLinear(hidden_features, out_features, nums))
self.net = nn.Sequential(*net)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor:
r""" forward method of MLP only assume the last dim of x matches `in_features` """
if x.dim() != 3:
assert x.dim() == 2
# [nums, batch_size, x_dim]
x = x.unsqueeze(0).repeat_interleave(self.nums, dim=0)
assert x.dim() == 3
assert x.shape[0] == self.nums
return self.net(x)
[docs]class ResNet(torch.nn.Module):
Initializes a residual neural network (ResNet) instance.
in_features (int): The number of input features to the ResNet.
out_features (int): The number of output features from the ResNet.
hidden_features (int): The number of hidden features in each residual block.
hidden_layers (int): The number of residual blocks in the ResNet.
norm (str, optional): The type of normalization to apply to the ResNet. Default is 'bn' for batch normalization.
output_activation (str, optional): The type of activation function to apply to the output of the ResNet. Default is 'identity'.
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
norm : str = 'bn',
output_activation : str = 'identity'):
modules = []
for i in range(hidden_layers):
if i == 0:
modules.append(ResBlock(in_features, hidden_features, norm))
modules.append(ResBlock(hidden_features, hidden_features, norm))
modules.append(torch.nn.Linear(hidden_features, out_features))
self.resnet = torch.nn.Sequential(*modules)
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor:
'''NOTE: reshape is needed since resblock only support 2D Tensor'''
shape = x.shape
x = x.view(-1, shape[-1])
output = self.resnet(x)
output = output.view(*shape[:-1], -1)
return output
[docs]class Tokenizer(nn.Module):
category_offsets: Optional[torch.Tensor]
def __init__(
d_numerical: int,
categories: Optional[List[int]],
d_token: int,
bias: bool,
) -> None:
if categories is None:
d_bias = d_numerical
self.category_offsets = None
self.category_embeddings = None
d_bias = d_numerical + len(categories)
category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
self.register_buffer('category_offsets', category_offsets)
self.category_embeddings = nn.Embedding(sum(categories), d_token)
nn.init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
# take [CLS] token into account
self.weight = nn.Parameter(torch.Tensor(d_numerical + 1, d_token))
self.bias = nn.Parameter(torch.Tensor(d_bias, d_token)) if bias else None
# The initialization is inspired by nn.Linear
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5))
def n_tokens(self) -> int:
return len(self.weight) + (
0 if self.category_offsets is None else len(self.category_offsets)
[docs] def forward(self, x_num: torch.Tensor, x_cat: Optional[torch.Tensor]):
x_some = x_num if x_cat is None else x_cat
assert x_some is not None
x_num = torch.cat(
[torch.ones(len(x_some), 1, device=x_some.device)] # [CLS]
+ ([] if x_num is None else [x_num]),
x = self.weight[None] * x_num[:, :, None]
if x_cat is not None:
x = torch.cat(
[x, self.category_embeddings(x_cat + self.category_offsets[None])],
if self.bias is not None:
bias = torch.cat(
torch.zeros(1, self.bias.shape[1], device=x.device),
x = x + bias[None]
return x
[docs]class MultiheadAttention(nn.Module):
def __init__(
self, d: int, n_heads: int, dropout: float, initialization: str
) -> None:
if n_heads > 1:
assert d % n_heads == 0
assert initialization in ['xavier', 'kaiming']
self.W_q = nn.Linear(d, d)
self.W_k = nn.Linear(d, d)
self.W_v = nn.Linear(d, d)
self.W_out = nn.Linear(d, d) if n_heads > 1 else None
self.n_heads = n_heads
self.dropout = nn.Dropout(dropout) if dropout else None
for m in [self.W_q, self.W_k, self.W_v]:
if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
# gain is needed since W_qkv is represented with 3 separate layers
nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
if self.W_out is not None:
def _reshape(self, x):
batch_size, n_tokens, d = x.shape
d_head = d // self.n_heads
return (
x.reshape(batch_size, n_tokens, self.n_heads, d_head)
.transpose(1, 2)
.reshape(batch_size * self.n_heads, n_tokens, d_head)
[docs] def forward(
key_compression: Optional[nn.Linear],
value_compression: Optional[nn.Linear],
q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
for tensor in [q, k, v]:
assert tensor.shape[-1] % self.n_heads == 0
if key_compression is not None:
assert value_compression is not None
k = key_compression(k.transpose(1, 2)).transpose(1, 2)
v = value_compression(v.transpose(1, 2)).transpose(1, 2)
assert value_compression is None
batch_size = len(q)
d_head_key = k.shape[-1] // self.n_heads
d_head_value = v.shape[-1] // self.n_heads
n_q_tokens = q.shape[1]
q = self._reshape(q)
k = self._reshape(k)
attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1)
if self.dropout is not None:
attention = self.dropout(attention)
x = attention @ self._reshape(v)
x = (
x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
.transpose(1, 2)
.reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
if self.W_out is not None:
x = self.W_out(x)
return x
[docs]class DistributionWrapper(nn.Module):
r"""wrap output of Module to distribution"""
BASE_TYPES = ['normal', 'gmm', 'onehot', 'discrete_logistic']
def __init__(self, distribution_type : str = 'normal', **params):
self.distribution_type = distribution_type
self.params = params
assert self.distribution_type in self.SUPPORTED_TYPES, f"{self.distribution_type} is not supported!"
if self.distribution_type == 'normal':
self.max_logstd = nn.Parameter(torch.ones(self.params['dim']) * 0, requires_grad=True)
self.min_logstd = nn.Parameter(torch.ones(self.params['dim']) * -10, requires_grad=True)
if not self.params.get('conditioned_std', True):
self.logstd = nn.Parameter(torch.zeros(self.params['dim']), requires_grad=True)
elif self.distribution_type == 'gmm':
self.max_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * 0, requires_grad=True)
self.min_logstd = nn.Parameter(torch.ones(self.params['mixture'], self.params['dim']) * -10, requires_grad=True)
if not self.params.get('conditioned_std', True):
self.logstd = nn.Parameter(torch.zeros(self.params['mixture'], self.params['dim']), requires_grad=True)
elif self.distribution_type == 'discrete_logistic':
self.num = self.params['num']
elif self.distribution_type == 'mix':
assert 'dist_config' in self.params.keys(), "You need to provide `dist_config` for Mix distribution"
self.dist_config = self.params['dist_config']
self.wrapper_list = []
self.input_sizes = []
self.output_sizes = []
for config in self.dist_config:
dist_type = config['type']
assert dist_type in self.SUPPORTED_TYPES, f"{dist_type} is not supported!"
assert not dist_type == 'mix', "recursive MixDistribution is not supported!"
self.wrapper_list.append(DistributionWrapper(dist_type, **config))
self.wrapper_list = nn.ModuleList(self.wrapper_list)
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, payload : Optional[torch.Tensor] = None) -> ReviveDistribution:
Warp the given tensor to distribution
:param adapt_std : it will overwrite the std part of the distribution (optional)
:param payload : payload will be applied to the output distribution after built (optional)
# [ OTHER ] replace the clip with tanh
# [ OTHER ] better std controlling strategy is needed todo
if self.distribution_type == 'normal':
if self.params.get('conditioned_std', True):
mu, logstd = torch.chunk(x, 2, dim=-1)
logstd_logit = logstd
max_std = 0.5
min_std = 0.001
std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std
mu, logstd = x, self.logstd
logstd_logit = self.logstd
max_std = 0.5
min_std = 0.001
std = (torch.tanh(logstd_logit) + 1) / 2 * (max_std - min_std) + min_std
# std = torch.exp(logstd)
if payload is not None:
mu = mu + safe_atanh(payload)
mu = torch.tanh(mu)
# std = adapt_std if adapt_std is not None else torch.exp(soft_clamp(logstd, self.min_logstd.to(logstd.device), self.max_logstd.to(logstd.device)))
return DiagnalNormal(mu, std)
elif self.distribution_type == 'gmm':
if self.params.get('conditioned_std', True):
logits, mus, logstds = torch.split(x, [self.params['mixture'],
self.params['mixture'] * self.params['dim'],
self.params['mixture'] * self.params['dim']], dim=-1)
mus = mus.view(*mus.shape[:-1], self.params['mixture'], self.params['dim'])
logstds = logstds.view(*logstds.shape[:-1], self.params['mixture'], self.params['dim'])
logits, mus = torch.split(x, [self.params['mixture'], self.params['mixture'] * self.params['dim']], dim=-1)
logstds = self.logstd
if payload is not None:
mus = mus + safe_atanh(payload.unsqueeze(dim=-2))
mus = torch.tanh(mus)
stds = adapt_std if adapt_std is not None else torch.exp(soft_clamp(logstds, self.min_logstd, self.max_logstd))
return GaussianMixture(mus, stds, logits)
elif self.distribution_type == 'onehot':
return Onehot(x)
elif self.distribution_type == 'discrete_logistic':
mu, logstd = torch.chunk(x, 2, dim=-1)
logstd = torch.clamp(logstd, -7, 1)
return DiscreteLogistic(mu, torch.exp(logstd), num=self.num)
elif self.distribution_type == 'mix':
xs = torch.split(x, self.output_sizes, dim=-1)
if adapt_std is not None:
adapt_stds = torch.split(adapt_std, self.input_sizes, dim=-1)
adapt_stds = [None] * len(self.input_sizes)
if payload is not None:
payloads = torch.split(payload, self.input_sizes + [payload.shape[-1] - sum(self.input_sizes)], dim=-1)[:-1]
payloads = [None] * len(self.input_sizes)
dists = [wrapper(x, _adapt_std, _payload) for x, _adapt_std, _payload, wrapper in zip(xs, adapt_stds, payloads, self.wrapper_list)]
return MixDistribution(dists)
# --------------------------------- Policies -------------------------------- #
[docs]class FeedForwardPolicy(torch.nn.Module):
"""Policy for using mlp, resnet and transformer backbone.
in_features : The number of input features, or a dictionary describing the input distribution
out_features : The number of output features
hidden_features : The number of hidden features in each layer
hidden_layers : The number of hidden layers
dist_config : A list of configurations for the distributions to use in the model
norm : The type of normalization to apply to the input features
hidden_activation : The activation function to use in the hidden layers
backbone_type : The type of backbone to use in the model
use_multihead : Whether to use a multihead model
use_feature_embed : Whether to use feature embedding
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
norm : str = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'mlp',
use_multihead : bool = False,
self.multihead = []
self.dist_config = dist_config
self.kwargs = kwargs
if isinstance(in_features, dict):
in_features = sum(in_features.values())
if not use_multihead:
if backbone_type == 'mlp':
self.backbone = MLP(in_features, out_features, hidden_features, hidden_layers, norm, hidden_activation)
elif backbone_type == 'res':
self.backbone = ResNet(in_features, out_features, hidden_features, hidden_layers, norm)
elif backbone_type == 'transformer':
self.backbone = Transformer1D(in_features, out_features, hidden_features, transformer_layers=hidden_layers)
elif backbone_type == 'ft_transformer':
self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers)
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
if backbone_type == 'mlp':
self.backbone = MLP(in_features, hidden_features, hidden_features, hidden_layers, norm, hidden_activation)
elif backbone_type == 'res':
self.backbone = ResNet(in_features, hidden_features, hidden_features, hidden_layers, norm)
elif backbone_type == 'transformer':
self.backbone = Transformer1D(in_features, hidden_features, hidden_features, transformer_layers=hidden_layers)
elif backbone_type == 'ft_transformer':
self.backbone = FT_Transformer(in_features, out_features, hidden_features, hidden_layers=hidden_layers)
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
for dist in self.dist_config:
dist_type = dist["type"]
output_dim = dist["output_dim"]
if dist_type == "onehot" or dist_type == "discrete_logistic":
self.multihead.append(MLP(hidden_features, output_dim, 64, 1, norm=None, hidden_activation='leakyrelu'))
elif dist_type == "normal":
normal_dim = dist["dim"]
for dim in range(normal_dim):
self.multihead.append(MLP(hidden_features, int(output_dim // normal_dim), 64, 1, norm=None, hidden_activation='leakyrelu'))
raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.')
self.multihead = nn.ModuleList(self.multihead)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
if isinstance(state, dict):
assert 'input_names' in kwargs
state = torch.cat([state[key] for key in kwargs['input_names']], dim=-1)
if not self.multihead:
output = self.backbone(state)
backbone_output = self.backbone(state)
multihead_output = []
multihead_index = 0
for dist in self.dist_config:
dist_type = dist["type"]
output_dim = dist["output_dim"]
if dist_type == "onehot" or dist_type == "discrete_logistic":
multihead_index += 1
elif dist_type == "normal":
normal_mode_output = []
normal_std_output = []
for head in self.multihead[multihead_index:]:
head_output = head(backbone_output)
if head_output.shape[-1] == 1:
mode = head_output
mode, std = torch.chunk(head_output, 2, axis=-1)
normal_output = torch.cat(normal_mode_output, axis=-1)
if normal_std_output:
normal_std_output = torch.cat(normal_std_output, axis=-1)
normal_output = torch.cat([normal_output, normal_std_output], axis=-1)
raise NotImplementedError(f'Dist type {dist_type} is not supported in multihead.')
output = torch.cat(multihead_output, axis= -1)
dist = self.dist_wrapper(output, adapt_std)
if hasattr(self, "dist_mu_shift"):
dist = dist.shift(self.dist_mu_shift)
return dist
[docs] @torch.no_grad()
def get_action(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], deterministic : bool = True):
dist = self(state)
return dist.mode if deterministic else dist.sample()
[docs]class RecurrentPolicy(torch.nn.Module):
"""Initializes a recurrent policy network instance.
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features in the RNN.
hidden_layers (int): The number of layers in the RNN.
dist_config (list): The configuration for the distributions used in the model.
backbone_type (Union[str, np.str_], optional): The type of RNN to use ('gru' or 'lstm'). Default is 'gru'.
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru'):
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
self.rnn = RNN(in_features, hidden_features, hidden_layers)
self.mlp = MLP(hidden_features, out_features, 0, 0)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self):
self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
if len(x.shape) == 2:
x = x.unsqueeze(0)
rnn_output, self.h = self.rnn(x, self.h)
rnn_output = rnn_output.squeeze(0)
rnn_output, self.h = self.rnn(x)
logits = self.mlp(rnn_output)
return self.dist_wrapper(logits, adapt_std)
[docs]class RecurrentRESPolicy(torch.nn.Module):
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] ='gru',
rnn_hidden_features : int = 64,
window_size : int = 0,
self.kwargs = kwargs
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
self.in_feature_embed = nn.Sequential(
nn.Linear(in_features, hidden_features),
self.side_net = nn.Sequential(
ResBlock(hidden_features, hidden_features),
ResBlock(hidden_features, hidden_features),
self.rnn = RNN(hidden_features, rnn_hidden_features, hidden_layers)
self.backbone = MLP(hidden_features + rnn_hidden_features, out_features, hidden_features, 1)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
self.hidden_layers = hidden_layers
self.hidden_features = hidden_features
self.rnn_hidden_features = rnn_hidden_features
self.window_size = window_size
self.hidden_que = deque(maxlen=self.window_size) # [(h_0, c_0), (h_1, c_1), ...]
self.h = None
[docs] def reset(self):
self.hidden_que = deque(maxlen=self.window_size)
self.h = None
[docs] def preprocess_window(self, a_embed):
shape = a_embed.shape
if len(self.hidden_que) < self.window_size: # do Not pop
if self.hidden_que:
h = self.hidden_que[0]
h = (torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed), torch.zeros((self.hidden_layers, shape[1], self.rnn_hidden_features)).to(a_embed))
h = self.hidden_que.popleft()
hidden_cat = torch.concat([h[0]] + [hidden[0] for hidden in self.hidden_que] + [torch.zeros_like(h[0]).to(h[0])], dim=1) # (layers, bs * 4, dim)
cell_cat = torch.concat([h[1]] + [hidden[1] for hidden in self.hidden_que] + [torch.zeros_like(h[1]).to(h[1])], dim=1) # (layers, bs * 4, dim)
a_embed = torch.repeat_interleave(a_embed, repeats=len(self.hidden_que)+2, dim=0).view(1, (len(self.hidden_que)+2)*shape[1], -1) # (1, bs * 4, dim)
return a_embed, hidden_cat, cell_cat
[docs] def postprocess_window(self, rnn_output, hidden_cat, cell_cat):
hidden_cat = torch.chunk(hidden_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim)
cell_cat = torch.chunk(cell_cat, chunks=len(self.hidden_que)+2, dim=1) # tuples of (layers, bs, dim)
rnn_output = torch.chunk(rnn_output, chunks=len(self.hidden_que)+2, dim=1) # tuples of (1, bs, dim)
self.hidden_que = deque(zip(hidden_cat[1:], cell_cat[1:]), maxlen=self.window_size) # important to discrad the first element !!!
rnn_output = rnn_output[0] # (1, bs, dim)
return rnn_output
[docs] def rnn_forward(self, a_embed, joint_train : bool):
if self.window_size > 0:
a_embed, hidden_cat, cell_cat = self.preprocess_window(a_embed)
if joint_train:
rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat))
with torch.no_grad():
rnn_output, (hidden_cat, cell_cat) = self.rnn(a_embed, (hidden_cat, cell_cat))
rnn_output = self.postprocess_window(rnn_output, hidden_cat, cell_cat) #(1, bs, dim)
if joint_train:
rnn_output, self.h = self.rnn(a_embed, self.h) #(1, bs, dim)
with torch.no_grad():
rnn_output, self.h = self.rnn(a_embed, self.h)
return rnn_output
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution:
if field == 'bc':
shape = x.shape
joint_train = True
if len(shape) == 2: # (bs, dim)
x_embed = self.in_feature_embed(x)
side_output = self.side_net(x_embed)
a_embed = x_embed.unsqueeze(0) # [1, bs, dim]
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
rnn_output = rnn_output.squeeze(0) # (bs, dim)
logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim)
assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}"
output = []
for i in range(shape[0]):
a = x[i]
a = a.unsqueeze(0) #(1, bs, dim)
a_embed = self.in_feature_embed(a) # (1, bs, dim)
side_output = self.side_net(a_embed)
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim)
logits = torch.concat(output, dim=0)
elif field == 'mail':
shape = x.shape
joint_train = False
if len(shape) == 2:
x_embed = self.in_feature_embed(x)
side_output = self.side_net(x_embed)
a_embed = x_embed.unsqueeze(0)
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
rnn_output = rnn_output.squeeze(0) # (bs, dim)
logits = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (bs, dim)
assert len(shape) == 3, f"expect len(x.shape) == 3. However got x.shape {shape}"
output = []
for i in range(shape[0]):
a = x[i]
a = a.unsqueeze(0) #(1, bs, dim)
a_embed = self.in_feature_embed(a) # (1, bs, dim)
side_output = self.side_net(a_embed)
rnn_output = self.rnn_forward(a_embed, joint_train) # [1, bs, dim]
backbone_output = self.backbone(torch.concat([side_output, rnn_output], dim=-1)) # (1, bs, dim)
logits = torch.concat(output, dim=0)
raise NotImplementedError(f"unknow field: {field} in RNN training !")
return self.dist_wrapper(logits, adapt_std)
[docs]class ContextualPolicy(torch.nn.Module):
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] ='contextual_gru',
self.kwargs = kwargs
RNN = torch.nn.GRU if backbone_type == 'contextual_gru' else torch.nn.LSTM
self.preprocess_mlp = MLP(in_features, hidden_features, 0, 0, output_activation='leakyrelu')
self.rnn = RNN(hidden_features, hidden_features, 1)
self.mlp = MLP(hidden_features + in_features, out_features, hidden_features, hidden_layers)
self.dist_wrapper = DistributionWrapper('mix', dist_config=dist_config)
[docs] def reset(self):
self.h = None
[docs] def forward(self, x : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
in_feature = x
if len(x.shape) == 2:
x = x.unsqueeze(0)
x = self.preprocess_mlp(x)
rnn_output, self.h = self.rnn(x, self.h)
rnn_output = rnn_output.squeeze(0)
x = self.preprocess_mlp(x)
rnn_output, self.h = self.rnn(x)
logits = self.mlp(torch.cat((in_feature, rnn_output), dim=-1))
return self.dist_wrapper(logits, adapt_std)
# ------------------------------- Transitions ------------------------------- #
[docs]class FeedForwardTransition(FeedForwardPolicy):
r"""Initializes a feedforward transition instance.
in_features (Union[int, dict]): The number of input features or a dictionary of input feature sizes.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The configuration of the distribution to use.
norm (Optional[str]): The normalization method to use. None if no normalization is used.
hidden_activation (str): The activation function to use for the hidden layers.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode to use for the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation for 'local' mode.
use_feature_embed (bool): Whether to use feature embedding.
def __init__(self,
in_features : Union[int, dict],
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
norm : Optional[str] = None,
hidden_activation : str = 'leakyrelu',
backbone_type : Union[str, np.str_] = 'mlp',
mode : str = 'global',
obs_dim : Optional[int] = None,
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local':
dist_types = [config['type'] for config in dist_config]
if 'onehot' in dist_types or 'discrete_logistic' in dist_types:
warnings.warn('Detect distribution type that are not compatible with the local mode, fallback to global mode!')
self.mode = 'global'
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(FeedForwardTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers,
dist_config, norm, hidden_activation, backbone_type, **kwargs)
[docs] def forward(self, state : Union[torch.Tensor, Dict[str, torch.Tensor]], adapt_std : Optional[torch.Tensor] = None, **kwargs) -> ReviveDistribution:
dist = super(FeedForwardTransition, self).forward(state, adapt_std, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]class RecurrentTransition(RecurrentPolicy):
Initializes a recurrent transition instance.
in_features (int): The number of input features.
out_features (int): The number of output features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
dist_config (list): The distribution configuration.
backbone_type (Union[str, np.str_]): The type of backbone to use.
mode (str): The mode of the transition. Either 'global' or 'local'.
obs_dim (Optional[int]): The dimension of the observation.
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] = 'mlp',
mode : str = 'global',
obs_dim : Optional[int] = None,
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(RecurrentTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers,
dist_config, backbone_type, **kwargs)
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None) -> ReviveDistribution:
dist = super(RecurrentTransition, self).forward(state, adapt_std)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]class RecurrentRESTransition(RecurrentRESPolicy):
def __init__(self,
in_features : int,
out_features : int,
hidden_features : int,
hidden_layers : int,
dist_config : list,
backbone_type : Union[str, np.str_] = 'mlp',
mode : str = 'global',
obs_dim : Optional[int] = None,
rnn_hidden_features : int = 64,
window_size : int = 0,
self.mode = mode
self.obs_dim = obs_dim
if self.mode == 'local': assert not 'onehot' in [config['type'] for config in dist_config], \
"The local mode of transition is not compatible with onehot data! Please fallback to global mode!"
if self.mode == 'local': assert self.obs_dim is not None, \
"For local mode, the dim of observation should be given!"
super(RecurrentRESTransition, self).__init__(in_features, out_features, hidden_features, hidden_layers,
dist_config, backbone_type,
rnn_hidden_features=rnn_hidden_features, window_size=window_size,
[docs] def forward(self, state : torch.Tensor, adapt_std : Optional[torch.Tensor] = None, field: str = 'mail', **kwargs) -> ReviveDistribution:
dist = super(RecurrentRESTransition, self).forward(state, adapt_std, field, **kwargs)
if self.mode == 'local' and self.obs_dim is not None:
dist = dist.shift(state[..., :self.obs_dim])
return dist
[docs]class FeedForwardMatcher(torch.nn.Module):
Initializes a feedforward matcher instance.
in_features (int): The number of input features.
hidden_features (int): The number of hidden features.
hidden_layers (int): The number of hidden layers.
hidden_activation (str): The activation function to use for the hidden layers.
norm (str): The normalization method to use. None if no normalization is used.
backbone_type (Union[str, np.str_]): The type of backbone to use.
def __init__(self,
in_features : int,
hidden_features : int,
hidden_layers : int,
hidden_activation : str = 'leakyrelu',
norm : str = None,
backbone_type : Union[str, np.str_] = 'mlp'):
if backbone_type == 'mlp':
self.backbone = MLP(in_features, 1, hidden_features, hidden_layers, norm, hidden_activation, output_activation='sigmoid')
elif backbone_type == 'res':
self.backbone = ResNet(in_features, 1, hidden_features, hidden_layers, norm, output_activation='sigmoid')
elif backbone_type == 'transformer':
self.backbone = torch.nn.Sequential(
Transformer1D(in_features, in_features, hidden_features, transformer_layers=hidden_layers),
torch.nn.Linear(in_features, 1),
elif backbone_type == 'ft_transformer':
self.backbone = FT_Transformer(in_features, 1, hidden_features, hidden_layers=hidden_layers)
raise NotImplementedError(f'backbone type {backbone_type} is not supported')
[docs] def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
x = torch.cat(inputs, dim=-1)
return self.backbone(x)
[docs]class RecurrentMatcher(torch.nn.Module):
def __init__(self,
in_features : int,
hidden_features : int,
hidden_layers : int,
backbone_type : Union[str, np.str_] = 'gru',
bidirect : bool = False):
RNN = torch.nn.GRU if backbone_type == 'gru' else torch.nn.LSTM
self.rnn = RNN(in_features, hidden_features, hidden_layers, bidirectional=bidirect)
self.output_layer = MLP(hidden_features * (2 if bidirect else 1), 1, 0, 0, output_activation='sigmoid')
[docs] def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
x = torch.cat(inputs, dim=-1)
rnn_output = self.rnn(x)[0]
return self.output_layer(rnn_output)
[docs]class HierarchicalMatcher(torch.nn.Module):
def __init__(self,
in_features : list,
hidden_features : int,
hidden_layers : int,
hidden_activation : int,
norm : str = None):
self.in_features = in_features
process_layers = []
output_layers = []
feature = self.in_features[0]
for i in range(1, len(self.in_features)):
feature += self.in_features[i]
process_layers.append(MLP(feature, hidden_features, hidden_features, hidden_layers, norm, hidden_activation, hidden_activation))
output_layers.append(torch.nn.Linear(hidden_features, 1))
feature = hidden_features
self.process_layers = torch.nn.ModuleList(process_layers)
self.output_layers = torch.nn.ModuleList(output_layers)
[docs] def forward(self, *inputs : List[torch.Tensor]) -> torch.Tensor:
assert len(inputs) == len(self.in_features)
last_feature = inputs[0]
result = 0
for i, process_layer, output_layer in zip(range(1, len(self.in_features)), self.process_layers, self.output_layers):
last_feature = torch.cat([last_feature, inputs[i]], dim=-1)
last_feature = process_layer(last_feature)
result += output_layer(last_feature)
return torch.sigmoid(result)
[docs]class VectorizedCritic(VectorizedMLP):
[docs] def forward(self, x : torch.Tensor) -> torch.Tensor:
x = super(VectorizedCritic, self).forward(x)
return x.squeeze(-1)