Yes, I find out that the actually blocking part is train.report.
Due to the uneven inputs, some worker finish the dataloader loop early. But train.report must be called the same number of times by all workers.
How to handle this problem using ray.train?
with model.join():
iter_idx = 0
for input in inputs:
loss = model(input).sum()
iter_idx += 1
if iter_idx % report_num == 0:
train.report(metrics)
...