Policy.compute_single_action() wrong outputs

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I have trained two multi-agent RL models and saved them through checkpoints. I used torch and trained on cpu. It’s a PPO model and the actions are multi-discrete. If I now want to deploy these two models simultaneously, I am getting really weird action outputs during inference.

The simplified inference code looks like this:

ModelCatalog.register_custom_model("model1", Net1)
ModelCatalog.register_custom_model("model2", Net2)

pol1 = Policy.from_checkpoint( 'path/to/model1/checkpoint', ['policy1'])
act1 = pol1.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act1[0]) # [0 5 0 1]

pol2 = Policy.from_checkpoint( 'path/to/model2/checkpoint', ['policy2'])
act2 = pol2.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act2[0]) # [2 3 1 0]

So far so good. I have set explore=False and a basic observation input with zeros. So I should always get the same output. And now the strange things start:

If I call pol1.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False) again with the same inputs, the output is something completely different, e.g. [12 0 1 1]. However calling pol2.compute_single_action() still gives the same and correct result. On the other hand, if I call Policy.from_checkpoint('path/to/model1/checkpoint', ['policy1']) without needing to assign it to pol1, the call pol1.compute_single_action() gives again the correct output, but pol2.compute_single_action() gives something different. This is so weird and I don’t know how to resolve this ? I mean why do I need to call Policy.from_checkpoint() before computing an action?

This is the code version of the problem described above:

act1 = pol1.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act1[0]) # [12 0 1 1]
Policy.from_checkpoint( 'path/to/model1/checkpoint', ['policy1'])
act1 = pol1.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act1[0]) # [0 5 0 1]

act2 = pol2.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act2[0]) # [7 4 0 1]