I have an agent learning in an evolving environment. I have implemented the environment as an external_env. The agent observes a state and takes an action. I am training an Attention net using A2C algorithm.
Currently, the initial state in each episode is zeroed out. I would like to use the same attention network to get a better representation of the initial state. Basically, I want the agent to observe the environment for the first few timesteps, without taking any action. Once enough observations are collected, the agent should start using the policy to generate actions. This will make sure that the initial state is not fully zeroed out when agent starts acting.
This is what I planning:
Start episode. Agent is only observing.
use .log_action method of ExternalEnv to record observations and a dummy action. Do not record any rewards.
Once enough observations are collected, switch to calling .get_action method and take the action suggested.
When computing the loss, zero out the rows with dummy actions.
Is this the right way to do this? Is there a way to avoid sending the dummy actions through the neural network at the training time completely?
What you are describing sounds like it would be simple but it actually messes with the RL paradigm and our sampling routines are not easily altered. But you can do something very similar and just write a callback to postprocess the samples you collected like here: ray/custom_metrics_and_callbacks.py at master · ray-project/ray · GitHub
This is how I would try and implement this in rllib.
I would have a dictionary observation space. One of the entries of the observation let’s call it prompt would only have valid data when the environment resets. This would have the input used to create the initial state.
I would also have a custom model. In the forward method I would check that the prompt was not the empty state.
If that is the case then I would, in a loop feed each embedding of the prompt through the model, discarding the output and forwarding the state.
Finaly I would then use that state to generate the first token out and return that token and the resulting state from generating that token.
There may be issues with this I haven’t appreciated yet, there usually are, but this is how I would start.