revive.utils package¶
Submodules¶
revive.utils.auth_utils module¶
- revive.utils.auth_utils.customer_createTrain(machineCode: str, trainModelSimulatorTotalCount: str, trainModelPolicyTotalCount: str, trainDataRowsCount: str, yamlNodeCount: str, yamlFileClientUrl: str, configFileClientUrl: str, logFileClientUrl: str, userPrivateKey: str)[source]¶
- Verify the user’s training privileges. - API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#uFKnl 
- revive.utils.auth_utils.customer_uploadTrainFile(trainId: str, accessToken: str, yamlFile: Optional[str] = None, configFile: Optional[str] = None, logFile: Optional[str] = None)[source]¶
- Upload the history train log. - API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#r5IPw 
- revive.utils.auth_utils.customer_uploadTrainLog(trainId: str, logFile: str, trainType: str, trainResult: str, trainScore: str, accessToken: str)[source]¶
- Upload the log after a trail is trained. - API Reference: https://polixir.yuque.com/puhlon/rwxlag/gu7pg8#KvKWx 
revive.utils.casual_graph module¶
revive.utils.causal_discovery_utils module¶
- revive.utils.causal_discovery_utils.pc(data: ndarray, indep: str = 'fisherz', thresh: float = 0.05, bg_rules: Optional[BackgroundKnowledge] = None, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]¶
- revive.utils.causal_discovery_utils.fci(data: ndarray, indep: str = 'fisherz', thresh: float = 0.05, bg_rules: Optional[BackgroundKnowledge] = None, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]¶
- revive.utils.causal_discovery_utils.inter_cit(data: ndarray, indep: str = 'fisherz', inter_classes: Iterable[Iterable[Iterable[int]]] = [], in_parallel: bool = True, parallel_limit: int = 5, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]¶
- use cit to discover the relations of variables inter different classes (indicated by indices) 
- revive.utils.causal_discovery_utils.lingam(data: ndarray, ver: str = 'ica', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]¶
- revive.utils.causal_discovery_utils.anm(data: ndarray, kernelX: str = 'Gaussian', kernelY: str = 'Gaussian', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]¶
- revive.utils.causal_discovery_utils.ges(data: ndarray, score_func: str = 'BIC', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]¶
- revive.utils.causal_discovery_utils.exact_search(data: ndarray, method: str = 'astar', callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs) Tuple[ndarray, bool][source]¶
- class revive.utils.causal_discovery_utils.Graph(graph: ndarray, is_real: bool = False, thresh_info: Optional[Dict[str, Any]] = None)[source]¶
- Bases: - object- Causal graph class - property graph¶
- raw graph 
 - property thresh_info¶
- information about threshold 
 
- class revive.utils.causal_discovery_utils.TransitionGraph(graph: ndarray, state_dim: int, action_dim: int, is_real: bool = False, thresh_info: Optional[Dict[str, Any]] = None)[source]¶
- Bases: - Graph- RL transition graph class 
- class revive.utils.causal_discovery_utils.DiscoveryModule(**kwargs)[source]¶
- Bases: - ABC- Base class for causal discovery modules - abstract fit(data: Any, **kwargs) DiscoveryModule[source]¶
 
- class revive.utils.causal_discovery_utils.ClassicalDiscovery(alg: str = 'inter_cit', alg_args: Dict[str, Any] = {'in_parallel': False, 'indep': 'kci'}, state_keys: Optional[List[str]] = ['obs'], action_keys: Optional[List[str]] = ['action'], next_state_keys: Optional[List[str]] = ['next_obs'], limit: Optional[int] = 100, use_residual: bool = True, **kwargs)[source]¶
- Bases: - DiscoveryModule- Classical causal discovery algorithms - CLASSICAL_ALGOS = {'anm': <function anm>, 'exact_search': <function exact_search>, 'fci': <function fci>, 'ges': <function ges>, 'lingam': <function lingam>, 'pc': <function pc>}¶
 - CLASSICAL_ALGOS_TRANSITION = {'inter_cit': <function inter_cit>}¶
 - CLASSICAL_ALGOS_THRESH_INFO = {'anm': {'common': 0.5, 'max': 1.0, 'min': 0.0}, 'inter_cit': {'common': 0.8, 'max': 1.0, 'min': 0.0}, 'lingam': {'common': 0.01, 'max': inf, 'min': 0.0}}¶
 - fit(data: Union[Dict[str, ndarray], ndarray], fit_transition: bool = True) ClassicalDiscovery[source]¶
- fit the discovery module to transition data or general data :param data: dict[str, ndarray] | ndarray, - transition data dictionary or general data matrix - Returns
- the module itself 
 
 
- class revive.utils.causal_discovery_utils.AsyncClassicalDiscovery(alg: str = 'inter_cit', alg_args: Dict[str, Any] = {'in_parallel': False, 'indep': 'kci'}, state_keys: Optional[List[str]] = ['obs'], action_keys: Optional[List[str]] = ['action'], next_state_keys: Optional[List[str]] = ['next_obs'], limit: Optional[int] = 100, use_residual: bool = True, callback: Optional[Callable[[int, int, ndarray, bool], Any]] = None, **kwargs)[source]¶
- Bases: - ClassicalDiscovery- Classical causal discovery algorithms (support asynchronous ver.) - set_callback(callback: Callable[[int, int, ndarray, bool], Any])[source]¶
- set custom callback function 
 - fit(data: Union[Dict[str, ndarray], ndarray], fit_transition: bool = True) ClassicalDiscovery[source]¶
- fit the discovery module to transition data or general data :param data: dict[str, ndarray] | ndarray, - transition data dictionary or general data matrix - Returns
- the module itself 
 
 
revive.utils.common_utils module¶
- revive.utils.common_utils.update_env_vars(key: str, value: Any)[source]¶
- update env vars in os - Args:
- key (str): name of the key. value (str): value for the key 
- Returns:
- update os.environ[‘env_vars’] 
 
- revive.utils.common_utils.get_env_var(key: str, default=None)[source]¶
- get env vars in os - Args:
- key (str): name of the key. default (str): None 
- Returns:
- update os.environ[‘env_vars’] 
 
- class revive.utils.common_utils.AttributeDict[source]¶
- Bases: - dict- define a new class for using get and set variables esily 
- revive.utils.common_utils.setup_seed(seed: int)[source]¶
- Seting random seed in REVIVE. - Args:
- seed: random seed 
 
- revive.utils.common_utils.load_npz(filename: str)[source]¶
- Loading npz file - Args:
- filename(str): *.npz file path 
- Return:
- Dict of data in format of keys:values 
 
- revive.utils.common_utils.load_h5(filename: str)[source]¶
- Loading npz file - Args:
- filename(str): *.h5 file path 
- Return:
- Dict of data in format of keys:values 
 
- revive.utils.common_utils.save_h5(filename: str, data: Dict[str, ndarray])[source]¶
- Loading npz file - Args:
- filename(str): output *.h5 file path 
- Return:
- output file 
 
- revive.utils.common_utils.npz2h5(npz_filename: str, h5_filename: str)[source]¶
- Transforming npz file to h5 file 
- revive.utils.common_utils.h52npz(h5_filename: str, npz_filename: str)[source]¶
- Transforming h5 file to npz file 
- revive.utils.common_utils.load_data(data_file: str)[source]¶
- Loading data file Only support h5 and npz file as data files in REVIVE Args: - Return:
- Dict of data in format of keys:values 
 
- revive.utils.common_utils.find_policy_index(graph: DesicionGraph, policy_name: str)[source]¶
- Find index of policy node in the whole decision flow graph Args: - graph (DesicionGraph): decision flow graph in REVIVE policy_name (str): the policy node name be indexed - Return:
- index of the policy node in decision graph 
- Notice:
- only the first policy node name is supported TODO: multi policy indexes 
 
- revive.utils.common_utils.load_policy(filename: str, policy_name: Optional[str] = None)[source]¶
- Load policy file for REVIVE in the format of torch or .pkl of VirturalEnv VirtualEnvDev or PolicyModelEv Args: - filename (str): file path policy_name (str): the policy node name be indexed - Return:
- Policy model 
 
- revive.utils.common_utils.download_helper(url: str, filename: str)[source]¶
- Download file from given url. Modified from `torchvision.dataset.utils Args: - url (str): donwloading path filename (str): output file path - Return:
- Output path 
 
- revive.utils.common_utils.import_module_from_file(file_path: str, module_name='module.name')[source]¶
- import expert function from file Args: - file_path (str): file path of the expert function module_name (str): function name in the file - Return:
- treat the expert function as an useable funtion in REVIVE 
 
- revive.utils.common_utils.get_reward_fn(reward_file_path: str, config_file: str)[source]¶
- import user defined reward function only for Matcher reward Args: - reward_file_path (str): file path of the expert function config_file (str): decision flow *.yml file - Return:
- treat the reward function as an useable funtion in REVIVE 
 
- revive.utils.common_utils.get_module(function_file_path, config_file)[source]¶
- import user defined function Args: - function_file_path (str): file path of the expert function config_file (str): decision flow *.yml file - Return:
- treat the reward function as an useable funtion in REVIVE 
 
- revive.utils.common_utils.create_env(task: str)[source]¶
- initiating gym environment as testing env for trainning Args: - task (str): gym mujoco task name - Return:
- gym env 
 
- revive.utils.common_utils.test_one_trail(env: Env, policy: PolicyModel)[source]¶
- testing revive policy on gym env Args: - env (str): initialized gym mujoco env policy (str): revive policy used for testing on the env - Return:
- reward and running length of the policy 
 
- revive.utils.common_utils.test_on_real_env(env: Env, policy: PolicyModel, number_of_runs: int = 10)[source]¶
- testing revive policy on multiple gym envs Args: - env (str): initialized gym mujoco env policy (str): revive policy used for testing on the env number_of_runs (int): the number of trails to testing - Return:
- mean value of reward and running length of the policy 
 
- revive.utils.common_utils.get_input_dim_from_graph(graph: DesicionGraph, node_name: str, total_dims: dict)[source]¶
- return the total number of dims used to compute the given node on the graph Args: - graph (DecisionGraph): decision flow with user setting nodes node_name (str): name of the node to get total dimensions total_dims (dict): dict of input and output dims of all nodes - Return:
- total number of dimensions of the node_name 
 
- revive.utils.common_utils.get_input_dim_dict_from_graph(graph: DesicionGraph, node_name: str, total_dims: dict)[source]¶
- return the total number of dims as dictused to compute the given node on the graph Args: - graph (DecisionGraph): decision flow with user setting nodes node_name (str): name of the node to get total dimensions total_dims (dict): dict of input and output dims of all nodes - Return:
- total number of dimensions as dict for all input of the node_name 
 
- revive.utils.common_utils.normalize(data: ndarray)[source]¶
- normalization of data using mean and std Args: - data (np.ndarray): numpy array - Return:
- normalized data 
 
- revive.utils.common_utils.plot_traj(traj: dict)[source]¶
- plot all dims of data into color map along trajectory Args: - traj (dict): data stored in dict - Return:
- plot show with x axis as dims and y axis as traj-step 
 
- revive.utils.common_utils.check_weight(network: Module)[source]¶
- Check whether network parameters are nan or inf. Args: - network (torch.nn.Module): torch.nn.Module - Print:
- nan of inf in network params 
 
- revive.utils.common_utils.get_models_parameters(*models)[source]¶
- return all the parameters of input models in a list Args: - models (torch.nn.Module): all models inputed for getting parameters - Return:
- list of parameters for all models inputted 
 
- revive.utils.common_utils.get_grad_norm(parameters, norm_type: float = 2)[source]¶
- return all gradient of the parameters Args: - models : parameters of the a model - Return:
- L2 norm of the gradient 
 
- revive.utils.common_utils.get_concat_traj(batch_data: Batch, node_names: List[str])[source]¶
- concatenate the data from node_names Args: - batch_data (Batch): Batch of data node_names (List): list of node names to get data - Return:
- data to get 
 
- revive.utils.common_utils.get_list_traj(batch_data: Batch, node_names: List[str], nodes_fit_index: Optional[dict] = None) list[source]¶
- return all data of node_names from batch_data Args: - batch_data (Batch): Batch of data node_names (List): list of node names to get data nodes_fit_index (Dict): dict of fixed index for nodel_names - Return:
- data to get 
 
- revive.utils.common_utils.generate_rewards(traj: Batch, reward_fn)[source]¶
- Add rewards for batch trajectories. Args: - traj: batch trajectories. reward_fn: how the rewards generate. - Return:
- batch trajectories with rewards. 
 
- revive.utils.common_utils.generate_rollout(expert_data: ~revive.data.batch.Batch, graph: ~revive.computation.graph.DesicionGraph, traj_length: int, sample_fn=<function <lambda>>, adapt_stds=None, clip: ~typing.Union[bool, float] = False, use_target: bool = False)[source]¶
- Generate trajectories based on current policy. Args: - expert_data: samples from the dataset. graph: the computation graph traj_length: trajectory length sample_fn: sample from a distribution. - Return:
- batch trajectories. 
 - NOTE: this function will mantain the last dimension even if it is 1 
- revive.utils.common_utils.generate_rollout_bc(expert_data: ~revive.data.batch.Batch, graph: ~revive.computation.graph.DesicionGraph, traj_length: int, sample_fn=<function <lambda>>, adapt_stds=None, clip: ~typing.Union[bool, float] = False, use_target: bool = False)[source]¶
- Generate trajectories based on current policy. Args: - expert_data: samples from the dataset. graph: the computation graph traj_length: trajectory length sample_fn: sample from a distribution. - Return:
- batch trajectories. 
 - NOTE: this function will mantain the last dimension even if it is 1 
- revive.utils.common_utils.compute_lambda_return(rewards, values, bootstrap=None, _gamma=0.9, _lambda=0.98)[source]¶
- Generate lambda return for svg in REVIVE env learning Args: - rewards: reward data for current stated values: values derived from value net bootstrap: bootstrap for the last time step of next_values _gamma: discounted factor _lambda: factor for balancing future or current return - Return:
- discounted return for the input rewards. 
 
- revive.utils.common_utils.sinkhorn_gpu(cuda_id)[source]¶
- Specifically setting running device Args: - cuda_id: cuda device id - Return:
- sinkhorn function 
 
- revive.utils.common_utils.wasserstein_distance(X, Y, cost_matrix, method='sinkhorn', niter=50000, cuda_id=0)[source]¶
- Calculate wasserstein distance Args: - X & Y : two arrays cost_matrix: cost matrix between two arrays method: method for calculating w_distance niter: number of iteration cuda_id: device for calculating w_distance - Return:
- wasserstein distance 
 
- revive.utils.common_utils.compute_w2_dist_to_expert(policy_trajectorys, expert_trajectorys, scaler=None, data_is_standardscaler=False, max_expert_sampes=20000, dist_metric='euclidean', emd_method='emd', processes=None, use_cuda=False, cuda_id_list=None)[source]¶
- Computes Wasserstein 2 distance to expert demonstrations. Calculate wasserstein distance Args: - policy_trajectorys: data generated by policy expert_trajectorys: expert data scaler: scale the data data_is_standardscaler: whether the data is standard scaled or not max_expert_sampes: number of data to use dist_metric: distance type, emd_method: using cpu for computing processes: multi-processing setting use_cuda: using gpu for computing cuda_id_lis: duda device as list - Return:
- wasserstein distance 
 
- revive.utils.common_utils.dict2parser(config: dict)[source]¶
- transform dict as operation setting as parser Args: - config: dict of operation - Return: parser as command 
- revive.utils.common_utils.list2parser(config: List[Dict])[source]¶
- transform list of dict as operation setting as parser Args: - config: list of operation - Return: parser as command 
- revive.utils.common_utils.set_parameter_value(config: List[Dict], name: str, value: Any)[source]¶
- change value of the name in config file Args: - config: list of dict of variables name: the value of the keys to be changed value: the value to be chanbed into - Return: resetting default values to the original config 
- revive.utils.common_utils.update_description(default_description, custom_description)[source]¶
- update in-place the default description with a custom description. Args: - default_description: custom_description: - Return: 
- revive.utils.common_utils.find_later(path: str, keyword: str) List[str][source]¶
- find all the later folder after the given keyword Args: - path: a file path to get list of folder keyword: the name of the folder which as the last folder at the path - Return:
- a list of folder as path 
 
- revive.utils.common_utils.get_node_dim_from_dist_configs(dist_configs: dict, node_name: str)[source]¶
- return the total number of dims of the node_name Args: - dist_configs (dict): decision flow with user setting nodes node_name (str): name of the node to get total dimensions - Return:
- total number of dimensions of the node_name 
 
- revive.utils.common_utils.save_histogram(histogram_path: str, graph: DesicionGraph, data_loader: DataLoader, device: str, scope: str)[source]¶
- save the histogram Args: - histogram_path (str): the path to save histogram graph (DesicionGraph): DesicionGraph data_loader (DataLoader): torch data loader device (str): generate data on which device scope (str): ‘train’ or ‘val’ related to the file-saving name. - Return:
- Saving the histogram as png file to the histogram_path 
 
- revive.utils.common_utils.save_histogram_after_stop(traj_length: int, traj_dir: str, train_dataset, val_dataset)[source]¶
- save the histogram after the training is stopped Args: - traj_length (int): length of the horizon traj_dir (str): saving derectory train_dataset: torch data loader val_dataset: generate data on which device - Return:
- Saving the histogram as png file to the histogram_path 
 
- revive.utils.common_utils.tb_data_parse(tensorboard_log_dir: str, keys: list = [])[source]¶
- parse data from tensorboard logdir Args: - tensorboard_log_dir (str): length of the horizon keys (list): list of keys to get from tb logdir - Return:
- geting a dict of result including value of keys 
 
- revive.utils.common_utils.double_venv_validation(reward_logs, data_reward={}, img_save_path='')[source]¶
- policy double venv validation to the img path Args: - reward_logs (str): dict of different rewards data_reward (dict): dataset mean reward of train and val dataset img_save_path (str): path of saving img - Return:
- saving double venv validation img to the setting path 
 
- revive.utils.common_utils.plt_double_venv_validation(tensorboard_log_dir, reward_train, reward_val, img_save_path)[source]¶
- Drawing double_venv_validation images Args: - tensorboard_log_dir (str): path of tb infomation reward_train : dataset mean reward of train dataset reward_val: dataset mean reward of val dataset img_save_path (str): path of saving img - Return:
- saving double venv validation img to the setting path 
 
- revive.utils.common_utils.save_rollout_action(rollout_save_path: str, graph: DesicionGraph, device: str, dataset, nodes_map, horizion_num=10)[source]¶
- save the Trj rollout Args: - rollout_save_path: path of saving img data graph: decision graph device: device dataset: dimensions of the data nodes_map: graph nodes horizion_num (int): length to generate data - Return:
- save Trj rollout 
 
- revive.utils.common_utils.data_to_dtreeviz(data: ~pandas.core.frame.DataFrame, target: ~pandas.core.frame.DataFrame, target_type: (typing.List[str], <class 'str'>), orientation: ('TD', 'LR') = 'TD', fancy: bool = True, max_depth: int = 3, output: ~typing.Optional[str] = None)[source]¶
- pd data to decision tree Args: - data: dataset in pandas form target: target in pandas form target_type: continuous or discrete orientation: Left to right or top to down fancy: true or false for dtreeviz function max_depth (int): depth of the tree output: whether to output dtreeviz result in the path - Return:
- save Trj rollout 
 
- revive.utils.common_utils.net_to_tree(tree_save_path: str, graph: DesicionGraph, device: str, dataset, nodes)[source]¶
- deriving the net model to decision tree Args: - tree_save_path: result saving path graph: decision flow in DesicionGraph type device: device to generate data dataset: dataset for deriving decision tree nodes: nodes in decision flow to derive decision tree - Return:
- save decision tree 
 
- revive.utils.common_utils.generate_response_inputs(expert_data: Batch, dataset: Batch, graph: DesicionGraph, obs_sample_num=16)[source]¶
- revive.utils.common_utils.generate_response_outputs(generated_inputs: defaultdict, expert_data: Batch, venv_train: VirtualEnvDev, venv_val: VirtualEnvDev)[source]¶
- revive.utils.common_utils.plot_response_curve(response_curve_path, graph_train, graph_val, dataset, device, obs_sample_num=16)[source]¶
- revive.utils.common_utils.response_curve(response_curve_path, venv, dataset, device='cuda', obs_sample_num=16)[source]¶
- revive.utils.common_utils.create_unit_vector(d_model)[source]¶
- Normalization of unit vector Args: - size: tuple, (num_vecs, dim) / (dim, ) - Return:
- a random vector of which the sum eaquals 1 
 
revive.utils.license_utils module¶
- revive.utils.license_utils.get_machine_info(output='./machine_info.json', online=False)[source]¶
- Retrieves machine information using pyarmor. - Args:
- output (str): The path to save the machine information as a json file. Defaults to “./machine_info.json”. online (bool): Whether to return the machine information as a string instead of saving it to a file. Defaults to False. 
- Returns:
- str: A string containing machine information if online is True, else None. 
 
revive.utils.raysgd_utils module¶
- class revive.utils.raysgd_utils.AverageMeter[source]¶
- Bases: - object- Computes and stores the average and current value. 
revive.utils.server_utils module¶
- class revive.utils.server_utils.DataBufferEnv(venv_max_num: int = 10)[source]¶
- Bases: - object- set_best_venv(venv: VirtualEnv)[source]¶
 - get_best_venv() VirtualEnv[source]¶
 - update_metric(task_id: int, metric: Dict[int, Union[float, VirtualEnvDev]])[source]¶
 - get_venv_list() List[VirtualEnvDev][source]¶
 
- class revive.utils.server_utils.DataBufferPolicy[source]¶
- Bases: - object- set_best_policy(policy: PolicyModel)[source]¶
 - get_best_policy() PolicyModel[source]¶
 - update_metric(task_id: int, metric: Dict[str, Union[float, PolicyModelDev]])[source]¶
 
- class revive.utils.server_utils.Logger[source]¶
- Bases: - object- This is a class called Logger that logs key-value pairs. 
- class revive.utils.server_utils.TuneVenvTrain(config, venv_logger, command=None)[source]¶
- Bases: - object
- class revive.utils.server_utils.TunePolicyTrain(config, policy_logger, venv_logger=None, command=None)[source]¶
- Bases: - object
revive.utils.tune_utils module¶
- class revive.utils.tune_utils.SysStopper(workspace, max_iter: int = 0, stop_callback=None)[source]¶
- Bases: - Stopper- Customizing the training mechanism of ray - Reference : https://docs.ray.io/en/latest/tune/api/stoppers.html 
- class revive.utils.tune_utils.TuneTBLoggerCallback[source]¶
- Bases: - LoggerCallback- custom tensorboard logger for ray tune modified from ray.tune.logger.TBXLogger - Reference: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.logger.LoggerCallback.html 
- class revive.utils.tune_utils.CLIReporter(*, metric_columns: Optional[Union[List[str], Dict[str, str]]] = None, parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None, total_samples: Optional[int] = None, max_progress_rows: int = 20, max_error_rows: int = 20, max_column_length: int = 20, max_report_frequency: int = 5, infer_limit: int = 3, print_intermediate_tables: Optional[bool] = None, metric: Optional[str] = None, mode: Optional[str] = None, sort_by_metric: bool = False)[source]¶
- Bases: - CLIReporter- Modifying the Command line reporter to support logging to loguru - Reference : https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html 
- class revive.utils.tune_utils.CustomSearchGenerator(searcher: Searcher)[source]¶
- Bases: - SearchGenerator- Customize the SearchGenerator by placing tags in the spec’s config - Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/search_generator.py 
- class revive.utils.tune_utils.TrialIterator(uuid_prefix: str, num_samples: int, unresolved_spec: dict, constant_grid_search: bool = False, points_to_evaluate: Optional[List] = None, lazy_eval: bool = False, start: int = 0, random_state: Optional[Union[Generator, RandomState, int]] = None)[source]¶
- Bases: - _TrialIterator- Customize the _TrialIterator by placing tags in the spec’s config - Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/basic_variant.py 
- class revive.utils.tune_utils.CustomBasicVariantGenerator(points_to_evaluate: Optional[List[Dict]] = None, max_concurrent: int = 0, constant_grid_search: bool = False, random_state: Optional[Union[Generator, RandomState, int]] = None)[source]¶
- Bases: - BasicVariantGenerator- Using custom TrialIterator instead _TrialIterator - Reference : https://github.com/ray-project/ray/blob/master/python/ray/tune/search/basic_variant.py 
- class revive.utils.tune_utils.Parameter(*args, **kwargs)[source]¶
- Bases: - Parameter- Customize Zoom resource allocation method to fully utilize resources 
- class revive.utils.tune_utils.ZOOptSearch(algo: str = 'asracos', budget: Optional[int] = None, dim_dict: Optional[Dict] = None, metric: Optional[str] = None, mode: Optional[str] = None, points_to_evaluate: Optional[List[Dict]] = None, parallel_num: int = 1, **kwargs)[source]¶
- Bases: - ZOOptSearch- Customize Zoom resource allocation method to fully utilize resources