Hi thanks for the answer,
I found it easier to flatten observations this way
def flatten_obs(obs_dict: Dict):
tensors = []
for key in obs_dict.keys():
tensors.append(obs_dict[key])
return torch.cat(tensors, dim=1)
Hi thanks for the answer,
I found it easier to flatten observations this way
def flatten_obs(obs_dict: Dict):
tensors = []
for key in obs_dict.keys():
tensors.append(obs_dict[key])
return torch.cat(tensors, dim=1)