revive.utils package¶
Submodules¶
revive.utils.auth_utils module¶
- revive.utils.auth_utils.customer_createTrain(machineCode: str, trainModelSimulatorTotalCount: str, trainModelPolicyTotalCount: str, trainDataRowsCount: str, yamlNodeCount: str, yamlFileClientUrl: str, configFileClientUrl: str, logFileClientUrl: str, userPrivateKey: str)[source]¶
Verify the user’s training privileges.
API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#uFKnl
- revive.utils.auth_utils.customer_uploadTrainFile(trainId: str, accessToken: str, yamlFile: Optional[str] = None, configFile: Optional[str] = None, logFile: Optional[str] = None)[source]¶
Upload the history train log.
API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#r5IPw
- revive.utils.auth_utils.customer_uploadTrainLog(trainId: str, logFile: str, trainType: str, trainResult: str, trainScore: str, accessToken: str)[source]¶
Upload the log after a trail is trained.
API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#KvKWx
revive.utils.casual_graph module¶
revive.utils.causal_discovery_utils module¶
- revive.utils.causal_discovery_utils.pc(data: ndarray, indep: str = 'fisherz', thresh: float = 0.05, bg_rules: Optional[BackgroundKnowledge] = None, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool] [source]¶
- revive.utils.causal_discovery_utils.fci(data: ndarray, indep: str = 'fisherz', thresh: float = 0.05, bg_rules: Optional[BackgroundKnowledge] = None, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool] [source]¶
- revive.utils.causal_discovery_utils.inter_cit(data: ndarray, indep: str = 'fisherz', inter_classes: Iterable[Iterable[Iterable[int]]] = [], in_parallel: bool = True, parallel_limit: int = 5, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool] [source]¶
use cit to discover the relations of variables inter different classes (indicated by indices)
- revive.utils.causal_discovery_utils.lingam(data: ndarray, ver: str = 'ica', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool] [source]¶
- revive.utils.causal_discovery_utils.anm(data: ndarray, kernelX: str = 'Gaussian', kernelY: str = 'Gaussian', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool] [source]¶
- revive.utils.causal_discovery_utils.ges(data: ndarray, score_func: str = 'BIC', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool] [source]¶
- revive.utils.causal_discovery_utils.exact_search(data: ndarray, method: str = 'astar', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool] [source]¶
- class revive.utils.causal_discovery_utils.Graph(graph: ndarray, is_real: bool = False, thresh_info: Optional[Dict[str, Any]] = None)[source]¶
Bases:
object
Causal graph class
- property graph¶
raw graph
- property thresh_info¶
information about threshold
- class revive.utils.causal_discovery_utils.TransitionGraph(graph: ndarray, state_dim: int, action_dim: int, is_real: bool = False, thresh_info: Optional[Dict[str, Any]] = None)[source]¶
Bases:
Graph
RL transition graph class
- class revive.utils.causal_discovery_utils.DiscoveryModule(**kwargs)[source]¶
Bases:
ABC
Base class for causal discovery modules
- abstract fit(data: Any, **kwargs) DiscoveryModule [source]¶
- class revive.utils.causal_discovery_utils.ClassicalDiscovery(alg: str = 'inter_cit', alg_args: Dict[str, Any] = {'in_parallel': False, 'indep': 'kci'}, state_keys: Optional[List[str]] = ['obs'], action_keys: Optional[List[str]] = ['action'], next_state_keys: Optional[List[str]] = ['next_obs'], limit: Optional[int] = 100, use_residual: bool = True, **kwargs)[source]¶
Bases:
DiscoveryModule
Classical causal discovery algorithms
- CLASSICAL_ALGOS = {'anm': <function anm>, 'exact_search': <function exact_search>, 'fci': <function fci>, 'ges': <function ges>, 'lingam': <function lingam>, 'pc': <function pc>}¶
- CLASSICAL_ALGOS_TRANSITION = {'inter_cit': <function inter_cit>}¶
- CLASSICAL_ALGOS_THRESH_INFO = {'anm': {'common': 0.5, 'max': 1.0, 'min': 0.0}, 'inter_cit': {'common': 0.8, 'max': 1.0, 'min': 0.0}, 'lingam': {'common': 0.01, 'max': inf, 'min': 0.0}}¶
- fit(data: Union[Dict[str, ndarray], ndarray], fit_transition: bool = True) ClassicalDiscovery [source]¶
fit the discovery module to transition data or general data :param data: dict[str, ndarray] | ndarray,
transition data dictionary or general data matrix
- Returns
the module itself
- class revive.utils.causal_discovery_utils.AsyncClassicalDiscovery(alg: str = 'inter_cit', alg_args: Dict[str, Any] = {'in_parallel': False, 'indep': 'kci'}, state_keys: Optional[List[str]] = ['obs'], action_keys: Optional[List[str]] = ['action'], next_state_keys: Optional[List[str]] = ['next_obs'], limit: Optional[int] = 100, use_residual: bool = True, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs)[source]¶
Bases:
ClassicalDiscovery
Classical causal discovery algorithms (support asynchronous ver.)
- set_callback(callback: Callable[[int, int, ndarray, bool], Any])[source]¶
set custom callback function
- fit(data: Union[Dict[str, ndarray], ndarray], fit_transition: bool = True) ClassicalDiscovery [source]¶
fit the discovery module to transition data or general data :param data: dict[str, ndarray] | ndarray,
transition data dictionary or general data matrix
- Returns
the module itself
revive.utils.common_utils module¶
- revive.utils.common_utils.update_env_vars(key: str, value: Any)[source]¶
update env vars in os
- Args:
key (str): name of the key. value (str): value for the key
- Returns:
update os.environ[‘env_vars’]
- revive.utils.common_utils.get_env_var(key: str, default=None)[source]¶
get env vars in os
- Args:
key (str): name of the key. default (str): None
- Returns:
update os.environ[‘env_vars’]
- class revive.utils.common_utils.AttributeDict[source]¶
Bases:
dict
define a new class for using get and set variables esily
- revive.utils.common_utils.setup_seed(seed: int)[source]¶
Seting random seed in REVIVE.
- Args:
seed: random seed
- revive.utils.common_utils.load_npz(filename: str)[source]¶
Loading npz file
- Args:
filename(str): *.npz file path
- Return:
Dict of data in format of keys:values
- revive.utils.common_utils.load_h5(filename: str)[source]¶
Loading npz file
- Args:
filename(str): *.h5 file path
- Return:
Dict of data in format of keys:values
- revive.utils.common_utils.save_h5(filename: str, data: Dict[str, ndarray])[source]¶
Loading npz file
- Args:
filename(str): output *.h5 file path
- Return:
output file
- revive.utils.common_utils.npz2h5(npz_filename: str, h5_filename: str)[source]¶
Transforming npz file to h5 file
- revive.utils.common_utils.h52npz(h5_filename: str, npz_filename: str)[source]¶
Transforming h5 file to npz file
- revive.utils.common_utils.load_data(data_file: str)[source]¶
Loading data file Only support h5 and npz file as data files in REVIVE Args:
- Return:
Dict of data in format of keys:values
- revive.utils.common_utils.find_policy_index(graph: DesicionGraph, policy_name: str)[source]¶
Find index of policy node in the whole decision flow graph Args:
graph (DesicionGraph): decision flow graph in REVIVE policy_name (str): the policy node name be indexed
- Return:
index of the policy node in decision graph
- Notice:
only the first policy node name is supported TODO: multi policy indexes
- revive.utils.common_utils.load_policy(filename: str, policy_name: Optional[str] = None)[source]¶
Load policy file for REVIVE in the format of torch or .pkl of VirturalEnv VirtualEnvDev or PolicyModelEv Args:
filename (str): file path policy_name (str): the policy node name be indexed
- Return:
Policy model
- revive.utils.common_utils.download_helper(url: str, filename: str)[source]¶
Download file from given url. Modified from `torchvision.dataset.utils Args:
url (str): donwloading path filename (str): output file path
- Return:
Output path
- revive.utils.common_utils.import_module_from_file(file_path: str, module_name='module.name')[source]¶
import expert function from file Args:
file_path (str): file path of the expert function module_name (str): function name in the file
- Return:
treat the expert function as an useable funtion in REVIVE
- revive.utils.common_utils.get_reward_fn(reward_file_path: str, config_file: str)[source]¶
import user defined reward function only for Matcher reward Args:
reward_file_path (str): file path of the expert function config_file (str): decision flow *.yml file
- Return:
treat the reward function as an useable funtion in REVIVE
- revive.utils.common_utils.get_module(function_file_path, config_file)[source]¶
import user defined function Args:
function_file_path (str): file path of the expert function config_file (str): decision flow *.yml file
- Return:
treat the reward function as an useable funtion in REVIVE
- revive.utils.common_utils.create_env(task: str)[source]¶
initiating gym environment as testing env for trainning Args:
task (str): gym mujoco task name
- Return:
gym env
- revive.utils.common_utils.test_one_trail(env: Env, policy: PolicyModel)[source]¶
testing revive policy on gym env Args:
env (str): initialized gym mujoco env policy (str): revive policy used for testing on the env
- Return:
reward and running length of the policy
- revive.utils.common_utils.test_on_real_env(env: Env, policy: PolicyModel, number_of_runs: int = 10)[source]¶
testing revive policy on multiple gym envs Args:
env (str): initialized gym mujoco env policy (str): revive policy used for testing on the env number_of_runs (int): the number of trails to testing
- Return:
mean value of reward and running length of the policy
- revive.utils.common_utils.get_input_dim_from_graph(graph: DesicionGraph, node_name: str, total_dims: dict)[source]¶
return the total number of dims used to compute the given node on the graph Args:
graph (DecisionGraph): decision flow with user setting nodes node_name (str): name of the node to get total dimensions total_dims (dict): dict of input and output dims of all nodes
- Return:
total number of dimensions of the node_name
- revive.utils.common_utils.get_input_dim_dict_from_graph(graph: DesicionGraph, node_name: str, total_dims: dict)[source]¶
return the total number of dims as dictused to compute the given node on the graph Args:
graph (DecisionGraph): decision flow with user setting nodes node_name (str): name of the node to get total dimensions total_dims (dict): dict of input and output dims of all nodes
- Return:
total number of dimensions as dict for all input of the node_name
- revive.utils.common_utils.normalize(data: ndarray)[source]¶
normalization of data using mean and std Args:
data (np.ndarray): numpy array
- Return:
normalized data
- revive.utils.common_utils.plot_traj(traj: dict)[source]¶
plot all dims of data into color map along trajectory Args:
traj (dict): data stored in dict
- Return:
plot show with x axis as dims and y axis as traj-step
- revive.utils.common_utils.check_weight(network: Module)[source]¶
Check whether network parameters are nan or inf. Args:
network (torch.nn.Module): torch.nn.Module
- Print:
nan of inf in network params
- revive.utils.common_utils.get_models_parameters(*models)[source]¶
return all the parameters of input models in a list Args:
models (torch.nn.Module): all models inputed for getting parameters
- Return:
list of parameters for all models inputted
- revive.utils.common_utils.get_grad_norm(parameters, norm_type: float = 2)[source]¶
return all gradient of the parameters Args:
models : parameters of the a model
- Return:
L2 norm of the gradient
- revive.utils.common_utils.get_concat_traj(batch_data: Batch, node_names: List[str])[source]¶
concatenate the data from node_names Args:
batch_data (Batch): Batch of data node_names (List): list of node names to get data
- Return:
data to get
- revive.utils.common_utils.get_list_traj(batch_data: Batch, node_names: List[str], nodes_fit_index: Optional[dict] = None) list [source]¶
return all data of node_names from batch_data Args:
batch_data (Batch): Batch of data node_names (List): list of node names to get data nodes_fit_index (Dict): dict of fixed index for nodel_names
- Return:
data to get
- revive.utils.common_utils.generate_rewards(traj: Batch, reward_fn)[source]¶
Add rewards for batch trajectories. Args:
traj: batch trajectories. reward_fn: how the rewards generate.
- Return:
batch trajectories with rewards.
- revive.utils.common_utils.generate_rollout(expert_data: ~revive.data.batch.Batch, graph: ~revive.computation.graph.DesicionGraph, traj_length: int, sample_fn=<function <lambda>>, adapt_stds=None, clip: ~typing.Union[bool, float] = False, use_target: bool = False)[source]¶
Generate trajectories based on current policy. Args:
expert_data: samples from the dataset. graph: the computation graph traj_length: trajectory length sample_fn: sample from a distribution.
- Return:
batch trajectories.
NOTE: this function will mantain the last dimension even if it is 1
- revive.utils.common_utils.generate_rollout_bc(expert_data: ~revive.data.batch.Batch, graph: ~revive.computation.graph.DesicionGraph, traj_length: int, sample_fn=<function <lambda>>, adapt_stds=None, clip: ~typing.Union[bool, float] = False, use_target: bool = False)[source]¶
Generate trajectories based on current policy. Args:
expert_data: samples from the dataset. graph: the computation graph traj_length: trajectory length sample_fn: sample from a distribution.
- Return:
batch trajectories.
NOTE: this function will mantain the last dimension even if it is 1
- revive.utils.common_utils.compute_lambda_return(rewards, values, bootstrap=None, _gamma=0.9, _lambda=0.98)[source]¶
Generate lambda return for svg in REVIVE env learning Args:
rewards: reward data for current stated values: values derived from value net bootstrap: bootstrap for the last time step of next_values _gamma: discounted factor _lambda: factor for balancing future or current return
- Return:
discounted return for the input rewards.
- revive.utils.common_utils.sinkhorn_gpu(cuda_id)[source]¶
Specifically setting running device Args:
cuda_id: cuda device id
- Return:
sinkhorn function
- revive.utils.common_utils.wasserstein_distance(X, Y, cost_matrix, method='sinkhorn', niter=50000, cuda_id=0)[source]¶
Calculate wasserstein distance Args:
X & Y : two arrays cost_matrix: cost matrix between two arrays method: method for calculating w_distance niter: number of iteration cuda_id: device for calculating w_distance
- Return:
wasserstein distance
- revive.utils.common_utils.compute_w2_dist_to_expert(policy_trajectorys, expert_trajectorys, scaler=None, data_is_standardscaler=False, max_expert_sampes=20000, dist_metric='euclidean', emd_method='emd', processes=None, use_cuda=False, cuda_id_list=None)[source]¶
Computes Wasserstein 2 distance to expert demonstrations. Calculate wasserstein distance Args:
policy_trajectorys: data generated by policy expert_trajectorys: expert data scaler: scale the data data_is_standardscaler: whether the data is standard scaled or not max_expert_sampes: number of data to use dist_metric: distance type, emd_method: using cpu for computing processes: multi-processing setting use_cuda: using gpu for computing cuda_id_lis: duda device as list
- Return:
wasserstein distance
- revive.utils.common_utils.dict2parser(config: dict)[source]¶
transform dict as operation setting as parser Args:
config: dict of operation
Return: parser as command
- revive.utils.common_utils.list2parser(config: List[Dict])[source]¶
transform list of dict as operation setting as parser Args:
config: list of operation
Return: parser as command
- revive.utils.common_utils.set_parameter_value(config: List[Dict], name: str, value: Any)[source]¶
change value of the name in config file Args:
config: list of dict of variables name: the value of the keys to be changed value: the value to be chanbed into
Return: resetting default values to the original config
- revive.utils.common_utils.update_description(default_description, custom_description)[source]¶
update in-place the default description with a custom description. Args:
default_description: custom_description:
Return:
- revive.utils.common_utils.find_later(path: str, keyword: str) List[str] [source]¶
find all the later folder after the given keyword Args:
path: a file path to get list of folder keyword: the name of the folder which as the last folder at the path
- Return:
a list of folder as path
- revive.utils.common_utils.get_node_dim_from_dist_configs(dist_configs: dict, node_name: str)[source]¶
return the total number of dims of the node_name Args:
dist_configs (dict): decision flow with user setting nodes node_name (str): name of the node to get total dimensions
- Return:
total number of dimensions of the node_name
- revive.utils.common_utils.save_histogram(histogram_path: str, graph: DesicionGraph, data_loader: DataLoader, device: str, scope: str)[source]¶
save the histogram Args:
histogram_path (str): the path to save histogram graph (DesicionGraph): DesicionGraph data_loader (DataLoader): torch data loader device (str): generate data on which device scope (str): ‘train’ or ‘val’ related to the file-saving name.
- Return:
Saving the histogram as png file to the histogram_path
- revive.utils.common_utils.save_histogram_after_stop(traj_length: int, traj_dir: str, train_dataset, val_dataset)[source]¶
save the histogram after the training is stopped Args:
traj_length (int): length of the horizon traj_dir (str): saving derectory train_dataset: torch data loader val_dataset: generate data on which device
- Return:
Saving the histogram as png file to the histogram_path
- revive.utils.common_utils.tb_data_parse(tensorboard_log_dir: str, keys: list = [])[source]¶
parse data from tensorboard logdir Args:
tensorboard_log_dir (str): length of the horizon keys (list): list of keys to get from tb logdir
- Return:
geting a dict of result including value of keys
- revive.utils.common_utils.double_venv_validation(reward_logs, data_reward={}, img_save_path='')[source]¶
policy double venv validation to the img path Args:
reward_logs (str): dict of different rewards data_reward (dict): dataset mean reward of train and val dataset img_save_path (str): path of saving img
- Return:
saving double venv validation img to the setting path
- revive.utils.common_utils.plt_double_venv_validation(tensorboard_log_dir, reward_train, reward_val, img_save_path)[source]¶
Drawing double_venv_validation images Args:
tensorboard_log_dir (str): path of tb infomation reward_train : dataset mean reward of train dataset reward_val: dataset mean reward of val dataset img_save_path (str): path of saving img
- Return:
saving double venv validation img to the setting path
- revive.utils.common_utils.save_rollout_action(rollout_save_path: str, graph: DesicionGraph, device: str, dataset, nodes_map, horizion_num=10)[source]¶
save the Trj rollout Args:
rollout_save_path: path of saving img data graph: decision graph device: device dataset: dimensions of the data nodes_map: graph nodes horizion_num (int): length to generate data
- Return:
save Trj rollout
- revive.utils.common_utils.data_to_dtreeviz(data: ~pandas.core.frame.DataFrame, target: ~pandas.core.frame.DataFrame, target_type: (typing.List[str], <class 'str'>), orientation: ('TD', 'LR') = 'TD', fancy: bool = True, max_depth: int = 3, output: ~typing.Optional[str] = None)[source]¶
pd data to decision tree Args:
data: dataset in pandas form target: target in pandas form target_type: continuous or discrete orientation: Left to right or top to down fancy: true or false for dtreeviz function max_depth (int): depth of the tree output: whether to output dtreeviz result in the path
- Return:
save Trj rollout
- revive.utils.common_utils.net_to_tree(tree_save_path: str, graph: DesicionGraph, device: str, dataset, nodes)[source]¶
deriving the net model to decision tree Args:
tree_save_path: result saving path graph: decision flow in DesicionGraph type device: device to generate data dataset: dataset for deriving decision tree nodes: nodes in decision flow to derive decision tree
- Return:
save decision tree
- revive.utils.common_utils.generate_response_inputs(expert_data: Batch, dataset: Batch, graph: DesicionGraph, obs_sample_num=16)[source]¶
- revive.utils.common_utils.generate_response_outputs(generated_inputs: defaultdict, expert_data: Batch, venv_train: VirtualEnvDev, venv_val: VirtualEnvDev)[source]¶
- revive.utils.common_utils.plot_response_curve(response_curve_path, graph_train, graph_val, dataset, device, obs_sample_num=16)[source]¶
- revive.utils.common_utils.response_curve(response_curve_path, venv, dataset, device='cuda', obs_sample_num=16)[source]¶
- revive.utils.common_utils.create_unit_vector(d_model)[source]¶
Normalization of unit vector Args:
size: tuple, (num_vecs, dim) / (dim, )
- Return:
a random vector of which the sum eaquals 1
revive.utils.license_utils module¶
- revive.utils.license_utils.get_machine_info(output='./machine_info.json', online=False)[source]¶
Retrieves machine information using pyarmor.
- Args:
output (str): The path to save the machine information as a json file. Defaults to “./machine_info.json”. online (bool): Whether to return the machine information as a string instead of saving it to a file. Defaults to False.
- Returns:
str: A string containing machine information if online is True, else None.
revive.utils.raysgd_utils module¶
- class revive.utils.raysgd_utils.AverageMeter[source]¶
Bases:
object
Computes and stores the average and current value.
revive.utils.server_utils module¶
- class revive.utils.server_utils.DataBufferEnv(venv_max_num: int = 10)[source]¶
Bases:
object
- set_best_venv(venv: VirtualEnv)[source]¶
- get_best_venv() VirtualEnv [source]¶
- update_metric(task_id: int, metric: Dict[int, Union[float, VirtualEnvDev]])[source]¶
- get_venv_list() List[VirtualEnvDev] [source]¶
- class revive.utils.server_utils.DataBufferPolicy[source]¶
Bases:
object
- set_best_policy(policy: PolicyModel)[source]¶
- get_best_policy() PolicyModel [source]¶
- update_metric(task_id: int, metric: Dict[str, Union[float, PolicyModelDev]])[source]¶
- class revive.utils.server_utils.Logger[source]¶
Bases:
object
This is a class called Logger that logs key-value pairs.
- class revive.utils.server_utils.TuneVenvTrain(config, venv_logger, command=None)[source]¶
Bases:
object
- class revive.utils.server_utils.TunePolicyTrain(config, policy_logger, venv_logger=None, command=None)[source]¶
Bases:
object
revive.utils.tune_utils module¶
- class revive.utils.tune_utils.SysStopper(workspace, max_iter: int = 0, stop_callback=None)[source]¶
Bases:
Stopper
Customizing the training mechanism of ray
Reference : https://docs.ray.io/en/latest/tune/api/stoppers.html
- class revive.utils.tune_utils.TuneTBLoggerCallback[source]¶
Bases:
LoggerCallback
custom tensorboard logger for ray tune modified from ray.tune.logger.TBXLogger
Reference: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.logger.LoggerCallback.html
- class revive.utils.tune_utils.CLIReporter(*, metric_columns: Optional[Union[List[str], Dict[str, str]]] = None, parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, total_samples: Optional[int] = None, max_progress_rows: int = 20, max_error_rows: int = 20, max_column_length: int = 20, max_report_frequency: int = 5, infer_limit: int = 3, print_intermediate_tables: Optional[bool] = None, metric: Optional[str] = None, mode: Optional[str] = None, sort_by_metric: bool = False)[source]¶
Bases:
CLIReporter
Modifying the Command line reporter to support logging to loguru
Reference : https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html
- class revive.utils.tune_utils.CustomSearchGenerator(searcher: Searcher)[source]¶
Bases:
SearchGenerator
Customize the SearchGenerator by placing tags in the spec’s config
Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/search_generator.py
- class revive.utils.tune_utils.TrialIterator(uuid_prefix: str, num_samples: int, unresolved_spec: dict, constant_grid_search: bool = False, points_to_evaluate: Optional[List] = None, lazy_eval: bool = False, start: int = 0, random_state: Optional[Union[Generator, RandomState, int]] = None)[source]¶
Bases:
_TrialIterator
Customize the _TrialIterator by placing tags in the spec’s config
Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/basic_variant.py
- class revive.utils.tune_utils.CustomBasicVariantGenerator(points_to_evaluate: Optional[List[Dict]] = None, max_concurrent: int = 0, constant_grid_search: bool = False, random_state: Optional[Union[Generator, RandomState, int]] = None)[source]¶
Bases:
BasicVariantGenerator
Using custom TrialIterator instead _TrialIterator
Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/basic_variant.py
- class revive.utils.tune_utils.Parameter(*args, **kwargs)[source]¶
Bases:
Parameter
Customize Zoom resource allocation method to fully utilize resources
- class revive.utils.tune_utils.ZOOptSearch(algo: str = 'asracos', budget: Optional[int] = None, dim_dict: Optional[Dict] = None, metric: Optional[str] = None, mode: Optional[str] = None, points_to_evaluate: Optional[List[Dict]] = None, parallel_num: int = 1, **kwargs)[source]¶
Bases:
ZOOptSearch
Customize Zoom resource allocation method to fully utilize resources