'numpy.ndarray' object has no attribute 'float'

I’m doing “getting started with RLlib” in the document.
When I run the example
Example: Querying a policy’s action distribution

It works well but if I change the framework into the “torch” instead of “tf2”, the error occurs as

Traceback (most recent call last):
File “/home/mjmj/Play/ray/rllib_queaction.py”, line 19, in
logits, _ = policy.model({“obs”: np.array([[0.1, 0.2, 0.3, 0.4]])})
File “/home/mjmj/anaconda3/envs/ray/lib/python3.9/site-packages/ray/rllib/models/modelv2.py”, line 259, in call
res = self.forward(restored, state or , seq_lens)
File “/home/mjmj/anaconda3/envs/ray/lib/python3.9/site-packages/ray/rllib/models/torch/fcnet.py”, line 144, in forward
obs = input_dict[“obs_flat”].float()
AttributeError: ‘numpy.ndarray’ object has no attribute ‘float’

Is that something I have to do when using torch as the framework?

Hi @keep9oing ,

Thanks for this input. This is expected at the moment.
If you call the model of the policy directly, you are effectively calling the torch.nn.Module that expects a torch tensor as input.
You might get away in some cases if the model converts the tensor under the hood, but not in this case.

Cheers

hi, how did you solve this?