How severe does this issue affect your experience of using Ray?
- Low: It annoys or frustrates me for a moment.
Lets say I want to route data with different tags (set by the environment) through different heads of my NN e.g., when the final dimension of the observation is 0, take path 1 and when the final dimension of the observation is 1, take path 2.
Both paths output action logits and have a value function. So, we’re just trying to let different heads of the NN specialize to their task id.
Does my custom model need to take into account batch information because you could easily imagine a batch needs to be routed through different parts of the big network? If so, would an implementation like this work or is there something really obvious I’ve missed? For example, would this properly preserve gradients?
class MLP(TorchModelV2, nn.Module):
"""MLP"""
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
TorchModelV2.__init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name)
nn.Module.__init__(self)
self.hidden_size = 128
self.fc1 = nn.Linear(obs_space.shape[0], self.hidden_size)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(self.hidden_size, num_outputs)
self.value_fn = nn.Linear(self.hidden_size, 1)
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
x = input_dict['obs']
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
self.value = self.value_fn(x)
x = self.fc3(x)
return x, state
@override(TorchModelV2)
def value_function(self):
return self.value.squeeze(1)
class ConditionalMultiHeadMLP(TorchModelV2, nn.Module):
"""MLP"""
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
TorchModelV2.__init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name)
nn.Module.__init__(self)
self.obs_space = obs_space
self.action_space = action_space
self.num_outputs = num_outputs
self.hidden_size = 128
self.encoder = nn.Sequential(
nn.Linear(obs_space.shape[0], self.hidden_size),
nn.ReLU(),
nn.Linear(self.hidden_size, obs_space.shape[0]))
self.opt_network = MLP(obs_space, action_space, num_outputs, model_config, name)
self.nov_network = MLP(obs_space, action_space, num_outputs, model_config, name)
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
# input_dict['obs'] has shape [batch_size, 126]
x = self.encoder(input_dict['obs']
batch_size = x.shape[0]
opt_mask = input_dict['obs'][..., -1] == 0
nov_mask = input_dict['obs'][..., -1] == 1
to_opt = x[opt_mask]
to_nov = x[nov_mask]
x_opt, state = self.opt_network.forward({'obs': to_opt}, state, seq_lens)
self.opt_value = self.opt_network.value
x_nov, state = self.nov_network.forward({'obs': to_nov}, state, seq_lens)
self.nov_value = self.nov_network.value
# stack back into how it was given
logits = torch.zeros(batch_size, self.num_outputs)
logits[opt_mask] += x_opt
logits[nov_mask] += x_nov
# stack values similarly
self.value = torch.zeros(batch_size, 1)
self.value[opt_mask] += self.opt_value
self.value[nov_mask] += self.nov_value
return logits, state
@override(TorchModelV2)
def value_function(self):
return self.value.squeeze(1)