Hello,
Im triyng to implement Encoder Decoder with MHA in my custom model within forward() and value() function. Now, when I try to test it and do sth like trainer = algo.build(), I receive an error from postprocess function
“ray\rllib\evaluation\postprocessing.py”, line 195, in compute_gae_for_sample_batch [repeated 9x across cluster]
(PPO pid=3512) assert vf_preds.shape == rewards.shape [repeated 9x across cluster]
(PPO pid=3512) AssertionError [repeated 9x across cluster]"
this is strange, because in the first run, when initializing the model, all tensor dimensions match and I do not put any reward to the model manually, so the error occurs only after the forward() call, at the very end during the policy initialisation. What does it check and why the error?
Below is the code, and thank you in advance!
ray.init()
env = GraphEnvironment()
register_env("env_graph", lambda config: GraphEnvironment())
ModelCatalog.register_custom_model("EncoderDecoder", my_model)
algo = (
PPOConfig()
.rollouts(num_rollout_workers=0,batch_mode="complete_episodes",num_envs_per_worker=1)
.framework("torch")
.resources(num_gpus=0)
.environment(env="env_graph")
.training(train_batch_size=10,sgd_minibatch_size =1,
model={
"custom_model": "EncoderDecoder",
"_disable_preprocessor_api": False,
"_disable_action_flattening":True,
}
)
.experimental(
_enable_new_api_stack=False,
)
)
trainer = algo.build()
And the model looks like this:
class my_model(TorchModelV2, nn.Module):
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name):
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
nn.Module.__init__(self)
self.conf = model_config["custom_model_config"]
conf = self.conf
self.state_buffer = []
self.num_of_nodes = 11
node_feature_count = 3
hidden_size = 128
context_size = 129
self.batch_size = 10
self.init_state = self.get_initial_state()
self.lin_in = nn.Linear(node_feature_count, hidden_size) # .to(self.device)
self.graph_emb = embed_netsEmbedding(hidden_size)
self.graph_enc = attent_nets.GraphAttentionEncoder()
self.graph_dec = decoder_.Decoder()
self.value_out = MLP(context_size, 1)
@override(TorchModelV2)
def forward(self,
input_dict,
state,
seq_lens):
embeddings = self.graph_emb(nodes)
#select the embedings based on previous node indexes has a shape of [32,1,128] for each batch selects one node embedding
last_node_ = last_node.clone().detach().requires_grad_(False).type(dtype=torch.int64)
last_node_expanded = last_node_.unsqueeze(-1).expand(-1, -1, 128)
selected_embedings = torch.gather(embeddings, 1,last_node_expanded) #concantenate on last dimension
state_tensor = state[-1][:,None,None]
current_context = torch.cat((selected_embedings, state_tensor), dim=-1)
encoded_in, _ = self.graph_enc(embeddings)
cached_embeddings = self.graph_dec.precompute(encoded_in)
log_p, mask = self.graph_dec.decode(cached_embeddings, last_node, node_mask)
# compute value.
self._value = self.value_out(current_context)
return log_p, state # previous state
@override(TorchModelV2)
def value_function(self):
assert self._value is not None, "Must call `forward()` first!"
value = self._value.squeeze(-1) # B x T x h_dim --mlp--> B x T X 1
return value