Jump-Start Reinforcement Learning

Hello,

I was wanting to try something out, I want to “Jump-Start” my agent as seen here: [2204.02372] Jump-Start Reinforcement Learning
The way they go about it is, some of the time they inject their own actions from another source other then the model. Therefore you would store the action you injected and would also send it to the environment. As time goes on you let the model take more control.

I was wondering how I would do this? I was trying to find where in the code the actions that are passed to step() are append for training, I figured that would be the best spot. However I can’t seem to find it.

I am currently using PPOConfig(), not sure if this matters.

It seems something about it was talked about here, but it doesn’t look like anything came from it.

Hi @Samuel_Fipps,

Welcome to the forum.

This is how I would recommend designing jump-start in rllib.

  1. I would create a custom environment that wraps the underlying environment and has the know transitions for the guided policy. It also keeps track of which h is being used from H and when it should transition from the guide actions to the exploration actions. While it is in the guide regime it places the guide action in the info dictionary. You could also always place the guide actions in the info dictionary and add a second entry indicating whether the policy should use it. This would let you define custom metrics to compare the known action and the exploration action.

  2. Create a custom policy that either produces the guide action or uses the neural network policy depending on what is in the info dictionary. The policy does not return actions directly but logits for each action. For a categorical action you could return the one-hot encoding of the action and for a box action you could return a mean of desired value and std of 0.0001.

  3. Create an rllib callback (note: not a tune callback they are different) and use the
    on_evaluate_end to tells the environments to adjust h based on the evaluation results. This will probably be the trickiest part to figure out because you are going to have to use the algorithm datastructures and apis to find and modify the environments on the workers. I usually have to do this in an interactive debug session.

Hopefuly this helps. I am happy to follow up if something doesn’t make sense or you get stuck.

Note: you could also probably use observation and action connectors for 1 and 2. But I don’t use them so I am not sure how I would recommend doing it with those.

I might do something a little easier for a proof of concept for now. I think I can change the forward method and change the direct output of the model to make sure that the outputs are correctly updated throughout the code.

So basically if I want all my actions to be 0.4 I can do this:

    @override(TorchModelV2)
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        obs = input_dict["obs_flat"].float()
        self._last_flat_in = obs.reshape(obs.shape[0], -1)
        self._features = self._hidden_layers(self._last_flat_in)
        logits = self._logits(self._features) if self._logits else self._features
        if self.free_log_std:
            logits = self._append_free_log_std(logits)

        # Modify logits to produce the desired action
        desired_mean = 0.4  # Target action value for both dimensions
        deterministic_variance = -100.0  # Low variance to ensure near-deterministic actions

        logits[:, :2] = desired_mean  # Set the means to 0.4
        logits[:, 2:] = deterministic_variance  # Set log stds to -100.0 (low variance)

        print(logits, flush=True
              )
        return logits, state

I think for testing and seeing how well the “jumpstarting” works I am going to do this.

Do you see any issue with this? Or will messing with the variance mess up the training?

Hi @Samuel_Fipps,

If you want to check if the mechanics work out that should work but you are going to run into two issues with trying to actually learn anything.

The first is this paper assumes that the jump start trajectories H are useful trajectories. Having the guide policy produce the same action repeatedly probably does not satisfy that assumption.

Second during the training phase of the learning process you are going to need the exploration policy which means you need a way to decide in the model’s forward whether you are computing logits during rollouts to use the guide (fixed mean and std) or the exploration policy (logits from the nn). I don’t think there is a way in rllib to tell that from within the model’s forward.

Edit: As a proof of concept a very hacky way you could do this is to check if the model is in training mode. It should be the case that during rollouts the policy will be set to model. eval() and in the loss calculation it should be in model.train(). You can check this with something like if self.training:...

Yes sorry, I meant I was going to implement the “Guided assistant”, and the proper mechanism to switch between training and eval mode in the forward() method. I was still trying to figure out where I wanted to do this at.

I would just be able to override the forward instead of having to make a whole new policy correct?

I’ll keep you updated once I get this done and tested. Its for an agent that fly’s a jet in a video game, so it might take a while to train once I get there. I already have the agent learning just fine without it, but I want to measure how much better it does.

@Samuel_Fipps,

I did not mean to offend. We get a variety of skill levels and experiences here so I try not to assume anything other than what is in the post.

Good luck and yes please share the results if you get some.

Oh no your good, you didn’t offend. I should have some results by the end of week to update this.

I have some results!

Here is the code I used:

class CustomModel(FullyConnectedNetwork):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super().__init__(obs_space, action_space, num_outputs, model_config, name)
        self.total_training_steps = 500_000  # Define the total number of steps for decay
        self.current_step = 0  # Track the current training step
        self.log_file_path = os.path.join(os.getcwd(), "training_steps.txt")

        # Write initialization step to the file
        #with open(self.log_file_path, "a") as log_file:
            #log_file.write(f"Initialized: Current training step = {self.current_step}\n")


    @override(TorchModelV2)
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        obs = input_dict["obs_flat"].float()
        self._last_flat_in = obs.reshape(obs.shape[0], -1)
        self._features = self._hidden_layers(self._last_flat_in)
        logits = self._logits(self._features) if self._logits else self._features
        if self.free_log_std:
            logits = self._append_free_log_std(logits)


        # Modify logits to produce the desired action
        deterministic_variance = -5.0  # Low variance to ensure near-deterministic actions
        switch_probability = max(0.0, 1 * (1 - self.current_step / self.total_training_steps))
        switch_probability = 0
        #with open(self.log_file_path, "a") as log_file:
            #log_file.write(f"Current training step: {self.current_step}\n")
        if self.current_step < 500_000: self.current_step += 1
        
        if logits.shape == torch.Size([1, 4]) and torch.rand(1).item() < switch_probability and not self.training and self.current_step < 500_000:
            # get expert guided actions
            # logits shape torch.Size([1, 6])
            action1, action2, action3 = self.get_expert_action(obs)
            if action1 is not None:
                #with open(self.log_file_path, "a") as log_file:
                    #log_file.write(f"Inside loop: Current training step = {self.current_step}\n")
                tensor_action1 = torch.tensor(action1, dtype=torch.float32, device=logits.device)
                tensor_action2 = torch.tensor(action2, dtype=torch.float32, device=logits.device)

                tensor_action1 = torch.clamp(tensor_action1, min=-1.0, max=1.0)
                tensor_action2 = torch.clamp(tensor_action2, min=-1.0, max=1.0)
                #action3 = torch.clamp(action3, min=-1.0, max=1.0)

                logits[0][0] = tensor_action1
                logits[0][1] = tensor_action2
                #logits[0][2] = action3 #already a tensor 
                #logits[:, 3:] = deterministic_variance  # Set log stds to -10.0 (low variance)
                logits[:, 2:] = deterministic_variance  # Set log stds to -10.0 (low variance)
                

        return logits, state

Note: Anything over 5 million steps of expert help caused exploding gradients. Therefore I messed around with multiple step sizes for the expert help.

My environment is a Jet fighting another jet, each jet just needs to get the target into the firing range of a missile right now the AI just controls the heading, and alt. The speed is hard set.

The red line on the graph is when the expert help was cut off.
Note: I didn’t evaluate. Therefore I am not sure if my code is eval safe. This is just how often the AI won during training.

The Jumpstarting results. (I had to make these 2 post since I can’t upload 2 pictures on one post)

It looks like the jump-start has helped in that you have a higher win rate and the variance of the win rate appears much lower. Any idea why you the performance degrades over training in both cases?

I feel like a broken record because I point a lot of people in this direction. But there can be some issues training PPO with continuous activations. If you have not seen it you might want to give this post a read to see if you are experiencing this issue.

I have no idea why it degrades, I am new to reinforcement learning. I have done machine learning for over 8 years but never reinforcement learning. So all feedback is welcomed thank you!

Therefore I should change my code to this?

        means = self._logits(self._features) if self._logits else self._features
        clamped_log_stds = torch.clamp(self.log_stds, -1.0, 1.0)
        clamped_log_stds = clamped_log_stds.unsqueeze(0).expand_as(means)
        logits = torch.cat([means, clamped_log_stds], dim=-1)

Thank you so much for point that out!

Hey @Samuel_Fipps I was seeing that you were looking at the clamping of the log_stds like @mannyv has been pointing to for, what seems like, years now lol.

So, the actual clamping range of the log_stds depending on your action space. In the example that was given above from the PPO NAN logits, I unknowingly clamped the log_stds from (-1.0, 1.0), but instead I will give you a break down here of what it should be (if this is your problem).

if you have your means from the policy logits normalized between (-1, 1) then your log_stds should be

        logits, _ = self.actor_fcnet(input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits, 2, -1)
        # assuming means are normalized between -1 and 1
        # even if normalized, I have noticed sometimes my actions in the batch are outside these bounds
        means_clamped = torch.clamp(means, -1, 1)
        # this is based on the means being -1 to 1 so the std_dev domain would be [0,1)
        # where exp(-10) and exp(0) would give the above domain for std_dev
        log_stds_clamped = torch.clamp(log_stds, -10, 0)
        logits = torch.cat((means_clamped, log_stds_clamped), dim = -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)

So if the mean of your normalized action space is 0, then your standard deviation’s min. can be near 0 (but not 0) and your standard deviation’s max. can be 1. So, we need to plug values into exp(value) to get near both 0 and 1. Now we currently have log standard deviations and we want just standard deviations, so we would like to take the exp() to get near 0 and 1.

So, what I have concluded is that I can use the following:

  • to get near 0, we can set the lower bound as exp(-10) which is ~ 0.00004539992
  • to get 1, we can set the upper bound as exp(0) which is 1

Therefore, our clamped log_stds should be in reference to the action space where initially I did it arbitrarily between (-1, 1) like you have. This might work, but doesn’t seem to be correct.

@mannyv feel free to correct me if you think this is wrong.

I tried to do this and it greatly slowed my training. I’m guessing gradient clipping wouldn’t work for this issue?

However yes, this is my issue. I am getting the PPO NAN logits.

@Samuel_Fipps,

Do you have a sense of what is making it slow? I would not expect that the clamping and shape wrangling would have that large an effect. I suppose if a lot of values are being clamped that could cause a slow down be cause clamped values will not contribute to learning since the do not have any gradients.

I don’t think gradient clipping is going to help based on what I think is happening.

I think the NaN originates from here:

One if those two terms ends up taking log(0) which returns NaN and blows everything up. Clipping a gradient will not fix that.

Just had an idea. Another option you have rather than clamping is to apply a threshold operator and use a straight-through estimator.

Riffing on @tlaurie99’s example, it would be something like this:

logits, _ = self.actor_fcnet(input_dict, state, seq_lens)
means, log_stds = torch.chunk(logits, 2, -1)
# prevent NaN log(0)  by limiting log_stds >= -10 
log_stds = torch.where(log_stds>=-10, log_stds, log_stds+(-10-log_stds).detach())
logits = torch.cat((means, log_stds), dim = -1)
self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
1 Like

Can I just use the 2 options that are talked about here?

PPO nan in actor logits - #6 by tlaurie99.

@Samuel_Fipps,

Oh, I thought that was way you were saying was slow. Yeah you can give those a try.

Well I did it myself following @tlaurie99 example. I am hoping that some how I messed it up or just had something go wrong with my run, and that using the built in method will work. Thanks for all the help guys!