Source code for revive.common.next_ts_policy_function

import torch
from typing import Dict


[docs] def next_ts_placeholder_policy_function(data : Dict[str, torch.Tensor]) -> torch.Tensor: action = data['placeholder'] dim = action.shape[-1] ts_action = data['ts_'+'placeholder'] return torch.cat([ts_action, action],axis=-1)[..., dim:]