''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import os
import sys
import json
import torch
import numpy as np
import logging, copy
from ray.tune.experiment.trial import Trial
from ray.tune.utils import merge_dicts, flatten_dict
logger = logging.getLogger(__name__)
from typing import Dict, List
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from zoopt.parameter import ToolFunction
from ray.tune.logger import LoggerCallback, CSVLoggerCallback, JsonLoggerCallback
from ray.tune.utils import flatten_dict
from ray.tune.error import TuneError
from ray.tune import Stopper
from ray.tune import CLIReporter as _CLIReporter
from ray.tune.search.basic_variant import _flatten_resolved_vars, _count_spec_samples, _count_variants, _TrialIterator
from ray.tune.experiment import _convert_to_experiment_list
from ray.tune.search.basic_variant import warnings, Union, List, itertools, SERIALIZATION_THRESHOLD
from ray.tune.search.zoopt.zoopt_search import DEFAULT_METRIC, Solution, zoopt
from ray.tune.search.zoopt import ZOOptSearch as _ZOOptSearch
from ray.tune.search import BasicVariantGenerator, SearchGenerator, Searcher
from ray.tune.search.variant_generator import format_vars, _resolve_nested_dict, _flatten_resolved_vars
from ray.tune.experiment.config_parser import _create_trial_from_spec
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
[docs]
class SysStopper(Stopper):
"""Customizing the training mechanism of ray
Reference : https://docs.ray.io/en/latest/tune/api/stoppers.html
"""
def __init__(self, workspace, max_iter: int = 0, stop_callback = None):
self._workspace = workspace
self._max_iter = max_iter
self._iter = defaultdict(lambda: 0)
self.stop_callback = stop_callback
# Customizing the stopping mechanism for a single trail
[docs]
def __call__(self, trial_id, result):
if self._max_iter > 0:
self._iter[trial_id] += 1
if self._iter[trial_id] >= self._max_iter:
return True
if result["stop_flag"]:
if self.stop_callback:
self.stop_callback()
return True
return False
# Customize the stopping mechanism for the entire training process
[docs]
def stop_all(self):
if os.path.exists(os.path.join(self._workspace,'.env.json')):
with open(os.path.join(self._workspace,'.env.json'), 'r') as f:
_data = json.load(f)
if _data["REVIVE_STOP"]:
if self.stop_callback:
self.stop_callback()
return _data["REVIVE_STOP"]
else:
return False
[docs]
class TuneTBLoggerCallback(LoggerCallback):
r"""
custom tensorboard logger for ray tune
modified from ray.tune.logger.TBXLogger
Reference: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.logger.LoggerCallback.html
"""
def _init(self):
self._file_writer = SummaryWriter(self.logdir)
self.last_result = None
self.step = 0
[docs]
def on_result(self, result):
self.step += 1
tmp = result.copy()
flat_result = flatten_dict(tmp, delimiter="/")
for k, v in flat_result.items():
if type(v) in VALID_SUMMARY_TYPES:
self._file_writer.add_scalar(k, float(v), global_step=self.step)
elif isinstance(v, torch.Tensor):
v = v.view(-1)
self._file_writer.add_histogram(k, v, global_step=self.step)
self.last_result = flat_result
self.flush()
[docs]
def flush(self):
if self._file_writer is not None:
self._file_writer.flush()
[docs]
def get_tune_callbacks():
TUNELOGGERCallbacks = [CSVLoggerCallback, JsonLoggerCallback, TuneTBLoggerCallback]
TUNELOGGERCallbacks = [callback() for callback in TUNELOGGERCallbacks]
return TUNELOGGERCallbacks
[docs]
class CLIReporter(_CLIReporter):
"""Modifying the Command line reporter to support logging to loguru
Reference : https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html
"""
[docs]
def report(self, trials: List, done: bool, *sys_info: Dict):
message = self._progress_str(trials, done, *sys_info)
from loguru import logger
logger.info(f"{message}")
[docs]
class CustomSearchGenerator(SearchGenerator):
"""
Customize the SearchGenerator by placing tags in the spec's config
Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/search_generator.py
"""
[docs]
def create_trial_if_possible(self, experiment_spec: Dict):
logger.debug("creating trial")
trial_id = Trial.generate_id()
suggested_config = self.searcher.suggest(trial_id)
if suggested_config == Searcher.FINISHED:
self._finished = True
logger.debug("Searcher has finished.")
return
if suggested_config is None:
return
spec = copy.deepcopy(experiment_spec)
spec["config"] = merge_dicts(spec["config"], copy.deepcopy(suggested_config))
# Create a new trial_id if duplicate trial is created
flattened_config = _resolve_nested_dict(spec["config"])
self._counter += 1
tag = "{0}_{1}".format(str(self._counter), format_vars(flattened_config))
# New: set tag in spec['config']
spec['config']['tag'] = tag
trial = _create_trial_from_spec(
spec,
self._parser,
evaluated_params=flatten_dict(suggested_config),
experiment_tag=tag,
trial_id=trial_id,
)
return trial
[docs]
class TrialIterator(_TrialIterator):
"""
Customize the _TrialIterator by placing tags in the spec's config
Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/basic_variant.py
"""
[docs]
def create_trial(self, resolved_vars, spec):
trial_id = self.uuid_prefix + ("%05d" % self.counter)
experiment_tag = str(self.counter)
# Always append resolved vars to experiment tag?
if resolved_vars:
experiment_tag += "_{}".format(format_vars(resolved_vars))
spec['config']['tag'] = experiment_tag
self.counter += 1
return _create_trial_from_spec(
spec,
self.output_path,
self.parser,
evaluated_params=_flatten_resolved_vars(resolved_vars),
trial_id=trial_id,
experiment_tag=experiment_tag)
[docs]
class CustomBasicVariantGenerator(BasicVariantGenerator):
"""
Using custom TrialIterator instead _TrialIterator
Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/basic_variant.py
"""
[docs]
def add_configurations(
self, experiments: Union["Experiment", List["Experiment"], Dict[str, Dict]]
):
"""Chains generator given experiment specifications.
Arguments:
experiments (Experiment | list | dict): Experiments to run.
"""
experiment_list = _convert_to_experiment_list(experiments)
for experiment in experiment_list:
grid_vals = _count_spec_samples(experiment.spec, num_samples=1)
lazy_eval = grid_vals > SERIALIZATION_THRESHOLD
if lazy_eval:
warnings.warn(
f"The number of pre-generated samples ({grid_vals}) "
"exceeds the serialization threshold "
f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is "
"disabled. To fix this, reduce the number of "
"dimensions/size of the provided grid search.")
previous_samples = self._total_samples
points_to_evaluate = copy.deepcopy(self._points_to_evaluate)
self._total_samples += _count_variants(experiment.spec,
points_to_evaluate)
iterator = TrialIterator(
uuid_prefix=self._uuid_prefix,
num_samples=experiment.spec.get("num_samples", 1),
unresolved_spec=experiment.spec,
constant_grid_search=self._constant_grid_search,
output_path=experiment.dir_name,
points_to_evaluate=points_to_evaluate,
lazy_eval=lazy_eval,
start=previous_samples)
self._iterators.append(iterator)
self._trial_generator = itertools.chain(self._trial_generator,
iterator)
[docs]
class Parameter(zoopt.Parameter):
"""
Customize Zoom resource allocation method to fully utilize resources
"""
def __init__(self, *args, **kwargs):
self.parallel_num = kwargs.pop('parallel_num')
super(Parameter, self).__init__(*args, **kwargs)
[docs]
def auto_set(self, budget):
"""
Set train_size, positive_size, negative_size by following rules:
budget < 3 --> error;
budget < 3 --> train_size = p, positive_size = (0.2*self.parallel_num);
:param budget: number of calls to the objective function
:return: no return value
"""
if budget < 3:
ToolFunction.log('parameter.py: budget too small')
sys.exit(1)
else:
if self.parallel_num < 4:
super(Parameter, self).auto_set(budget)
return
else:
self.__train_size = self.parallel_num
self.__positive_size = max(int(0.2 * self.parallel_num),1)
self.__negative_size = self.__train_size - self.__positive_size
[docs]
class ZOOptSearch(_ZOOptSearch):
"""
Customize Zoom resource allocation method to fully utilize resources
"""
def _setup_zoopt(self):
if self._metric is None and self._mode:
# If only a mode was passed, use anonymous metric
self._metric = DEFAULT_METRIC
_dim_list = []
for k in self._dim_dict:
self._dim_keys.append(k)
_dim_list.append(self._dim_dict[k])
init_samples = None
if self._points_to_evaluate:
logger.warning(
"`points_to_evaluate` is ignored by ZOOpt in versions <= 0.4.1."
)
init_samples = [
Solution(x=tuple(point[dim] for dim in self._dim_keys))
for point in self._points_to_evaluate
]
dim = zoopt.Dimension2(_dim_list)
par = Parameter(budget=self._budget, init_samples=init_samples,parallel_num=self.parallel_num)
if self._algo == "sracos" or self._algo == "asracos":
from zoopt.algos.opt_algorithms.racos.sracos import SRacosTune
self.optimizer = SRacosTune(
dimension=dim,
parameter=par,
parallel_num=self.parallel_num,
**self.kwargs
)