Multi-Path Custom Networks

How severe does this issue affect your experience of using Ray?

  • Low: It annoys or frustrates me for a moment.

Lets say I want to route data with different tags (set by the environment) through different heads of my NN e.g., when the final dimension of the observation is 0, take path 1 and when the final dimension of the observation is 1, take path 2.

Both paths output action logits and have a value function. So, we’re just trying to let different heads of the NN specialize to their task id.


Does my custom model need to take into account batch information because you could easily imagine a batch needs to be routed through different parts of the big network? If so, would an implementation like this work or is there something really obvious I’ve missed? For example, would this properly preserve gradients?

class MLP(TorchModelV2, nn.Module):
    """MLP"""
    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str):

        TorchModelV2.__init__(self,
                              obs_space,
                              action_space,
                              num_outputs,
                              model_config,
                              name)
        nn.Module.__init__(self)

        self.hidden_size = 128

        self.fc1 = nn.Linear(obs_space.shape[0], self.hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(self.hidden_size, num_outputs)
        self.value_fn = nn.Linear(self.hidden_size, 1)

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        x = input_dict['obs']
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        self.value = self.value_fn(x)
        x = self.fc3(x)
        return x, state

    @override(TorchModelV2)
    def value_function(self):
        return self.value.squeeze(1)

class ConditionalMultiHeadMLP(TorchModelV2, nn.Module):
    """MLP"""
    def __init__(self, obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name: str):

        TorchModelV2.__init__(self,
                              obs_space,
                              action_space,
                              num_outputs,
                              model_config,
                              name)
        nn.Module.__init__(self)

        self.obs_space = obs_space
        self.action_space = action_space
        self.num_outputs = num_outputs
        self.hidden_size = 128

        self.encoder = nn.Sequential(
            nn.Linear(obs_space.shape[0], self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, obs_space.shape[0]))

        self.opt_network = MLP(obs_space, action_space, num_outputs, model_config, name)
        self.nov_network = MLP(obs_space, action_space, num_outputs, model_config, name)

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):

        # input_dict['obs'] has shape [batch_size, 126]
        x = self.encoder(input_dict['obs'] 
        batch_size = x.shape[0]
        opt_mask = input_dict['obs'][..., -1] == 0
        nov_mask = input_dict['obs'][..., -1] == 1

        to_opt = x[opt_mask]
        to_nov = x[nov_mask]

        x_opt, state = self.opt_network.forward({'obs': to_opt}, state, seq_lens)
        self.opt_value = self.opt_network.value

        x_nov, state = self.nov_network.forward({'obs': to_nov}, state, seq_lens)
        self.nov_value = self.nov_network.value

        # stack back into how it was given
        logits = torch.zeros(batch_size, self.num_outputs)
        logits[opt_mask] += x_opt
        logits[nov_mask] += x_nov
        # stack values similarly
        self.value = torch.zeros(batch_size, 1)
        self.value[opt_mask] += self.opt_value
        self.value[nov_mask] += self.nov_value

        return logits, state

    @override(TorchModelV2)
    def value_function(self):
        return self.value.squeeze(1)

Hi @aadharna,

This should work fine,although you do not need to add the result to zeros you can just assign the result to the appropriate indices.

Here is a collab notebook with a short example showing correct AD on a simple problem.

Fantastic – thanks @mannyv!