revive.utils package

Submodules

revive.utils.auth_utils module

revive.utils.auth_utils.customer_createTrain(machineCode: str, trainModelSimulatorTotalCount: str, trainModelPolicyTotalCount: str, trainDataRowsCount: str, yamlNodeCount: str, yamlFileClientUrl: str, configFileClientUrl: str, logFileClientUrl: str, userPrivateKey: str)[source]

Verify the user’s training privileges.

API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#uFKnl

revive.utils.auth_utils.customer_uploadTrainFile(trainId: str, accessToken: str, yamlFile: Optional[str] = None, configFile: Optional[str] = None, logFile: Optional[str] = None)[source]

Upload the history train log.

API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#r5IPw

revive.utils.auth_utils.customer_uploadTrainLog(trainId: str, logFile: str, trainType: str, trainResult: str, trainScore: str, accessToken: str)[source]

Upload the log after a trail is trained.

API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#KvKWx

revive.utils.auth_utils.check_license(cls)[source]

revive.utils.casual_graph module

revive.utils.causal_discovery_utils module

revive.utils.causal_discovery_utils.pc(data: ndarray, indep: str = 'fisherz', thresh: float = 0.05, bg_rules: Optional[BackgroundKnowledge] = None, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]
revive.utils.causal_discovery_utils.fci(data: ndarray, indep: str = 'fisherz', thresh: float = 0.05, bg_rules: Optional[BackgroundKnowledge] = None, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]
revive.utils.causal_discovery_utils.inter_cit(data: ndarray, indep: str = 'fisherz', inter_classes: Iterable[Iterable[Iterable[int]]] = [], in_parallel: bool = True, parallel_limit: int = 5, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]

use cit to discover the relations of variables inter different classes (indicated by indices)

revive.utils.causal_discovery_utils.lingam(data: ndarray, ver: str = 'ica', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]
revive.utils.causal_discovery_utils.anm(data: ndarray, kernelX: str = 'Gaussian', kernelY: str = 'Gaussian', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]
revive.utils.causal_discovery_utils.ges(data: ndarray, score_func: str = 'BIC', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]
class revive.utils.causal_discovery_utils.Graph(graph: ndarray, is_real: bool = False, thresh_info: Optional[Dict[str, Any]] = None)[source]

Bases: object

Causal graph class

property graph

raw graph

property thresh_info

information about threshold

get_adj_matrix()[source]

return transition graph [S+A+S, S] (binary or real)

get_binary_adj_matrix(thresh=None)[source]

return binary transition graph (with threshold specified)

get_binary_adj_matrix_by_sparsity(sparsity=None)[source]

return binary transition graph (with sparsity specified)

class revive.utils.causal_discovery_utils.TransitionGraph(graph: ndarray, state_dim: int, action_dim: int, is_real: bool = False, thresh_info: Optional[Dict[str, Any]] = None)[source]

Bases: Graph

RL transition graph class

class revive.utils.causal_discovery_utils.DiscoveryModule(**kwargs)[source]

Bases: ABC

Base class for causal discovery modules

abstract fit(data: Any, **kwargs) DiscoveryModule[source]
property graph: Optional[Graph]
class revive.utils.causal_discovery_utils.ClassicalDiscovery(alg: str = 'inter_cit', alg_args: Dict[str, Any] = {'in_parallel': False, 'indep': 'kci'}, state_keys: Optional[List[str]] = ['obs'], action_keys: Optional[List[str]] = ['action'], next_state_keys: Optional[List[str]] = ['next_obs'], limit: Optional[int] = 100, use_residual: bool = True, **kwargs)[source]

Bases: DiscoveryModule

Classical causal discovery algorithms

CLASSICAL_ALGOS = {'anm': <function anm>, 'exact_search': <function exact_search>, 'fci': <function fci>, 'ges': <function ges>, 'lingam': <function lingam>, 'pc': <function pc>}
CLASSICAL_ALGOS_TRANSITION = {'inter_cit': <function inter_cit>}
CLASSICAL_ALGOS_THRESH_INFO = {'anm': {'common': 0.5, 'max': 1.0, 'min': 0.0}, 'inter_cit': {'common': 0.8, 'max': 1.0, 'min': 0.0}, 'lingam': {'common': 0.01, 'max': inf, 'min': 0.0}}
fit(data: Union[Dict[str, ndarray], ndarray], fit_transition: bool = True) ClassicalDiscovery[source]

fit the discovery module to transition data or general data :param data: dict[str, ndarray] | ndarray,

transition data dictionary or general data matrix

Returns

the module itself

class revive.utils.causal_discovery_utils.AsyncClassicalDiscovery(alg: str = 'inter_cit', alg_args: Dict[str, Any] = {'in_parallel': False, 'indep': 'kci'}, state_keys: Optional[List[str]] = ['obs'], action_keys: Optional[List[str]] = ['action'], next_state_keys: Optional[List[str]] = ['next_obs'], limit: Optional[int] = 100, use_residual: bool = True, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs)[source]

Bases: ClassicalDiscovery

Classical causal discovery algorithms (support asynchronous ver.)

set_callback(callback: Callable[[int, int, ndarray, bool], Any])[source]

set custom callback function

fit(data: Union[Dict[str, ndarray], ndarray], fit_transition: bool = True) ClassicalDiscovery[source]

fit the discovery module to transition data or general data :param data: dict[str, ndarray] | ndarray,

transition data dictionary or general data matrix

Returns

the module itself

revive.utils.common_utils module

revive.utils.common_utils.update_env_vars(key: str, value: Any)[source]

update env vars in os

Args:

key (str): name of the key. value (str): value for the key

Returns:

update os.environ[‘env_vars’]

revive.utils.common_utils.get_env_var(key: str, default=None)[source]

get env vars in os

Args:

key (str): name of the key. default (str): None

Returns:

update os.environ[‘env_vars’]

class revive.utils.common_utils.AttributeDict[source]

Bases: dict

define a new class for using get and set variables esily

revive.utils.common_utils.setup_seed(seed: int)[source]

Seting random seed in REVIVE.

Args:

seed: random seed

revive.utils.common_utils.load_npz(filename: str)[source]

Loading npz file

Args:

filename(str): *.npz file path

Return:

Dict of data in format of keys:values

revive.utils.common_utils.load_h5(filename: str)[source]

Loading npz file

Args:

filename(str): *.h5 file path

Return:

Dict of data in format of keys:values

revive.utils.common_utils.save_h5(filename: str, data: Dict[str, ndarray])[source]

Loading npz file

Args:

filename(str): output *.h5 file path

Return:

output file

revive.utils.common_utils.npz2h5(npz_filename: str, h5_filename: str)[source]

Transforming npz file to h5 file

Args:

npz_filename (str): *.npz file path h5_filename (str): output *.h5 file path

Return:

output file

revive.utils.common_utils.h52npz(h5_filename: str, npz_filename: str)[source]

Transforming h5 file to npz file

Args:

h5_filename (str): input *.h5 file path npz_filename (str): output *.npz file path

Return:

output file

revive.utils.common_utils.load_data(data_file: str)[source]

Loading data file Only support h5 and npz file as data files in REVIVE Args:

data_file (str): input *.h5 or *.npz file path

Return:

Dict of data in format of keys:values

revive.utils.common_utils.find_policy_index(graph: DesicionGraph, policy_name: str)[source]

Find index of policy node in the whole decision flow graph Args:

graph (DesicionGraph): decision flow graph in REVIVE policy_name (str): the policy node name be indexed

Return:

index of the policy node in decision graph

Notice:

only the first policy node name is supported TODO: multi policy indexes

revive.utils.common_utils.load_policy(filename: str, policy_name: Optional[str] = None)[source]

Load policy file for REVIVE in the format of torch or .pkl of VirturalEnv VirtualEnvDev or PolicyModelEv Args:

filename (str): file path policy_name (str): the policy node name be indexed

Return:

Policy model

revive.utils.common_utils.download_helper(url: str, filename: str)[source]

Download file from given url. Modified from `torchvision.dataset.utils Args:

url (str): donwloading path filename (str): output file path

Return:

Output path

revive.utils.common_utils.import_module_from_file(file_path: str, module_name='module.name')[source]

import expert function from file Args:

file_path (str): file path of the expert function module_name (str): function name in the file

Return:

treat the expert function as an useable funtion in REVIVE

revive.utils.common_utils.get_reward_fn(reward_file_path: str, config_file: str)[source]

import user defined reward function only for Matcher reward Args:

reward_file_path (str): file path of the expert function config_file (str): decision flow *.yml file

Return:

treat the reward function as an useable funtion in REVIVE

revive.utils.common_utils.get_module(function_file_path, config_file)[source]

import user defined function Args:

function_file_path (str): file path of the expert function config_file (str): decision flow *.yml file

Return:

treat the reward function as an useable funtion in REVIVE

revive.utils.common_utils.create_env(task: str)[source]

initiating gym environment as testing env for trainning Args:

task (str): gym mujoco task name

Return:

gym env

revive.utils.common_utils.test_one_trail(env: Env, policy: PolicyModel)[source]

testing revive policy on gym env Args:

env (str): initialized gym mujoco env policy (str): revive policy used for testing on the env

Return:

reward and running length of the policy

revive.utils.common_utils.test_on_real_env(env: Env, policy: PolicyModel, number_of_runs: int = 10)[source]

testing revive policy on multiple gym envs Args:

env (str): initialized gym mujoco env policy (str): revive policy used for testing on the env number_of_runs (int): the number of trails to testing

Return:

mean value of reward and running length of the policy

revive.utils.common_utils.get_input_dim_from_graph(graph: DesicionGraph, node_name: str, total_dims: dict)[source]

return the total number of dims used to compute the given node on the graph Args:

graph (DecisionGraph): decision flow with user setting nodes node_name (str): name of the node to get total dimensions total_dims (dict): dict of input and output dims of all nodes

Return:

total number of dimensions of the node_name

revive.utils.common_utils.get_input_dim_dict_from_graph(graph: DesicionGraph, node_name: str, total_dims: dict)[source]

return the total number of dims as dictused to compute the given node on the graph Args:

graph (DecisionGraph): decision flow with user setting nodes node_name (str): name of the node to get total dimensions total_dims (dict): dict of input and output dims of all nodes

Return:

total number of dimensions as dict for all input of the node_name

revive.utils.common_utils.normalize(data: ndarray)[source]

normalization of data using mean and std Args:

data (np.ndarray): numpy array

Return:

normalized data

revive.utils.common_utils.plot_traj(traj: dict)[source]

plot all dims of data into color map along trajectory Args:

traj (dict): data stored in dict

Return:

plot show with x axis as dims and y axis as traj-step

revive.utils.common_utils.check_weight(network: Module)[source]

Check whether network parameters are nan or inf. Args:

network (torch.nn.Module): torch.nn.Module

Print:

nan of inf in network params

revive.utils.common_utils.get_models_parameters(*models)[source]

return all the parameters of input models in a list Args:

models (torch.nn.Module): all models inputed for getting parameters

Return:

list of parameters for all models inputted

revive.utils.common_utils.get_grad_norm(parameters, norm_type: float = 2)[source]

return all gradient of the parameters Args:

models : parameters of the a model

Return:

L2 norm of the gradient

revive.utils.common_utils.get_concat_traj(batch_data: Batch, node_names: List[str])[source]

concatenate the data from node_names Args:

batch_data (Batch): Batch of data node_names (List): list of node names to get data

Return:

data to get

revive.utils.common_utils.get_list_traj(batch_data: Batch, node_names: List[str], nodes_fit_index: Optional[dict] = None) list[source]

return all data of node_names from batch_data Args:

batch_data (Batch): Batch of data node_names (List): list of node names to get data nodes_fit_index (Dict): dict of fixed index for nodel_names

Return:

data to get

revive.utils.common_utils.generate_rewards(traj: Batch, reward_fn)[source]

Add rewards for batch trajectories. Args:

traj: batch trajectories. reward_fn: how the rewards generate.

Return:

batch trajectories with rewards.

revive.utils.common_utils.generate_rollout(expert_data: ~revive.data.batch.Batch, graph: ~revive.computation.graph.DesicionGraph, traj_length: int, sample_fn=<function <lambda>>, adapt_stds=None, clip: ~typing.Union[bool, float] = False, use_target: bool = False)[source]

Generate trajectories based on current policy. Args:

expert_data: samples from the dataset. graph: the computation graph traj_length: trajectory length sample_fn: sample from a distribution.

Return:

batch trajectories.

NOTE: this function will mantain the last dimension even if it is 1

revive.utils.common_utils.generate_rollout_bc(expert_data: ~revive.data.batch.Batch, graph: ~revive.computation.graph.DesicionGraph, traj_length: int, sample_fn=<function <lambda>>, adapt_stds=None, clip: ~typing.Union[bool, float] = False, use_target: bool = False)[source]

Generate trajectories based on current policy. Args:

expert_data: samples from the dataset. graph: the computation graph traj_length: trajectory length sample_fn: sample from a distribution.

Return:

batch trajectories.

NOTE: this function will mantain the last dimension even if it is 1

revive.utils.common_utils.compute_lambda_return(rewards, values, bootstrap=None, _gamma=0.9, _lambda=0.98)[source]

Generate lambda return for svg in REVIVE env learning Args:

rewards: reward data for current stated values: values derived from value net bootstrap: bootstrap for the last time step of next_values _gamma: discounted factor _lambda: factor for balancing future or current return

Return:

discounted return for the input rewards.

revive.utils.common_utils.sinkhorn_gpu(cuda_id)[source]

Specifically setting running device Args:

cuda_id: cuda device id

Return:

sinkhorn function

revive.utils.common_utils.wasserstein_distance(X, Y, cost_matrix, method='sinkhorn', niter=50000, cuda_id=0)[source]

Calculate wasserstein distance Args:

X & Y : two arrays cost_matrix: cost matrix between two arrays method: method for calculating w_distance niter: number of iteration cuda_id: device for calculating w_distance

Return:

wasserstein distance

revive.utils.common_utils.compute_w2_dist_to_expert(policy_trajectorys, expert_trajectorys, scaler=None, data_is_standardscaler=False, max_expert_sampes=20000, dist_metric='euclidean', emd_method='emd', processes=None, use_cuda=False, cuda_id_list=None)[source]

Computes Wasserstein 2 distance to expert demonstrations. Calculate wasserstein distance Args:

policy_trajectorys: data generated by policy expert_trajectorys: expert data scaler: scale the data data_is_standardscaler: whether the data is standard scaled or not max_expert_sampes: number of data to use dist_metric: distance type, emd_method: using cpu for computing processes: multi-processing setting use_cuda: using gpu for computing cuda_id_lis: duda device as list

Return:

wasserstein distance

revive.utils.common_utils.dict2parser(config: dict)[source]

transform dict as operation setting as parser Args:

config: dict of operation

Return: parser as command

revive.utils.common_utils.list2parser(config: List[Dict])[source]

transform list of dict as operation setting as parser Args:

config: list of operation

Return: parser as command

revive.utils.common_utils.set_parameter_value(config: List[Dict], name: str, value: Any)[source]

change value of the name in config file Args:

config: list of dict of variables name: the value of the keys to be changed value: the value to be chanbed into

Return: resetting default values to the original config

revive.utils.common_utils.update_description(default_description, custom_description)[source]

update in-place the default description with a custom description. Args:

default_description: custom_description:

Return:

revive.utils.common_utils.find_later(path: str, keyword: str) List[str][source]

find all the later folder after the given keyword Args:

path: a file path to get list of folder keyword: the name of the folder which as the last folder at the path

Return:

a list of folder as path

revive.utils.common_utils.get_node_dim_from_dist_configs(dist_configs: dict, node_name: str)[source]

return the total number of dims of the node_name Args:

dist_configs (dict): decision flow with user setting nodes node_name (str): name of the node to get total dimensions

Return:

total number of dimensions of the node_name

revive.utils.common_utils.save_histogram(histogram_path: str, graph: DesicionGraph, data_loader: DataLoader, device: str, scope: str)[source]

save the histogram Args:

histogram_path (str): the path to save histogram graph (DesicionGraph): DesicionGraph data_loader (DataLoader): torch data loader device (str): generate data on which device scope (str): ‘train’ or ‘val’ related to the file-saving name.

Return:

Saving the histogram as png file to the histogram_path

revive.utils.common_utils.save_histogram_after_stop(traj_length: int, traj_dir: str, train_dataset, val_dataset)[source]

save the histogram after the training is stopped Args:

traj_length (int): length of the horizon traj_dir (str): saving derectory train_dataset: torch data loader val_dataset: generate data on which device

Return:

Saving the histogram as png file to the histogram_path

revive.utils.common_utils.tb_data_parse(tensorboard_log_dir: str, keys: list = [])[source]

parse data from tensorboard logdir Args:

tensorboard_log_dir (str): length of the horizon keys (list): list of keys to get from tb logdir

Return:

geting a dict of result including value of keys

revive.utils.common_utils.double_venv_validation(reward_logs, data_reward={}, img_save_path='')[source]

policy double venv validation to the img path Args:

reward_logs (str): dict of different rewards data_reward (dict): dataset mean reward of train and val dataset img_save_path (str): path of saving img

Return:

saving double venv validation img to the setting path

revive.utils.common_utils.plt_double_venv_validation(tensorboard_log_dir, reward_train, reward_val, img_save_path)[source]

Drawing double_venv_validation images Args:

tensorboard_log_dir (str): path of tb infomation reward_train : dataset mean reward of train dataset reward_val: dataset mean reward of val dataset img_save_path (str): path of saving img

Return:

saving double venv validation img to the setting path

revive.utils.common_utils.save_rollout_action(rollout_save_path: str, graph: DesicionGraph, device: str, dataset, nodes_map, horizion_num=10)[source]

save the Trj rollout Args:

rollout_save_path: path of saving img data graph: decision graph device: device dataset: dimensions of the data nodes_map: graph nodes horizion_num (int): length to generate data

Return:

save Trj rollout

revive.utils.common_utils.data_to_dtreeviz(data: ~pandas.core.frame.DataFrame, target: ~pandas.core.frame.DataFrame, target_type: (typing.List[str], <class 'str'>), orientation: ('TD', 'LR') = 'TD', fancy: bool = True, max_depth: int = 3, output: ~typing.Optional[str] = None)[source]

pd data to decision tree Args:

data: dataset in pandas form target: target in pandas form target_type: continuous or discrete orientation: Left to right or top to down fancy: true or false for dtreeviz function max_depth (int): depth of the tree output: whether to output dtreeviz result in the path

Return:

save Trj rollout

revive.utils.common_utils.net_to_tree(tree_save_path: str, graph: DesicionGraph, device: str, dataset, nodes)[source]

deriving the net model to decision tree Args:

tree_save_path: result saving path graph: decision flow in DesicionGraph type device: device to generate data dataset: dataset for deriving decision tree nodes: nodes in decision flow to derive decision tree

Return:

save decision tree

revive.utils.common_utils.generate_response_inputs(expert_data: Batch, dataset: Batch, graph: DesicionGraph, obs_sample_num=16)[source]
revive.utils.common_utils.generate_response_outputs(generated_inputs: defaultdict, expert_data: Batch, venv_train: VirtualEnvDev, venv_val: VirtualEnvDev)[source]
revive.utils.common_utils.plot_response_curve(response_curve_path, graph_train, graph_val, dataset, device, obs_sample_num=16)[source]
revive.utils.common_utils.response_curve(response_curve_path, venv, dataset, device='cuda', obs_sample_num=16)[source]
revive.utils.common_utils.create_unit_vector(d_model)[source]

Normalization of unit vector Args:

size: tuple, (num_vecs, dim) / (dim, )

Return:

a random vector of which the sum eaquals 1

revive.utils.common_utils.generate_bin_encoding(traj_num)[source]

transform trajectory id into binary code in form of list Args:

traj_num: the number of traj in data

Return:

binary vector of traj id

class revive.utils.common_utils.PositionalEncoding(d_model: int, max_len: int = 5000)[source]

Bases: object

encode(pos) ndarray[source]
Args:

pos: Array, shape [seq_len, 1] / scalar

Output:

pe: Array, shape [seq_len, d_model] / [d_model, ]

revive.utils.license_utils module

revive.utils.license_utils.get_machine_info(output='./machine_info.json', online=False)[source]

Retrieves machine information using pyarmor.

Args:

output (str): The path to save the machine information as a json file. Defaults to “./machine_info.json”. online (bool): Whether to return the machine information as a string instead of saving it to a file. Defaults to False.

Returns:

str: A string containing machine information if online is True, else None.

revive.utils.raysgd_utils module

class revive.utils.raysgd_utils.AverageMeter[source]

Bases: object

Computes and stores the average and current value.

reset()[source]
update(val, n=1)[source]

Update current value, total sum, and average.

class revive.utils.raysgd_utils.AverageMeterCollection[source]

Bases: object

This is a class called AverageMeterCollection that calculates and stores the average metrics for a collection of meters.

update(metrics, n=1)[source]

Does one batch of updates for the provided metrics.

summary()[source]

Returns a dict of average and most recent values for each metric.

revive.utils.server_utils module

class revive.utils.server_utils.DataBufferEnv(venv_max_num: int = 10)[source]

Bases: object

update_status(task_id: int, status: str, message: str = '')[source]
get_status() Dict[int, Tuple[str, str]][source]
set_best_venv(venv: VirtualEnv)[source]
get_best_venv() VirtualEnv[source]
get_best_model_workspace() str[source]
set_total_trials(trials: int)[source]
inc_trial() int[source]
get_num_of_trial() int[source]
update_venv_deque_dict(task_id, venv_train, venv_val)[source]
delet_deque_item(task_id, index)[source]
update_metric(task_id: int, metric: Dict[int, Union[float, VirtualEnvDev]])[source]
get_max_acc() float[source]
get_least_metric() float[source]
get_best_id() int[source]
get_venv_list() List[VirtualEnvDev][source]
get_dict()[source]
write(filename: str)[source]
class revive.utils.server_utils.DataBufferPolicy[source]

Bases: object

update_status(task_id: int, status: str, message: str = '')[source]
get_status() Dict[int, Tuple[str, str]][source]
set_best_policy(policy: PolicyModel)[source]
get_best_policy() PolicyModel[source]
get_best_model_workspace() str[source]
set_total_trials(trials: int)[source]
inc_trial() int[source]
get_num_of_trial() int[source]
update_metric(task_id: int, metric: Dict[str, Union[float, PolicyModelDev]])[source]
get_max_reward()[source]
get_best_id()[source]
get_dict()[source]
write(filename: str)[source]
class revive.utils.server_utils.DataBufferTuner(mode: str, budget: int)[source]

Bases: object

get_state()[source]
update(parameter: Dict[str, ndarray], metric: float)[source]
class revive.utils.server_utils.Logger[source]

Bases: object

This is a class called Logger that logs key-value pairs.

get_log()[source]
update(key, value)[source]
revive.utils.server_utils.trial_str_creator(trial)[source]
revive.utils.server_utils.catch_error(func)[source]

Push the training error message to data buffer

class revive.utils.server_utils.TuneVenvTrain(config, venv_logger, command=None)[source]

Bases: object

train(*args, **kwargs)[source]
class revive.utils.server_utils.TunePolicyTrain(config, policy_logger, venv_logger=None, command=None)[source]

Bases: object

train(*args, **kwargs)[source]
class revive.utils.server_utils.VenvTrain(config, venv_logger, command=None)[source]

Bases: object

train(*args, **kwargs)[source]
class revive.utils.server_utils.PolicyTrain(config, policy_logger, venv_logger=None, command=None)[source]

Bases: object

train(*args, **kwargs)[source]
revive.utils.server_utils.default_evaluate(config)[source]
class revive.utils.server_utils.ParameterTuner(config, mode, initial_state, logger, venv_logger=None)[source]

Bases: object

run()[source]

revive.utils.tune_utils module

class revive.utils.tune_utils.SysStopper(workspace, max_iter: int = 0, stop_callback=None)[source]

Bases: Stopper

Customizing the training mechanism of ray

Reference : https://docs.ray.io/en/latest/tune/api/stoppers.html

__call__(trial_id, result)[source]

Returns true if the trial should be terminated given the result.

stop_all()[source]

Returns true if the experiment should be terminated.

class revive.utils.tune_utils.TuneTBLoggerCallback[source]

Bases: LoggerCallback

custom tensorboard logger for ray tune modified from ray.tune.logger.TBXLogger

Reference: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.logger.LoggerCallback.html

on_result(result)[source]
flush()[source]
revive.utils.tune_utils.get_tune_callbacks()[source]
class revive.utils.tune_utils.CLIReporter(*, metric_columns: Optional[Union[List[str], Dict[str, str]]] = None, parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, total_samples: Optional[int] = None, max_progress_rows: int = 20, max_error_rows: int = 20, max_column_length: int = 20, max_report_frequency: int = 5, infer_limit: int = 3, print_intermediate_tables: Optional[bool] = None, metric: Optional[str] = None, mode: Optional[str] = None, sort_by_metric: bool = False)[source]

Bases: CLIReporter

Modifying the Command line reporter to support logging to loguru

Reference : https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html

report(trials: List, done: bool, *sys_info: Dict)[source]

Reports progress across trials.

Args:

trials: Trials to report on. done: Whether this is the last progress report attempt. sys_info: System info.

class revive.utils.tune_utils.CustomSearchGenerator(searcher: Searcher)[source]

Bases: SearchGenerator

Customize the SearchGenerator by placing tags in the spec’s config

Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/search_generator.py

create_trial_if_possible(experiment_spec, output_path)[source]
class revive.utils.tune_utils.TrialIterator(uuid_prefix: str, num_samples: int, unresolved_spec: dict, constant_grid_search: bool = False, points_to_evaluate: Optional[List] = None, lazy_eval: bool = False, start: int = 0, random_state: Optional[Union[Generator, RandomState, int]] = None)[source]

Bases: _TrialIterator

Customize the _TrialIterator by placing tags in the spec’s config

Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/basic_variant.py

create_trial(resolved_vars, spec)[source]
class revive.utils.tune_utils.CustomBasicVariantGenerator(points_to_evaluate: Optional[List[Dict]] = None, max_concurrent: int = 0, constant_grid_search: bool = False, random_state: Optional[Union[Generator, RandomState, int]] = None)[source]

Bases: BasicVariantGenerator

Using custom TrialIterator instead _TrialIterator

Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/basic_variant.py

add_configurations(experiments: Union[Experiment, List[Experiment], Dict[str, Dict]])[source]

Chains generator given experiment specifications.

Arguments:

experiments (Experiment | list | dict): Experiments to run.

class revive.utils.tune_utils.Parameter(*args, **kwargs)[source]

Bases: Parameter

Customize Zoom resource allocation method to fully utilize resources

auto_set(budget)[source]
Set train_size, positive_size, negative_size by following rules:

budget < 3 –> error; budget < 3 –> train_size = p, positive_size = (0.2*self.parallel_num);

Parameters

budget – number of calls to the objective function

Returns

no return value

class revive.utils.tune_utils.ZOOptSearch(algo: str = 'asracos', budget: Optional[int] = None, dim_dict: Optional[Dict] = None, metric: Optional[str] = None, mode: Optional[str] = None, points_to_evaluate: Optional[List[Dict]] = None, parallel_num: int = 1, **kwargs)[source]

Bases: ZOOptSearch

Customize Zoom resource allocation method to fully utilize resources

Module contents