HI @kourosh I just wanted to ask you few more questions regarding the implementation that I’m trying to do. In a nutshell, I have trained a policy which roughly does what i want it to do (Take actions) and critic which evaluates the sampled action based on the current state.
So my input size for my actor and critic are as follows
actor - (B, C, H, W) B being batch size, it is an image, actor also outputs and 1 channel image but this is flattened for action sampling
critic - (B, C+1, H, W) the sampled action should be reshaped again to and image array and then concatenated with the current state (q,a) and this should output a reward estimate.
To achieve this, I have created a custom model class as following
class VoxelgymModel(TorchModelV2):
"""Example of a PyTorch custom model"""
def __init__(self, obs_space, action_space, num_outputs, model_config, name, pretrained_actor_path=None, pretrained_critic_path=None, **kwargs):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs)
nn.Module.__init__(self)
# access custom_model_config directly from model_config
custom_model_config = model_config.get('custom_model_config', {})
# get pretrained paths from custom_model_config
pretrained_actor_path = custom_model_config.get('pretrained_actor_path')
pretrained_critic_path = custom_model_config.get('pretrained_critic_path')
if pretrained_actor_path is None or pretrained_critic_path is None:
raise ValueError("Both pretrained actor and critic models must be provided.")
print(pretrained_actor_path)
print(pretrained_critic_path)
# Instantiate model
self.pretrained_actor = PretrainedActor()
self.pretrained_critic = PretrainedCritic()
# Load state_dicts from checkpoints
actor_checkpoint = torch.load(pretrained_actor_path)
critic_checkpoint = torch.load(pretrained_critic_path)
self.pretrained_actor.load_state_dict(actor_checkpoint['state_dict'])
self.pretrained_critic.load_state_dict(critic_checkpoint['state_dict'])
# note that out outputs look like this
# Actor - SemanticSegmenterOutput(
# loss=loss,
# logits=logits,
# hidden_states=None,
# attentions=None)
# Critic - return SemanticSegmenterOutput(
# loss=loss,
# logits=reward,
# hidden_states=None,
# attentions=None)
def forward(self, input_dict, state, seq_lens):
input_dict["obs"] = input_dict["obs"].float().to('cuda')
action_output = self.pretrained_actor(input_dict["obs"]) # action output = SemanticSegmenterOutput
logits = action_output.logits
print(logits.shape)
flattened_logits = logits.view(logits.size(0), -1) # Flatten the tensor
print(flattened_logits.shape)
# here you might have to flatten the action outputs before passing
return flattened_logits, state
def value_function(self):
input_dict = self.input_dict # These should be set during the forward pass
actions = self.actions # These should be set during the forward pass
input_dict["obs"] = input_dict["obs"].float().to('cuda')
# action have to be reconstructed before passing it to perform augmentation
reshaped_actions = actions.view(-1, 42, 42) # Reshape the tensor
print(actions.shape)
print(reshaped_actions.shape)
augmented_obs = torch.cat((input_dict["obs"], reshaped_actions), dim=1)
print(augmented_obs.shape)
value_output = self.pretrained_critic(augmented_obs).logits # value output = SemanticSegmenterOutput
return value_output.reshape(-1) # maybe you dont even have to do this
Then I use this line of code to build the algo
algo = (
DDPGConfig()
.resources(num_gpus=1)
.framework('torch')
.environment(env="voxel-v0")
.training(
model={
"custom_model": "voxelgym_model",
"custom_model_config": {"pretrained_actor_path": actor_checkpoint_file_path,
'pretrained_critic_path': critic_checkpoint_file_path
},
}
)
.build()
)
I think the custom model gets defined well up to the forward pass, but since I’m trying to use the sampled action within the value function, I think now the things gets problematic. Is there an algorithm that uses state action value directly to estimate the reward? Is there a way to implement the custom critic as I explained above?
Any sort of input should be appreciated. Thank you and have a great day!