End to end training of RAG retriever with RAY

I need some help to implement an end-to-end retrieval training feature for the rag with Ray.

How can I run document encoding and indexing with an updated doc-encoder (context encoder network that kept frozen in the original RAG) using a Ray actor separated from the main training process?

How can I access the document index inside Ray actors during the training incase I want to update the index, say in every 5000 steps.

Hey @Shamane_Siriwardhana thanks a bunch for trying out RAG on Ray! I saw your Github issue and I will take a closer look at it later today.

Can you explain a bit more on what you’re trying to do? Do you need to train the document encoder in a distributed setting? Currently there is no gradient flow in the Ray actors, and they are only used to access documents, so there might be some work to include training in this process as well.

If you have an example already of this in a non-distributed setting, I’d be happy to help with getting it working on Ray. In fact, in the current Ray implementation, if only 1 training worker is used, the index is loaded into the same process as the worker and not on a remote process. transformers/distributed_ray_retriever.py at master · huggingface/transformers · GitHub

Hi, thanks a lot for your quick reply.

In my suggestion, we do not need to do any gradient flow to the RAY Actors. We only need to run a RAY actor that uses an updated context-encoder to calculate the embedding for the document sets in every 1000 steps.

The context encoder updating will be happening on the main process. It can be simply done by plugging the context encoder when calculating the doc scores which I have done already. Now I need to find a way how periodically use that updated context-encoder to calculate embedding and re-initialize the indexed datasets in those words that get the embedding.

My idea is we first initialize another actor that works with the updated context encoder similar to what you have done here.

If I’m understanding correctly, would you just be able to pass in the updated context-encoder as an argument into the retrieve call (transformers/distributed_ray_retriever.py at master · huggingface/transformers · GitHub) and ultimately to the Ray actors? Then the actors can update its index and return the relevant documents.

Do you perhaps have some pseudocode of what this would look like on a single process? That might help me understand this better. Disclaimer that I have not read through the REALM paper yet.

-Exactly this is what I want.

Other than usual document retrieval actors, we have to assign an actor that specially calculate document embedding from an updated cotext encoder. I am thinking of interchanging the updated context encoder ever N number of steps, let’s say 5000. Then maybe every 7500 steps we reload the dataset with faiss idex.

def training_step(self, batch, batch_idx) -> Dict:

if not batch_idx==0 and batch_idx%5000==0:
        -calling the ray actor that uses updated context encoder to re calculate the document embeddings 

if not batch_idx==0 and batch_idx%7500==0:
       -Again calling the module.model.rag.retriever.init_retrieval()  (RAY implementation a s usual)

p.s I am using two different step counts to calculate the embedding and load the index inoder to reduce the latency.

Hey @Shamane_Siriwardhana, were you able to get this working?

Hey @amogkam yes I have implemented. But would love to get a review, because I really do not know whether it works since I haven’t tested this with very large-scale training.

Awesome, that’s great to hear!!! If you can make a PR on Huggingface repo, I can take a look and try it out :slight_smile:

ok, I was waiting for your response will do it soon :slight_smile: .

Hi @amogkam ,

Please refer to the following pull request rag with end-to-end retriever training with ray by shamanez · Pull Request #10410 · huggingface/transformers · GitHub.

Please check the comments with @amogkam in the code.

Thanks- I will take a look later today!


Sorry for spamming …is there any updates? This would be a big help for me …