How to update/rebuild algorithm in order to add new offline datasets?

Hi everyone! I was wondering what is the proper way to further train RL Algorithm on an additional offline datasets.
Let’s say that I have 3 offline datasets (for simplicity each one is in one .json file):
offline_ds = ['a.json', 'b.json', 'c.json']
Then I build and train an offline RL algorithm:

config = (    
    MARWILConfig()
    .environment(env=None, observation_space=env.observation_space, action_space=env.action_space)
    .offline_data(input_=offline_ds)
)
algo = config.build()
for _ in tqdm(range(20)):
    algo.train()

Next I get the new dataset d.json which I want to add somehow to inputs_ and continue further training. However, if I try to do so (let’s say in the following way):
algo.config.update_from_dict({'input_': ['a.json', 'b.json', 'c.json', 'd.json']})
I got an error that Cannot set attribute (input_) of an already frozen AlgorithmConfig!.
Thus the question: what is the proper way to “update/rebuild” algorithm with new offline data for further training, when the algorithm was already trained (and so also build)?

1 Like

I eventually resolved this issue by using ReplayBuffer. The solution is to do the following after the block of code from the OP:

from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer

# Create and fill ReplayBuffer
replay_buffer = ReplayBuffer()
json_reader = JsonReader(['a.json', 'b.json', 'c.json', 'd.json'])
all_sample_batches = list(json_reader.read_all_files())
for sample_batch in all_sample_batches:
    replay_buffer.add(sample_batch)

# Train policy directly
train_batch_size = 32
nb_train_batches = 50
policy = algo.get_policy()
for _ in range(nb_train_batches):
    policy.learn_on_batch(replay_buffer.sample(train_batch_size))

Here is another more accurate solution that substitutes JsonReader directly:

new_offline_ds = ['a.json', 'b.json', 'c.json', 'd.json']

# Substitude JsonReader
old_json_reader = algo.workers._local_worker.input_reader.child
new_json_reader = JsonReader(new_input_files, ioctx=old_json_reader.ioctx)
algo.workers._local_worker.input_reader.child = new_json_reader

# Continue training
for _ in tqdm(range(5)):
    algo.train()