CQL restore question - td_error spikes after restore

How severe does this issue affect your experience of using Ray?

  • Low: It annoys or frustrates me for a moment.

When I resume a CQL Tuner job, the policy’s tower_stats ‘q_t’ value is an array of 0s.
After calling tuner.fit(), the td_error spikes up to a value ~= the square of the q_mean value. I think the bellman error function may be using these q_t values incorrectly to create an artificially high error value.

Is there a way to restore a Tuner job and clear the tower_stats?
How can I prevent the td_error from spiking after a restore?

tuner = tune.Tuner.restore(
    INPUT_CHEKPOINT_DIR,
    trainable=trainable,
    resume_errored=True,
)
best_result = tuner.get_results().get_best_result(metric="info/learner/default_policy/learner_stats/mean_q", mode="max")
best_checkpoint = best_result.checkpoint
best_checkpoint.to_directory(os.path.join(MODEL_DIR,'checkpoint'))
algorithm = config.build()
algorithm.restore(os.path.join(MODEL_DIR,'checkpoint'))
policy = algorithm.get_policy()
print(policy.model_gpu_towers[0].tower_stats['q_t'])

Result:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SqueezeBackward1>)

@Walt_Mayfield Could you check to see if this issue also occurs if you run this without Tune?

E.g.

algo = config.build()
for _ in range(10):
    algo.train()
checkpoint = algo.save(checkpoint_dir)
algo.stop()
del algo
algo = config.build()
algo.restore(checkpoint)
policy = algorithm.get_policy()
print(policy.model_gpu_towers[0].tower_stats['q_t'])