Rllib trainig step customize

I want to refine the step of the algorithm: By default, the action is selected as:

a = argmax Q(f(s), a; teta)

Then the transition from the old state to the new one is calculated. The old state, action, reward, and new state are then stored in a buffer.

I want to change the get action function and make it look like this:

while not condition:
    action = policy.compute_action(obs)
    new_obs, rew, done, info = env.step(action)
    
    if condition:
        break
    obs = new_obs

replay_buffer.add(obs, action, rew, new_obs, done)
obs = new_obs

How do I customize this?

@sven1977 could you give some ideas for this?

1 Like

Did you try specifying a custom action_sampler_fn in your build_(tf|torch)_policy?
Take a look at Dreamer’s torch policy: rllib/agents/dreamer/dreamer_torch_policy.py, where we use this customization option.

1 Like

Thanks for the answer, but this is not exactly what I want. My goal is to enable a RL agent to model a varying number of action using fixed-size representation. Algorithm for training the agent is identical to the generally accepted one, except for one main difference: i use custom function, which replaces the application of a conventional training step. Therefore, I need to change not the choice of the action by the policy, but the filling of the replay buffer

I want to clarify: some of the agent’s actions can be unsuccessful in the environment, so I want to collect a replay buffer for the agent “with my own hands”, because the action_space is quite large

@sven1977 could you give some more ideas for this?

@rsv, got it. So you probably have to re-define your algo’s “execution plan”. If you look at e.g. DQN’s execution plan in rllib/agents/dqn/dqn.py, you can see that in there we collect rollouts (basically a concat’d SampleBatch from different RolloutWorkers) and store these in a ReplayBuffer. You may simply want to define your own StoreToReplayBuffer op (see rllib/execution/replay_ops.py for the default implementation of the StoreToReplayBuffer callable class).
Hope this helps. Let me know, if not. :slight_smile: