Source code for revive.computation.operators

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
""" The function of this document is mainly for use on the SAAS platform, 
and maintenance will be temporarily suspended in the future."""
import torch


[docs] def cat(*args): return torch.cat(args,axis=-1)
[docs] def sum(*args): if len(args) == 1: return torch.sum(args[0], axis=-1, keepdim=True) elif len(args) > 1: _res = 0 for arg in args: _res += arg return _res else: raise NotImplementedError
[docs] def sub(arg1,arg2): return arg1 - arg2
[docs] def mul(*args): _res = 1 if len(args) == 1: for index in range(args[0].shape[-1]): _res *= args[0][...,index:index+1] return _res elif len(args) > 1: for arg in args: _res *= arg return _res else: raise NotImplementedError
[docs] def div(arg1, arg2): return arg1 / (arg2 + 1e-8)
[docs] def mean(*args): if len(args) == 1: return torch.mean(args[0], dim=-1, keepdim=True) return torch.mean(torch.cat([arg.unsqueeze(-1) for arg in args], axis=-1), axis=-1)
[docs] def min(*args): if len(args) == 1: return torch.min(args[0], dim=-1, keepdim=True)[0] return torch.min(torch.cat([arg.unsqueeze(-1) for arg in args], axis=-1), axis=-1)[0]
[docs] def max(*args): if len(args) == 1: return torch.max(args[0], dim=-1, keepdim=True)[0] return torch.max(torch.cat([arg.unsqueeze(-1) for arg in args], axis=-1), axis=-1)[0]
[docs] def abs(arg): return torch.abs(arg)
[docs] def clip(arg,min_v=None,max_v=None): return torch.clip(arg,min_v,max_v)
[docs] def exp(arg): return torch.exp(arg)
[docs] def log(arg): return torch.log(arg)
""" def log(arg): assert len(args) == 1 return torch.log(arg) def floor(arg, base): pass def fmod(arg): pass """ __all__ = [ "sum", "cat", "sub", "mul", "div", "min", "mean", "max", "abs", "clip", "exp", "log", ]