Source code for revive.common.next_ts_transition_function
import torch
from typing import Dict
[docs]
def next_ts_placeholder_transition_function(data : Dict[str, torch.Tensor]) -> torch.Tensor:
next_obs = data['next_'+'placeholder']
dim = next_obs.shape[-1]
ts_obs = data['ts_'+'placeholder']
return torch.cat([ts_obs, next_obs],axis=-1)[..., dim:]