revive.computation¶
- class revive.computation.inference.VirtualEnv(env_list: List[VirtualEnvDev])[source]¶
Bases:
object
- to(device)[source]¶
Move model to the device specified by the parameter.
Examples:
>>> venv_model.to("cpu") >>> venv_model.to("cuda") >>> venv_model.to("cuda:1")
- check_version()[source]¶
Check if the revive version of the saved model and the current revive version match.
- reset() None [source]¶
When using RNN for model training, this method needs to be called before model reuse to reset the hidden layer information.
- property target_policy_name: str¶
Get the target policy name.
- replace_policy(policy: PolicyModel) None [source]¶
Replace the target policy with the given policy.
- infer_k_steps(states: Dict[str, ndarray] | List[Dict[str, ndarray]], k: int | None = None, deterministic: bool = True, clip: bool = True) List[Dict[str, ndarray]] [source]¶
Generate k steps interactive data.
- Args:
- states:
a dict of initial input nodes
- k:
how many steps to generate
- deterministic:
if True, the most likely actions are generated; if False, actions are generated by sample. Default: True
- Return:
k steps interactive data dict
Examples:
>>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> ten_step_output = venv_model.infer_k_steps(state, k=10)
- infer_one_step(state: Dict[str, ndarray], deterministic: bool = True, clip: bool = True, **kwargs) Dict[str, ndarray] [source]¶
Generate one step interactive data given action.
- Args:
- state:
a dict of input nodes
- deterministic:
if True, the most likely actions are generated; if False, actions are generated by sample. Default: True
- Return:
one step outputs
Examples:
>>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> one_step_output = venv_model.infer_one_step(state)
- node_infer(node_name: str, state: Dict[str, ndarray], deterministic: bool = True, clip: bool = True) Dict[str, ndarray] [source]¶
Generate one step interactive data given node_name.
- Args:
- state:
a dict of input nodes
- deterministic:
if True, the most likely actions are generated; if False, actions are generated by sample. Default: True
- Return:
one step node output
Examples:
>>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action_output = venv_model.node_infer("action", state)
- node_dist(node_name: str, state: Dict[str, ndarray]) Dict[str, ndarray] [source]¶
Generate one step interactive dist given node_name.
- Args:
- state:
a dict of input nodes
- deterministic:
if True, the most likely actions are generated; if False, actions are generated by sample. Default: True
- Return:
one step node output
Examples:
>>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action_output = venv_model.node_infer("action", state)
- export2onnx(onnx_file: str, verbose: bool = True)[source]¶
Exporting the model to onnx mode.
Reference: https://pytorch.org/docs/stable/onnx.html
- Args:
- onnx_file:
the onnx model file save path.
- verbose:
if True, prints a description of the model being exported to stdout. In addition, the final ONNX graph will include the field
doc_string`
from the exported model which mentions the source code locations formodel
.
- class revive.computation.inference.PolicyModel(policy_model_dev: PolicyModelDev, post_process: Callable[[Dict[str, ndarray], ndarray], ndarray] | None = None)[source]¶
Bases:
object
- to(device: str)[source]¶
Move model to the device specified by the parameter.
Examples:
>>> policy_model.to("cpu") >>> policy_model.to("cuda") >>> policy_model.to("cuda:1")
- check_version()[source]¶
Check if the revive version of the saved model and the current revive version match.
- reset()[source]¶
When using RNN for model training, this method needs to be called before model reuse to reset the hidden layer information.
- property target_policy_name: None¶
Get the target policy name.
- infer(state: Dict[str, ndarray], deterministic: bool = True, clip: bool = True, additional_info: Dict[str, ndarray] | None = None) ndarray [source]¶
Generate action according policy.
- Args:
- state:
a dict contain ALL the input nodes of the policy node
- deterministic:
if True, the most likely actions are generated; if False, actions are generated by sample. Default: True
- clip:
if True, The output will be cropped to the range set in the yaml file; if False, actions are generated by sample. Default: True
- additional_info:
a dict of additional info for post process
- Return:
action
Examples:
>>> state = {"obs": obs_array, "static_obs": static_obs_array} >>> action = policy_model.infer(state)
- export2onnx(onnx_file: str, verbose: bool = True)[source]¶
Exporting the model to onnx mode.
Reference: https://pytorch.org/docs/stable/onnx.html
- Args:
- onnx_file:
the onnx model file save path.
- verbose:
if True, prints a description of the model being exported to stdout. In addition, the final ONNX graph will include the field
doc_string`
from the exported model which mentions the source code locations formodel
.