Jump-Start Reinforcement Learning

So I enabled the 2 options instead of doing them myself

{"log_std_clip_param": 1}
{"free_log_std": True}

It didn’t seem to slow down the training this time (as in how long it takes to process the data), but my model just didn’t learn anything now.

update: I commented out “{“free_log_std”: True}” for now, and set this {“log_std_clip_param”: 1} to 20 and if I still have NaN’s I can work down from there.

Hey @Samuel_Fipps, that seems to be a good strategy. Setting {"free_log_std": True} will have the log_stds as a parameter of the model and just like you said I found that the agent normally doesn’t learn anything. As you go down in the log_std_clip_param you’ll find a spot, hopefully, where the agent learns but no NaNs appear. What sort of an action space are you working with? is it normalized?

best,

Tyler

@tlaurie99 , @Samuel_Fipps

This is interesting to know. This is the way cleanrl and sb3 handle it. I wonder what is different in their implementation. Do either of you have an setup you could share with me. I would like to experiment with a few things but I don’t have an problem handy that is experiencing this issue.

I have done some experimenting with hand-coded policies designed to produce the nans and I can confidently say that a log_std value less than -25 will produce NaNs in the backward pass. I set the log_std_clip_param to -12 and even this did not prevent NaNs.

One way I did manage to prevent nan was to make the following modification. Depending on if you are using the rl_modules or not:

Without rl_module:

self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std) + 0.00001)

With rl_module:

return torch.distributions.normal.Normal(loc, scale + 0.00001)

Here the value of 0.00001 is an epsilon to prevent the NaNs. It corresponds to a log_std of -11.512925464970229 and picked it because it will clamp the value if it is less than that and have no appreciable effect on the distribution with larger stds.

I have verified that it does eliminate the NaNs in my cases but I do not actually have a learning problem with this issue so I do not know how it will affect learning. I would expect that it would not but :man_shrugging: .

I can try that at some point, have you tried something like this (see below)? Sadly my setup is not easily sharable.

    @override(TorchModelV2)
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        obs = input_dict["obs_flat"].float()
        self._last_flat_in = obs.reshape(obs.shape[0], -1)
        self._features = self._hidden_layers(self._last_flat_in)
        logits = self._logits(self._features) if self._logits else self._features
        if self.free_log_std:
            logits = self._append_free_log_std(logits)
 
 
        # ---------------------------------------------------------------------
        # Check for NaNs and replace them with a small constant (e.g. 1e-5).
        # ---------------------------------------------------------------------
        if torch.isnan(logits).any():
            logits[torch.isnan(logits)] = 1e-5

or checking this before it does the log(0):

 logp_ratio = torch.exp(
curr_action_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP]
 )

For zeros and replacing them with 0.000001? just a small number.

This seems to be promising, my model needs a lot of steps to learn so I wont know until Monday.

Also something that I noticed that I mentioned already log_std_clip_param can slow down my progress by 200%. I’m guessing that its having to do a lot of clipping when training? Not to sure. I gather 51200 steps before training, so that is quite a bit to be looping through and clipping. However it still shouldn’t slow it down that much, unless its not very efficient in the way that it goes about it.

Like numpy.append has to recreate the array every time you call that. So I am wondering if something similar is happening.

@Samuel_Fipps,

Fingers crossed.

I am not sure why log_std_clip_param is slowing down training so much. The torch clamp operation itself is not very computationaly intensive so I would not expect that to be the cause.

I am kind of taking a shot in the dark here but my guess would be that there are a lot of instances that are being clamped. The clamping operation is not expensive but a clapped item has no derivate and so it kills the gradient. Any sample that is clamped will not be used to train any of the layers below it. That is my guess as to why training slows.