I am using action masking in combinatorial action spaces which means i need to sample from a number of possible heirarchical actions during the action_sample_fn.
The hierarchical action space can be compressed to a dict and then the sampling function goes through the tree to create the appropriate mask for the chosen actions.
My problem is that I can’t seem to find a way to pass this dict structure (which is built during step()) to the action_sample_fn.
Things I have tried:
Pass the dict as a custom space in input_dict['obs']['action_mask'] = {.....}
Preprocessor will only accept tensors because it assumes all data is some sort of tensor
Compress the tree to a variable length tensor and pass it as a tensor in input_dict['obs']['action_mask'] = [....]
Preprocessor tries to flatten the array but expects the tensor to be fixed length.
I’ve read that possible I will be able to use the env.stepinfo variable, but that is not included in the input_dict when it gets to action_sample_fn. (Potentially this is available in 2.0.0.dev)
Hey @Bam4d , thanks for the question. This is indeed handled quite confusingly right now due to the fact that our action_sample_fn behaves slightly differently for torch and tf.
I’m assuming you are on tf and your action_sampler_fn receives the flattened “obs” (which includes the action mask) as obs_batch argument. You can simply do the following inside your custom action_sampler_fn now:
from ray.rllib.models.modelv2 import restore_original_dimensions
orig_obs = restore_original_dimensions(obs_batch, policy.observation_space, "tf")
Now, orig_obs should be a dict containing the action mask.
I’m using torch and the problem isnt that the data is flattened, its that the data is deleted because its not a fixed tensor.
My action mask is either a variable length tensor, or it needs to be a dict object.
It seems as though its not possible to put non-tensor data (i.e a dict object)into the input_dict?
I’ve worked around this issue in a pretty ugly way for now by overriding compute_actions with a mixin, and then setting _use_trajectory_view_api to false, and putting the dict into the info_batch instead.
It would be much better if the info_batch could just be seen by the action_sampler_fn, OR if input_dict could store dict data.
Ah, got it. Yeah, I’ll create a PR, force-adding the info dicts always, no matter what the view requirements say. Similar to rewards and dones, which are also always part of the train batch, no matter whether they are needed in the loss or not.