Source code for revive.algo.venv.template

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""

from revive.algo.venv.base import VenvOperator

[docs]class AlgorithmOperator(VenvOperator): '''TODO 1: Define the name of this algorithm.''' NAME = '' '''TODO 2: Define the hyper-parameters of this algorithm.''' # 2.1 Either write the `PARAMETER_DESCRIPTION`. PARAMETER_DESCRIPTION = [ ] # 2.2 Or overwrite `get_parameters` and `get_tune_parameters`.
[docs] @classmethod def get_parameters(cls, command=None, **kargs): raise NotImplementedError
[docs] @classmethod def get_tune_parameters(cls, config, **kargs): raise NotImplementedError
'''TODO 3: Define the creator of model, optimizer, data (optional).'''
[docs] def model_creator(self, config): """ Create all the models. :param config: configuration parameters :return: list of models """ raise NotImplementedError
[docs] def optimizer_creator(self, models, config): """ Create Optimizers. :param models: list of all the models :param config: configuration parameters :return: list of optimizers """ raise NotImplementedError
[docs] def data_creator(self, config): """ Create DataLoaders. :param config: configuration parameters :return: train_loader and val_loader """ return NotImplementedError
'''TODO 4: Define the training loop, either by overwrite `train_epoch` or `train_batch`.'''
[docs] def train_epoch(self, iterator, info): # you can find reference in base raise NotImplementedError
[docs] def train_batch(self, expert_data, batch_info, scope): raise NotImplementedError