How to use my pretrained model as policy and value netwok

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi, I’m relatively new to RL, but thanks to RLlib, getting bit familiar.

I am doing a small project where I train my agent on a supervised fashion (more like behavior cloning) and then want to further train this network using RL.

I have separately trained a policy and value function (they also have different architecture) and now I want to load the model params from pretrained network and want to further train this on a reinforcement learning fashion.

Is this possible using RLlib?

Thank you and have a good day!

Cheers,
Taehyoung

Hi @cubpaw,

We have recently rolled out a new module and learning stack inside RLlib (RLModules) that gives you a lot of flexibility in achieving what you just explained. Here is the example code.

We recommend to use these nightly releases to use RLModule features.

Having said that it is still possible to do this type of stuff in the current stable release, it would just need more work. It involves two steps:

  1. You need to define a custom model that would initialize the actor and value function using the architecture that they were pertained on. At this stage you need to make sure that you can train an algorithm with this custom model via RLlib. Example:
    ray/rllib/examples/custom_rnn_model.py at master · ray-project/ray · GitHub
  2. You need to use callbacks’s on_algorithm_init hook to load the weights onto your custom model during the initialization.
    Related discourse q: Updating policy_mapping_fn while using tune.run() and restoring from a checkpoint - #3 by Muff2n
    Related example: ray/rllib/examples/restore_1_of_n_agents_from_checkpoint.py at master · ray-project/ray · GitHub

cc @avnishn

1 Like

Thanks for the input! I’ll first try with the current stable release and move over to the nightly releases :slight_smile: !!

HI @kourosh I just wanted to ask you few more questions regarding the implementation that I’m trying to do. In a nutshell, I have trained a policy which roughly does what i want it to do (Take actions) and critic which evaluates the sampled action based on the current state.

So my input size for my actor and critic are as follows

actor - (B, C, H, W) B being batch size, it is an image, actor also outputs and 1 channel image but this is flattened for action sampling

critic - (B, C+1, H, W) the sampled action should be reshaped again to and image array and then concatenated with the current state (q,a) and this should output a reward estimate.

To achieve this, I have created a custom model class as following

class VoxelgymModel(TorchModelV2):
    """Example of a PyTorch custom model"""

    def __init__(self, obs_space, action_space, num_outputs, model_config, name, pretrained_actor_path=None, pretrained_critic_path=None, **kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs)
        nn.Module.__init__(self)

        # access custom_model_config directly from model_config
        custom_model_config = model_config.get('custom_model_config', {})

        # get pretrained paths from custom_model_config
        pretrained_actor_path = custom_model_config.get('pretrained_actor_path')
        pretrained_critic_path = custom_model_config.get('pretrained_critic_path')

        if pretrained_actor_path is None or pretrained_critic_path is None:
            raise ValueError("Both pretrained actor and critic models must be provided.")

        print(pretrained_actor_path)
        print(pretrained_critic_path)

        # Instantiate model
        self.pretrained_actor = PretrainedActor() 
        self.pretrained_critic = PretrainedCritic()

        # Load state_dicts from checkpoints
        actor_checkpoint = torch.load(pretrained_actor_path)
        critic_checkpoint = torch.load(pretrained_critic_path)

        self.pretrained_actor.load_state_dict(actor_checkpoint['state_dict'])
        self.pretrained_critic.load_state_dict(critic_checkpoint['state_dict'])        
        
    # note that out outputs look like this
    # Actor - SemanticSegmenterOutput(
        #     loss=loss,
        #     logits=logits,
        #     hidden_states=None,
        #     attentions=None)
    # Critic -         return SemanticSegmenterOutput(
        #     loss=loss,
        #     logits=reward,
        #     hidden_states=None,
        #     attentions=None)

    def forward(self, input_dict, state, seq_lens):
        input_dict["obs"] = input_dict["obs"].float().to('cuda')
        action_output = self.pretrained_actor(input_dict["obs"]) # action output = SemanticSegmenterOutput
        logits = action_output.logits
        print(logits.shape)
        flattened_logits = logits.view(logits.size(0), -1)  # Flatten the tensor
        print(flattened_logits.shape)
        # here you might have to flatten the action outputs before passing
        return flattened_logits, state

    def value_function(self):
        input_dict = self.input_dict  # These should be set during the forward pass
        actions = self.actions  # These should be set during the forward pass

        input_dict["obs"] = input_dict["obs"].float().to('cuda')
        # action have to be reconstructed before passing it to perform augmentation
        reshaped_actions = actions.view(-1, 42, 42)  # Reshape the tensor
        print(actions.shape)
        print(reshaped_actions.shape)
        augmented_obs = torch.cat((input_dict["obs"], reshaped_actions), dim=1)
        print(augmented_obs.shape)
        value_output = self.pretrained_critic(augmented_obs).logits # value output = SemanticSegmenterOutput
        return value_output.reshape(-1) # maybe you dont even have to do this

Then I use this line of code to build the algo

algo = (
    DDPGConfig()
    .resources(num_gpus=1)
    .framework('torch')
    .environment(env="voxel-v0")
    .training(
        model={
            "custom_model": "voxelgym_model",
            "custom_model_config": {"pretrained_actor_path": actor_checkpoint_file_path,
                                    'pretrained_critic_path': critic_checkpoint_file_path
                                    },
            
        }
    )
    .build()
)

I think the custom model gets defined well up to the forward pass, but since I’m trying to use the sampled action within the value function, I think now the things gets problematic. Is there an algorithm that uses state action value directly to estimate the reward? Is there a way to implement the custom critic as I explained above?

Any sort of input should be appreciated. Thank you and have a great day!

Hey @cubpaw ,

It seems like what you need is a Q-function not a value function. For reference, value function is defined as V(s) where it measures the average reward obtainable under policy pi starting from state s. On the other hand, Q function is defined as Q(s, a) and it measures the average reward obtainable under policy pi if you start on state s and take action a.

I think you should look at how things implemented under DDPG model by default:

There is a get_q_function that is implemented and probably used in the loss function of its policy. You probably want to follow that design pattern.

1 Like

Hey @kourosh thanks for the valuable input again :slight_smile:

Yeap you are right, the function I needed to realize my critic was Q-funtion not value function. I was trying to leverage the custom_model_config feature of rllib to simply use my pretrained networks as policy and Q function, and seems like SAC is the only one that supports this. Is this also possible for other algorithms? or in the case of DDPG or TD3, the custom model has to be manually created as you mentioned in the first answer?

This is how I build my algo when using SAC:

# Register your custom models
ModelCatalog.register_custom_model("custom_policy_model", CustomPolicyModel)
ModelCatalog.register_custom_model("custom_q_model",CustomQModel)

# Set up your SAC agent configuration
algo = (
    SACConfig()
    .resources(num_gpus=1)
    .framework('torch')
    .environment(env="customenv-v0", normalize_actions=False)
    .training(
        # model={
        #     "custom_model": "custom_model",
        #     "custom_model_config": {},
        # },
        policy_model_config={"custom_model": "custom_policy_model",
                             "custom_model_config": {'pretrained_actor_path': actor_checkpoint_file_path}},
        q_model_config={"custom_model": "custom_q_model",
                        "custom_model_config": {'pretrained_critic_path': critic_checkpoint_file_path}}
    )
    .build()
)

Thank you :slight_smile: and have a good one!

Hi, I notice that the ray has updated, but the train_w_bc_finetune_w_ppo.py could not run in the new version. Could you please provide the requirement of this tutorial version or undate the script?