'numpy.ndarray' object has no attribute 'float'

I’m doing “getting started with RLlib” in the document.
When I run the example
Example: Querying a policy’s action distribution

It works well but if I change the framework into the “torch” instead of “tf2”, the error occurs as

Traceback (most recent call last):
File “/home/mjmj/Play/ray/rllib_queaction.py”, line 19, in
logits, _ = policy.model({“obs”: np.array([[0.1, 0.2, 0.3, 0.4]])})
File “/home/mjmj/anaconda3/envs/ray/lib/python3.9/site-packages/ray/rllib/models/modelv2.py”, line 259, in call
res = self.forward(restored, state or , seq_lens)
File “/home/mjmj/anaconda3/envs/ray/lib/python3.9/site-packages/ray/rllib/models/torch/fcnet.py”, line 144, in forward
obs = input_dict[“obs_flat”].float()
AttributeError: ‘numpy.ndarray’ object has no attribute ‘float’

Is that something I have to do when using torch as the framework?

Hi @keep9oing ,

Thanks for this input. This is expected at the moment.
If you call the model of the policy directly, you are effectively calling the torch.nn.Module that expects a torch tensor as input.
You might get away in some cases if the model converts the tensor under the hood, but not in this case.

Cheers