Hi, Could someone suggest how to maintain the same sequence of data after sampling from each worker (“on_sample_end”) and learning policy (“on_learn_on_batch”) especially for LSTM networks. Below is the example for PPO
class MyCallbacks(DefaultCallbacks):
def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch,**kwargs):
print("worker index{}: sample batch of size:{}".format(worker.worker_index, samples.count))
print("worker index{}: list_of_dates:{}".format(worker.worker_index
,[datetime.datetime.fromtimestamp(dt) for dt in samples['obs'][:,-1]]))
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,result: dict, **kwargs):
print("policy.learn_on_batch(): list_of_dates:}".format([datetime.datetime.fromtimestamp(dt) for dt in train_batch['obs'][:,-1]]))
Here is the sample Output:
(pid=3779) worker index1: sample batch of size:10
(pid=3779) worker index1: list_of_dates:[datetime.datetime(2019, 2, 23, 23, 30, 8), datetime.datetime(2019, 2, 24, 0, 0), datetime.datetime(2019, 2, 24, 0, 29, 52), datetime.datetime(2019, 2, 24, 0, 59, 44), datetime.datetime(2019, 2, 24, 1, 29, 36), datetime.datetime(2019, 2, 24, 1, 59, 28), datetime.datetime(2019, 2, 24, 2, 29, 20), datetime.datetime(2019, 2, 24, 2, 59, 12), datetime.datetime(2019, 2, 24, 3, 29, 4), datetime.datetime(2019, 2, 24, 3, 58, 56)]
(pid=3780) worker index2: sample batch of size:10
(pid=3780) worker index2: list_of_dates:[datetime.datetime(2019, 2, 23, 23, 30, 8), datetime.datetime(2019, 2, 24, 0, 0), datetime.datetime(2019, 2, 24, 0, 29, 52), datetime.datetime(2019, 2, 24, 0, 59, 44), datetime.datetime(2019, 2, 24, 1, 29, 36), datetime.datetime(2019, 2, 24, 1, 59, 28), datetime.datetime(2019, 2, 24, 2, 29, 20), datetime.datetime(2019, 2, 24, 2, 59, 12), datetime.datetime(2019, 2, 24, 3, 29, 4), datetime.datetime(2019, 2, 24, 3, 58, 56)]
(pid=3782) policy.learn_on_batch(): list_of_dates:[datetime.datetime(2019, 2, 24, 0, 59, 44), datetime.datetime(2019, 2, 24, 0, 29, 52), datetime.datetime(2019, 2, 24, 3, 58, 56), datetime.datetime(2019, 2, 24, 3, 58, 56), datetime.datetime(2019, 2, 24, 1, 29, 36), datetime.datetime(2019, 2, 24, 0, 29, 52), datetime.datetime(2019, 2, 24, 2, 59, 12), datetime.datetime(2019, 2, 23, 23, 30, 8), datetime.datetime(2019, 2, 24, 1, 29, 36), datetime.datetime(2019, 2, 24, 2, 29, 20)]
(pid=3782) policy.learn_on_batch(): list_of_dates:[datetime.datetime(2019, 2, 23, 23, 30, 8), datetime.datetime(2019, 2, 24, 2, 59, 12), datetime.datetime(2019, 2, 24, 1, 59, 28), datetime.datetime(2019, 2, 24, 0, 0), datetime.datetime(2019, 2, 24, 3, 29, 4), datetime.datetime(2019, 2, 24, 1, 59, 28), datetime.datetime(2019, 2, 24, 0, 0), datetime.datetime(2019, 2, 24, 0, 59, 44), datetime.datetime(2019, 2, 24, 3, 29, 4), datetime.datetime(2019, 2, 24, 2, 29, 20)]
As we could see, both worker index 1 & 2 have sequence of dates, where as in policy.learn_on_batch() the date list is shuffled.
Here are Config Parameters Used:
config[‘num_workers’] = 2
config[‘num_envs_per_worker’] = 1
config[‘num_cpus_for_driver’] = 1
config[‘rollout_fragment_length’] = 10
config[‘train_batch_size’] = config[‘rollout_fragment_length’] * config[‘num_workers’] * config[‘num_envs_per_worker’]
config[‘sgd_minibatch_size’] = config[‘rollout_fragment_length’]
config[‘num_sgd_iter’] = 1
config[‘batch_mode’] = ‘truncate_episodes’
config[‘shuffle_sequences’] = False
even the parameter shuffle_sequences doesn’t do the trick…