Hi, Dear @Chen_Shen, One thing I want to note here is that I define a Buffer class which is an actor,
buffer = ray.remote(ReplayBufferwithQueue).options(name=f"Buffer",
num_cpus=2,
max_concurrency=10,
num_gpus=0).remote(
scheme=scheme,
groups=groups,
buffer_size=args.buffer_size,
max_seq_length=min(env_info["episode_limit"], args.mp_episode_truncate_len) + 1,
preprocess=preprocess,
device="cpu" if args.buffer_cpu_only else args.device,
queue=queue,
args=args,
)
assert ray.get(buffer.ready.remote())
ray_ws = [buffer.run.remote()]
Here I define the run
function in this way:
class ReplayBufferwithQueue(ReplayBuffer):
def __init__(self, scheme, groups, buffer_size, max_seq_length, preprocess=None, device="cpu", queue=None, buffer_queue=None, args=None):
super().__init__(scheme, groups, buffer_size, max_seq_length, preprocess, device, args=args)
logging.basicConfig(level="INFO")
self.queue = queue
self.buffer_queue = buffer_queue
def run(self):
count = 0
try:
while True:
if self.queue.qsize() == 0:
time.sleep(1)
st = time.time()
episode_batch = self.queue.get()
self.insert_episode_batch(episode_batch)
logging.debug(f"Get new episdoe data, now count: {count}, time cost: {(time.time()-st):.2f}")
count += 1
except Exception as e:
logging.error(f"Error in training: {e}")
traceback.print_exc()
In the main process, I use ray.get(buffer.sample.remote(batch_size))
to sample the data. I think the reason why ray.get
here is slow is that buffer
actor is actually running in the background and it needs to switch from the run
function to sample
function, so it takes much longer to get the data. Is it correct?