revive package¶
Subpackages¶
- revive.algo package
- revive.common package
- revive.computation package
- Submodules
- revive.computation.dists module
all_equal()
exportable_broadcast()
ReviveDistributionMixin
ReviveDistribution
ExportableNormal
ExportableCategorical
DiagnalNormal
TransformedDistribution
DiscreteLogistic
DiscreteLogistic.arg_constraints
DiscreteLogistic.support
DiscreteLogistic.has_rsample
DiscreteLogistic.log_prob()
DiscreteLogistic.sample()
DiscreteLogistic.rsample()
DiscreteLogistic.cdf()
DiscreteLogistic.icdf()
DiscreteLogistic.round()
DiscreteLogistic.mode
DiscreteLogistic.std
DiscreteLogistic.entropy()
Onehot
GaussianMixture
MixDistribution
- revive.computation.funs_parser module
- revive.computation.graph module
DesicionNode
NetworkDecisionNode
NetworkDecisionNode.node_type
NetworkDecisionNode.set_network()
NetworkDecisionNode.get_network()
NetworkDecisionNode.initialize_network()
NetworkDecisionNode.__call__()
NetworkDecisionNode.to()
NetworkDecisionNode.requires_grad_()
NetworkDecisionNode.train()
NetworkDecisionNode.eval()
NetworkDecisionNode.reset()
FunctionDecisionNode
DesicionGraph
DesicionGraph.register_node()
DesicionGraph.learnable_node_names
DesicionGraph.register_target_nodes()
DesicionGraph.del_target_nodes()
DesicionGraph.use_target_network()
DesicionGraph.not_use_target_network()
DesicionGraph.update_target_network()
DesicionGraph.mark_tunable()
DesicionGraph.register_processor()
DesicionGraph.get_node()
DesicionGraph.compute_node()
DesicionGraph.get_relation_node_names()
DesicionGraph.summary_nodes()
DesicionGraph.collect_models()
DesicionGraph.is_equal_venv()
DesicionGraph.is_equal_structure()
DesicionGraph.copy_graph_node()
DesicionGraph.get_leaf()
DesicionGraph.sort_graph()
DesicionGraph.to()
DesicionGraph.requires_grad_()
DesicionGraph.eval()
DesicionGraph.reset()
DesicionGraph.__getitem__()
DesicionGraph.keys()
DesicionGraph.values()
DesicionGraph.items()
DesicionGraph.__len__()
DesicionGraph.__call__()
DesicionGraph.state_transition()
DesicionGraph.export2onnx()
- revive.computation.inference module
VirtualEnvDev
VirtualEnvDev.to()
VirtualEnvDev.check_version()
VirtualEnvDev.reset()
VirtualEnvDev.set_target_policy_name()
VirtualEnvDev.infer_k_steps()
VirtualEnvDev.infer_one_step()
VirtualEnvDev.node_infer()
VirtualEnvDev.node_dist()
VirtualEnvDev.node_pre_computation()
VirtualEnvDev.node_post_computation()
VirtualEnvDev.forward()
VirtualEnvDev.pre_computation()
VirtualEnvDev.post_computation()
VirtualEnvDev.export2onnx()
VirtualEnvDev.training
VirtualEnv
VirtualEnv.to()
VirtualEnv.check_version()
VirtualEnv.reset()
VirtualEnv.set_env()
VirtualEnv.target_policy_name
VirtualEnv.set_target_policy_name()
VirtualEnv.replace_policy()
VirtualEnv.infer_k_steps()
VirtualEnv.infer_one_step()
VirtualEnv.node_pre_computation()
VirtualEnv.node_post_computation()
VirtualEnv.node_infer()
VirtualEnv.node_dist()
VirtualEnv.export2onnx()
PolicyModelDev
PolicyModel
- revive.computation.inference_cn module
VirtualEnvDev
VirtualEnvDev.to()
VirtualEnvDev.check_version()
VirtualEnvDev.reset()
VirtualEnvDev.set_target_policy_name()
VirtualEnvDev.infer_k_steps()
VirtualEnvDev.infer_one_step()
VirtualEnvDev.node_infer()
VirtualEnvDev.forward()
VirtualEnvDev.pre_computation()
VirtualEnvDev.post_computation()
VirtualEnvDev.export2onnx()
VirtualEnvDev.training
VirtualEnv
PolicyModelDev
PolicyModel
- revive.computation.modules module
reglu()
geglu()
Swish
MLP
ResBlock
VectorizedLinear
VectorizedMLP
ResNet
Transformer1D
Tokenizer
MultiheadAttention
FT_Transformer
DistributionWrapper
FeedForwardPolicy
RecurrentPolicy
RecurrentRESPolicy
ContextualPolicy
FeedForwardTransition
RecurrentTransition
RecurrentRESTransition
FeedForwardMatcher
RecurrentMatcher
HierarchicalMatcher
VectorizedCritic
- revive.computation.operators module
- revive.computation.utils module
- Module contents
- revive.conf package
- revive.data package
- revive.utils package
- Submodules
- revive.utils.auth_utils module
- revive.utils.casual_graph module
- revive.utils.causal_discovery_utils module
- revive.utils.common_utils module
update_env_vars()
get_env_var()
AttributeDict
setup_seed()
load_npz()
load_h5()
save_h5()
npz2h5()
h52npz()
load_data()
find_policy_index()
load_policy()
download_helper()
import_module_from_file()
get_reward_fn()
get_module()
create_env()
test_one_trail()
test_on_real_env()
get_input_dim_from_graph()
get_input_dim_dict_from_graph()
normalize()
plot_traj()
check_weight()
get_models_parameters()
get_grad_norm()
get_concat_traj()
get_list_traj()
generate_rewards()
generate_rollout()
generate_rollout_bc()
compute_lambda_return()
sinkhorn_gpu()
wasserstein_distance()
compute_w2_dist_to_expert()
dict2parser()
list2parser()
set_parameter_value()
update_description()
find_later()
get_node_dim_from_dist_configs()
save_histogram()
save_histogram_after_stop()
tb_data_parse()
double_venv_validation()
plt_double_venv_validation()
save_rollout_action()
data_to_dtreeviz()
net_to_tree()
generate_response_inputs()
generate_response_outputs()
plot_response_curve()
response_curve()
create_unit_vector()
generate_bin_encoding()
PositionalEncoding
- revive.utils.license_utils module
- revive.utils.raysgd_utils module
- revive.utils.server_utils module
DataBufferEnv
DataBufferEnv.update_status()
DataBufferEnv.get_status()
DataBufferEnv.set_best_venv()
DataBufferEnv.get_best_venv()
DataBufferEnv.get_best_model_workspace()
DataBufferEnv.set_total_trials()
DataBufferEnv.inc_trial()
DataBufferEnv.get_num_of_trial()
DataBufferEnv.update_venv_deque_dict()
DataBufferEnv.delet_deque_item()
DataBufferEnv.update_metric()
DataBufferEnv.get_max_acc()
DataBufferEnv.get_least_metric()
DataBufferEnv.get_best_id()
DataBufferEnv.get_venv_list()
DataBufferEnv.get_dict()
DataBufferEnv.write()
DataBufferPolicy
DataBufferPolicy.update_status()
DataBufferPolicy.get_status()
DataBufferPolicy.set_best_policy()
DataBufferPolicy.get_best_policy()
DataBufferPolicy.get_best_model_workspace()
DataBufferPolicy.set_total_trials()
DataBufferPolicy.inc_trial()
DataBufferPolicy.get_num_of_trial()
DataBufferPolicy.update_metric()
DataBufferPolicy.get_max_reward()
DataBufferPolicy.get_best_id()
DataBufferPolicy.get_dict()
DataBufferPolicy.write()
DataBufferTuner
Logger
trial_str_creator()
catch_error()
TuneVenvTrain
TunePolicyTrain
VenvTrain
PolicyTrain
default_evaluate()
ParameterTuner
- revive.utils.tune_utils module
- Module contents
Submodules¶
revive.server module¶
- class revive.server.ReviveServer(dataset_file_path: str, dataset_desc_file_path: str, val_file_path: Optional[str] = None, user_module_file_path: Optional[str] = None, matcher_reward_file_path: Optional[str] = None, reward_file_path: Optional[str] = None, target_policy_name: Optional[str] = None, log_dir: Optional[str] = None, run_id: Optional[str] = None, address: Optional[str] = None, venv_mode: str = 'tune', policy_mode: str = 'tune', tuning_mode: str = 'None', tune_initial_state: Optional[Dict[str, ndarray]] = None, debug: bool = False, revive_config_file_path: Optional[str] = None, **kwargs)[source]¶
Bases:
object
A class that uses ray to manage all the training tasks. It can automatic search for optimal hyper-parameters.
ReviveServer will do five steps to initialize:
1. Create or connect to a ray cluster. The behavior is controlled by address parameter. If the address parameter is None, it will create its own cluster. If the address parameter is specified, it will connect to the existing cluster.
2. Load config for training. The config is stored in revive/config.py. You can change these parameters by editing the file, passing through command line or through custom_config.
3. Load data and its config, register reward function. The data files are specified by parameters dataset_file_path, dataset_desc_file_path and val_file_path. Note the val_file_path is optional. If it is not specified, revive will split the training data. All the data will be put into the ray object store to share among the whole cluster.
4. Create the folder to store results. The top level folder of these logs are controlled by log_dir parameter. If it is not provided, the default value is the logs folder under the revive repertory. The second-level folder is controlled by the run_id parameter in the training config. If it is not specified, we will generate a random id for the folder. All the training results will be placed in the second-level folder.
5. Create result server as ray actor, and try to load existing results in the log folder. This class is very useful when you want to train a policy or tune parameters from an already trained simulator.
Initialization a Revive Server.
- Args:
- dataset_file_path (str)
The file path where the training dataset is stored. If the val_file_path is “None”, Some data will be cut out from the training dataset as the validation dataset. (e.g., “/data/data.npz” )
- dataset_desc_file_path (str)
The file path where the data description file is stored. (e.g., “/data/test.yaml” )
- val_file_path (str)
The file path where the validate dataset is stored. If it’s “None”, the validation dataset will be cut out from the training dataset.
- reward_file_path (str)
The storage path of the file that defines the reward function.
- target_policy_name (str)
Name of target policy to be optimized. Maximize the defined reward by optimizing the policy. If it is None, the first policy in the graph will be chosen.
- log_dir (str)
Training log and saved model storage folder.
- run_id (str)
The ID of the current running experiment is used to distinguish different training. When it is not provided, an ID will be automatically generated
- address (str)
The address of the ray cluster, If the address parameter is None, it will create its own cluster.
- venv_mode (“tune”,”once”,”None”)
Control the mode of venv training. tune means conducting hyper-parameter search; once means train with the default hyper-parameters; None means skip.’
- policy_mode (“tune”,”once”,”None”)
Control the mode of venv training. tune means conducting hyper-parameter search; once means train with the default hyper-parameters; None means skip.’
- tuning_mode (“max”,”min”,”None”)
Control the mode of parameter tuning. max and min means enabling tuning and the direction; None means skip.’ This feature is currently unstable
- tune_initial_state (str)
Initial state of parameter tuning, needed when tuning mode is enabled.
- debug (bool)
If it is True, Will enter debug mode for debugging.
- custom_config
Json file path. The file content can be used to override the default parameters.
- kwargs
Keyword parameters can be used to override default parameters
- revive_config_file_path¶
preprocess config
- train(env_save_path: Optional[str] = None)[source]¶
Train the virtual environment and policy. Steps
Start ray worker train the virtual environment based on the data;
Start ray worker train train policy based on the virtual environment.
- train_policy(env_save_path: Optional[str] = None)[source]¶
Start ray worker train train policy based on the virtual environment.
- Args:
- env_save_path
virtual environments path
Note
Before train policy, environment models and reward function should be provided.
- tune_parameter(env_save_path: Optional[str] = None)[source]¶
Tune parameters on specified virtual environments.
- Args:
- env_save_path
virtual environments path
Note
This feature is currently unstable.
- get_virtualenv_env() Tuple[VirtualEnv, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]] [source]¶
Get virtual environment models and train log.
- Returns
virtual environment models and train log
- get_policy_model() Tuple[PolicyModel, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]] [source]¶
Get policy based on specified virtual environments.
- Return
policy models and train log
revive.server_cn module¶
- class revive.server_cn.ReviveServer(dataset_file_path: str, dataset_desc_file_path: str, val_file_path: Optional[str] = None, reward_file_path: Optional[str] = None, target_policy_name: Optional[str] = None, log_dir: Optional[str] = None, run_id: Optional[str] = None, address: Optional[str] = None, venv_mode: str = 'tune', policy_mode: str = 'tune', tuning_mode: str = 'None', tune_initial_state: Optional[Dict[str, ndarray]] = None, debug: bool = False, revive_config_file_path: Optional[str] = None, **kwargs)[source]¶
Bases:
object
ReviveServer是Revive SDK的训练入口,负责启动并管理所有训练任务。
ReviveServer 执行四个步骤来完成初始化:
1. 创建或连接到ray集群。集群地址由`address`参数控制。如果`address`参数为`None`,它将创建自己的集群。 如果指定了`address`参数,它将使用参数连接到现有集群。
加载培训配置文件。提供的默认配置文件为`config.json`中。可以通过编辑文件来更改默认参数。
加载决策流图,npz数据和函数。数据文件由参数`dataset_file_path `、`dataset_desc_file_path和`val_file_path’参数指定。
4. 创建日志文件夹存储训练结果。这些日志的顶层文件夹由`log_dir`参数控制。如果未提供,则默认生成的`logs`文件夹。 第二级文件夹由训练配置中的`run_id`参数控制,如果未指定,将为文件夹生成一个随机id。所有训练日志和模型都将放在第二级文件夹中。
- 参数:
- dataset_file_path (str)
训练数据的文件路径(
.npz
或.h5
文件)。- dataset_desc_file_path (str)
决策流图的文件路径(
.yaml
)。- val_file_path (str)
验证数据的文件路径(可选)。
- reward_file_path (str)
定义奖励函数的文件的存储路径。
- target_policy_name (str)
要优化的策略节点的名称。如果为None,则将选择决策流图中的第一网络节点作为策略节点。
- log_dir (str)
模型和训练日志存储文件夹
- run_id (str)
实验ID,用于生成日志文件夹名称,区分不同的实验。如果未提供,系统会自动生成。
- address (str)
ray集群地址,集群地址由`address`参数控制。如果`address`参数为`None`,它将创建自己的集群。如果指定了`address`参数,它将使用参数连接到现有集群。
- venv_mode (“tune”,”once”,”None”)
训练虚拟环境的不同模式: tune 使用超参数搜索来训练虚拟环境模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。 once 使用默认参数训练虚拟环境模型。 None 不训练虚拟环境模型。
- policy_mode (“tune”,”once”,”None”)
策略模型的训练模式: tune 使用超参数搜索来训练策略模型,需要消耗大量的算力和时间,以搜寻超参数来获得更优的模型结果。 once 使用默认参数训练策略模型。 None 不训练策略模型。
- custom_config
超参配置文件路径,可用于覆盖默认参数。
- kwargs
关键字参数,可用于覆盖默认参数。
- revive_config_file_path¶
preprocess config
- train(env_save_path: Optional[str] = None)[source]¶
训练虚拟环境和策略 步骤:
加载数据和参数配置启动ray actor训练虚拟环境;
加载数据,参数和已训练完成的虚拟环境启动ray actor训练策略。
- train_policy(env_save_path: Optional[str] = None)[source]¶
加载数据,参数和已训练完成的虚拟环境启动ray actor训练策略.
- 参数:
- env_save_path
虚拟环境的保存地址,默认为None,将会自动根据run_id查找虚拟环境文件
Note
在训练策略之前,应提供已训练完成的虚拟环境模型和奖励函数。
- get_virtualenv_env() Tuple[VirtualEnv, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]] [source]¶
获取实时最佳虚拟环境模型和训练日志
- get_policy_model() Tuple[PolicyModel, Dict[str, Union[str, float]], Dict[int, Tuple[str, str]]] [source]¶
获取实时最佳策略模型和训练日志