How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
I have trained two multi-agent RL models and saved them through checkpoints. I used torch and trained on cpu. It’s a PPO model and the actions are multi-discrete. If I now want to deploy these two models simultaneously, I am getting really weird action outputs during inference.
The simplified inference code looks like this:
ModelCatalog.register_custom_model("model1", Net1)
ModelCatalog.register_custom_model("model2", Net2)
pol1 = Policy.from_checkpoint( 'path/to/model1/checkpoint', ['policy1'])
act1 = pol1.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act1[0]) # [0 5 0 1]
pol2 = Policy.from_checkpoint( 'path/to/model2/checkpoint', ['policy2'])
act2 = pol2.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act2[0]) # [2 3 1 0]
So far so good. I have set explore=False
and a basic observation input with zeros. So I should always get the same output. And now the strange things start:
If I call pol1.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
again with the same inputs, the output is something completely different, e.g. [12 0 1 1]. However calling pol2.compute_single_action()
still gives the same and correct result. On the other hand, if I call Policy.from_checkpoint('path/to/model1/checkpoint', ['policy1'])
without needing to assign it to pol1
, the call pol1.compute_single_action()
gives again the correct output, but pol2.compute_single_action()
gives something different. This is so weird and I don’t know how to resolve this ? I mean why do I need to call Policy.from_checkpoint()
before computing an action?
This is the code version of the problem described above:
act1 = pol1.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act1[0]) # [12 0 1 1]
Policy.from_checkpoint( 'path/to/model1/checkpoint', ['policy1'])
act1 = pol1.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act1[0]) # [0 5 0 1]
act2 = pol2.compute_single_action(obs=torch.zeros(10), state=torch.zeros(1), explore=False)
print(act2[0]) # [7 4 0 1]