Bad inference after perfect training. What am I missing?

so I made a mock-up problem of the real-world combinatorial optimization I have in front of me so I can check the result manually and share it publically.

The .html file of the notebook is on my github.

From the notebook, you can clearly see that the training goes well as all 3 metrics I care about are rising: episode_reward_max, episode_reward_mean, episode_reward_min. However, when I want to make an inference (cell 45) with compute_single_action or compute_action, I am getting very low rewards, even though I do the checkpointing… It looks like random, but I really need the very best possible solution to be returned.

Also note that I am new to the RLlib and Ray in general so that I might be missing something “obvious” :slight_smile:

P.S. I am aware of this issue and related topics here on discuss, but playing with unsquash_action and clip_action did not help :confused:. This was already told to me on the slack channel.

How severe does this issue affect your experience of using Ray?

  • High: It blocks me from completing my task.

Hey @vlainic ,

I can’t say for sure, but here are a couple of things that I find worth looking at:

  • Your write that you checkpoint only at the end of training in cell 43, but the code before looks like you checkpoint at the most promising Trainer.step(), call?

    • This could already be the issue and here is why I think it could be:
      You chose one rollout worker, which makes two workers together with your driver process.
      If you rollouts are extremely short (28) and your train batch size is also super small (28), depending on your reward signal, you might end up with some random rollout in the middle of your training where both workers receive a nice reward signal by “accident” and this will be your largest reward until the end of the training. This essentially means that although the episodic reward signal looks good, you should not decide on a single rollout whether this should be your checkpoint! How long are your episodes? This is an important metric to watch to see if this could be your issue.
    • To see if this is a problem, try actually checkpointing only at the end via ray.tune!
  • Even though these graphs look cool, it is generally worth mentioning that when checkpointing, you want to have an estimate of your policy’s performance that is as free from variance as possible. So evaluating over multiple episodes would be a better way to choose where to checkpoint.

  • One final thought that might not apply to your case: Have a look at your KL loss! Is it down at your checkpoint? If it’s staying up, you have optimized for a stochastic policy and should not disable exploration during evaluation, since this will get rid of stochasticity. That’s a rare case, I just wanted to mention it since I don’t fully understand your environment.

Let us know if this helped!

1 Like

Hello @arturn,

Thanks for the response. So here is what I understand from you:

  • I should have train_batch_size much larger than the episode length? For the current mock-up set up episode_length is static and it is 28. What about rollout_fragment_length?
  • How do I evaluate over multiple episodes? Is it with the checkpoint_freq=n? I understood that parameter as skipping several n-1 and then checking the n-th, not averaging over the last n.
  • where can I find the KL loss? Is it total_loss or kl_coeff or something third?

Today I did play a bit with train_batch_size=2800 (with significant increase in episodes_total too), rollout_fragment_length: [280,560], increased sgd_minibatch_size, tried both options for checkpoint_at_end, changing checkpoint_freq and keep_checkpoints_num… but inference is still the same…

Why is a short episode a problem, btw?

[SOLUTION] I had to change observations, i.e. model input.

What I had initially as input was a row from the mask as it changes with action. Example:

  • Start (cell 10): [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]
  • cell 48 from now on:
    • Action = 0 → [1., 0., 0., -1., -1., -1., -1., -1., -1., -1., -1.]
    • Action = 5 → [1., 0., 0., 0., 0., 1., -1., -1., -1., -1., -1.]
    • Action = 6 → [1., 0., 0., 0., 0., 1., 1., 0., 0., -1., -1.]
    • Action = 9 → well… you can guess it :slight_smile:

However, the above did not work and what finally works the best is just putting the sequence of the last actions done. So to compare it to the previous example:

  • Start: [-1,-1]
  • Action = 0 → [0, -1]
  • Action = 5 → [5, 0]
  • Action = 6: [6, 5]
  • Action = 9: guess again!

Not sure why it was easier for the model to learn from the latter inputs, but that worked :tada:

P.S. From this, it looks that this would be better to approach with LSTM/RNN models, but the problem above is super-simplified and in reality, besides this action sequence there will be more side inputs. Anyway, this is not the topic of this issue…

1 Like