''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
from revive.algo.venv.base import VenvOperator
[docs]class AlgorithmOperator(VenvOperator):
'''TODO 1: Define the name of this algorithm.'''
NAME = ''
'''TODO 2: Define the hyper-parameters of this algorithm.'''
# 2.1 Either write the `PARAMETER_DESCRIPTION`.
PARAMETER_DESCRIPTION = [
]
# 2.2 Or overwrite `get_parameters` and `get_tune_parameters`.
[docs] @classmethod
def get_parameters(cls, command=None, **kargs):
raise NotImplementedError
[docs] @classmethod
def get_tune_parameters(cls, config, **kargs):
raise NotImplementedError
'''TODO 3: Define the creator of model, optimizer, data (optional).'''
[docs] def model_creator(self, config):
"""
Create all the models.
:param config: configuration parameters
:return: list of models
"""
raise NotImplementedError
[docs] def optimizer_creator(self, models, config):
"""
Create Optimizers.
:param models: list of all the models
:param config: configuration parameters
:return: list of optimizers
"""
raise NotImplementedError
[docs] def data_creator(self, config):
"""
Create DataLoaders.
:param config: configuration parameters
:return: train_loader and val_loader
"""
return NotImplementedError
'''TODO 4: Define the training loop, either by overwrite `train_epoch` or `train_batch`.'''
[docs] def train_epoch(self, iterator, info):
# you can find reference in base
raise NotImplementedError
[docs] def train_batch(self, expert_data, batch_info, scope):
raise NotImplementedError