Hi everyone! I was wondering what is the proper way to further train RL Algorithm on an additional offline datasets.
Let’s say that I have 3 offline datasets (for simplicity each one is in one .json
file):
offline_ds = ['a.json', 'b.json', 'c.json']
Then I build and train an offline RL algorithm:
config = (
MARWILConfig()
.environment(env=None, observation_space=env.observation_space, action_space=env.action_space)
.offline_data(input_=offline_ds)
)
algo = config.build()
for _ in tqdm(range(20)):
algo.train()
Next I get the new dataset d.json
which I want to add somehow to inputs_
and continue further training. However, if I try to do so (let’s say in the following way):
algo.config.update_from_dict({'input_': ['a.json', 'b.json', 'c.json', 'd.json']})
I got an error that Cannot set attribute (input_) of an already frozen AlgorithmConfig!
.
Thus the question: what is the proper way to “update/rebuild” algorithm with new offline data for further training, when the algorithm was already trained (and so also build)?
1 Like
I eventually resolved this issue by using ReplayBuffer. The solution is to do the following after the block of code from the OP:
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
# Create and fill ReplayBuffer
replay_buffer = ReplayBuffer()
json_reader = JsonReader(['a.json', 'b.json', 'c.json', 'd.json'])
all_sample_batches = list(json_reader.read_all_files())
for sample_batch in all_sample_batches:
replay_buffer.add(sample_batch)
# Train policy directly
train_batch_size = 32
nb_train_batches = 50
policy = algo.get_policy()
for _ in range(nb_train_batches):
policy.learn_on_batch(replay_buffer.sample(train_batch_size))
Here is another more accurate solution that substitutes JsonReader directly:
new_offline_ds = ['a.json', 'b.json', 'c.json', 'd.json']
# Substitude JsonReader
old_json_reader = algo.workers._local_worker.input_reader.child
new_json_reader = JsonReader(new_input_files, ioctx=old_json_reader.ioctx)
algo.workers._local_worker.input_reader.child = new_json_reader
# Continue training
for _ in tqdm(range(5)):
algo.train()