Wrapping Rllib's Built-In Wrappers

Hi all, I’m trying to add an action mask to an LSTM and attention model that works with the use_lstm or use_attention parameters. My basic setup is that I have a wrapper Model that contains a FullyConnected network, and the contained network is set with use_attention=True , while the wrapper isn’t set that way.

class GenericModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name, use_attention=False, use_lstm=False, agent_type="generic", **kwargs):
    nn.Module.__init__(self)
    super(GenericModel, self).__init__(
        obs_space, action_space, num_outputs, model_config, name
    )
    self.real_obs_space = flatten_space(observation_space["real_obs"])
    if use_attention:
        model_config["use_attention"] = True
    if use_lstm:
        model_config["use_lstm"] = True
    self.agent_type = agent_type
    self.fc_embed = FullyConnectedNetwork(
                        self.real_obs_space, action_space, model_config["fcnet_hiddens"][-1], model_config, name, **kwargs
    )
    self.view_requirements = self.fc_embed.view_requirements
    self.view_requirements[SampleBatch.OBS] = ViewRequirement(shift=0, space=obs_space)

From what I understand, this should wrap a smaller model that uses an attention wrapper. However, it’s completely ignoring the attention-wrapping aspect and just using a FullyConnectedNetwork , and I’m struggling to figure out why. Does ray only process attention wrappers outside the highest-level model that’s passed into the trainer?

I know I’m taking an approach not recommended by RLlib Models, Preprocessors, and Action Distributions — Ray v2.0.0.dev0, but I literally only need to add an action mask to an otherwise barebones model, so I figured that wrapping an attention model would make the most sense.

Apologies for reposting this thread from the Ray slack, I figured the official forums would be a better location to discuss this.

1 Like

The reason why it’s ignoring “use_lstm/attention” is that these config settings are only looked at by our ModelCatalog, which generates e.g. your GenericModel in the ModelCatalog.get_model_v2 utility method. The catalog also then takes care of the wrapping. LSTM/attention wrapping are not(!) done inside our built-int default models (such as FullyConnectedNetwork).

Looking at your example (and not being sure exactly what exact architecture you are trying to build), you could try calling ModelCatalog.get_model_v2 inside your model (to have it take care of the auto-wrapping). Just be careful to not pass in the custom_model option to that call, otherwise you probably end up in an infinite recursion.

1 Like

This is exactly what I’m looking for, thank you so much!

1 Like

Using the ModelCatalog.get_model_v2 call seems to have done the trick for getting attention to wrap. However, I’m now struggling to figure out what ViewRequirements the GenericModel should have. I’m getting the following error:

File "/home/jack/anaconda3/lib/python3.8/site-packages/ray/rllib/policy/sample_batch.py", line 78, in __init__
  self.data[k] = np.array(v)
ValueError: could not broadcast input array from shape (50,64) into shape (50)
Result for trial 2021.04.28, 19-14-40:
  {}

Which occurs when creating a sample batch for the state_in_0 key in sample_batch.py. 99% of the values for v are of size (50, 64), while there occassionally is a v of size (50,), which throws off the conversion to a numpy array. I’m thinking that I’m not processing the view requirements for my model properly.

Here’s my current setup.

class GenericModel(TorchModelV2, nn.Module):                                                      
                                                                                                  
    def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs):       
                                                                                                  
        nn.Module.__init__(self)                                                                  
        super(GenericModel, self).__init__(                                                       
            obs_space, action_space, num_outputs, model_config, name                              
        )                                                                                         
        self.real_obs_space = flatten_space(observation_space["real_obs"])                        
        self.agent_type = model_config["custom_model_config"]["agent_type"]                       
        # ray default models that we are wrapping don't accept our custom kwargs                  
        model_config.pop("custom_model_config")                                                   
        if self.agent_type == "attention":                                                        
            model_config["use_attention"] = True                                                  
                                                                                                  
        self.model = ModelCatalog.get_model_v2(                                                   
            self.real_obs_space, action_space, num_outputs, model_config, framework="torch"       
        )                                                                                         
                                                                                                  
        self.view_requirements = self.model.view_requirements                                     
        self.view_requirements[SampleBatch.OBS] = ViewRequirement(shift=0, space=obs_space)       
                                                                                                  
    def forward(self, input_dict, state, seq_lens):                                               
                                                                                                  
        inf_mask = torch.clamp(torch.log(input_dict["obs"]["action_mask"]), FLOAT_MIN, FLOAT_MAX) 
        obs_flat = torch.cat(input_dict["obs"]["real_obs"], axis=1)                               
        obs = {"obs": obs_flat}                                                                   
        outs, memory_outs = self.model(obs, state, seq_lens)                                      
        return outs + inf_mask,  memory_outs                                                      
                                                                                                  
                                                                                                  
    def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:                     
        # Place hidden states on same device as model.                                            
        return self.model.get_initial_state()                                                     
                                                                                                  
    def value_function(self) -> TensorType:                                                       
        return self.model.value_function()                                                        

Is there something else that I’m missing? I’m using a FullyConnectedNetwork is the underlying model to wrap attention around. I get the same error for wrapping with attention or LSTM, and I don’t get this error when I’m not wrapping the model.