RuntimeError with ray.train.torch.prepare_model

Hello,

When utilizing ray.train.torch.prepare_model on the a BERT model from the transformers library, I came across this error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.LongTensor [1, 100]] is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

With torch.autograd.set_detect_anomaly(True) this is the full stack trace where the issue seems to be backward pass:

(RayTrainWorker pid=17974) .../lib/python3.9/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in EmbeddingBackward0. Traceback of forward call that caused the error:
(RayTrainWorker pid=17974)   File ".../.pyenv/versions/3.9.16/lib/python3.9/threading.py", line 937, in _bootstrap
(RayTrainWorker pid=17974)     self._bootstrap_inner()
(RayTrainWorker pid=17974)   File ".../.pyenv/versions/3.9.16/lib/python3.9/threading.py", line 980, in _bootstrap_inner
(RayTrainWorker pid=17974)     self.run()
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/ray/air/_internal/util.py", line 88, in run
(RayTrainWorker pid=17974)     self._ret = self._target(*self._args, **self._kwargs)
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/ray/train/_internal/utils.py", line 138, in train_fn
(RayTrainWorker pid=17974)     return wrapped_train_func(config)
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/ray/train/_internal/utils.py", line 118, in discard_return_wrapper
(RayTrainWorker pid=17974)     train_func(*args, **kwargs)
(RayTrainWorker pid=17974)   File ".../github/negotiatus/data-sandbox/science/ml_infrastructure/ray_mlflow/ray/error.py", line 152, in train
(RayTrainWorker pid=17974)     embedding_1 = model(**model_inputs_1).pooler_output
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
(RayTrainWorker pid=17974)     return forward_call(*args, **kwargs)
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
(RayTrainWorker pid=17974)     output = self._run_ddp_forward(*inputs, **kwargs)
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1113, in _run_ddp_forward
(RayTrainWorker pid=17974)     return module_to_run(*inputs, **kwargs)
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
(RayTrainWorker pid=17974)     return forward_call(*args, **kwargs)
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1006, in forward
(RayTrainWorker pid=17974)     embedding_output = self.embeddings(
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
(RayTrainWorker pid=17974)     return forward_call(*args, **kwargs)
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 237, in forward
(RayTrainWorker pid=17974)     position_embeddings = self.position_embeddings(position_ids)
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
(RayTrainWorker pid=17974)     return forward_call(*args, **kwargs)
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 162, in forward
(RayTrainWorker pid=17974)     return F.embedding(
(RayTrainWorker pid=17974)   File ".../lib/python3.9/site-packages/torch/nn/functional.py", line 2210, in embedding
(RayTrainWorker pid=17974)     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
(RayTrainWorker pid=17974)  (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:119.)

This is the code to reproduce the issue. If you comment out model = ray.train.torch.prepare_model(model=model) then it runs fine. I may be doing something really dumb here, but I thought I would raise this issue.
Package versions (also an issue on older transformers versions and newer torch versions):

transformers==4.37.0
ray==2.9.0
torch==2.0.1
pandas==2.0.3
import ray
import ray.train.torch
import torch
import pandas as pd
from ray import data
from ray.train import ScalingConfig
from transformers import BertModel, BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

texts = {
    "text1": [
        "This is an example text I would like to embed with a NLP model.",
        "This is also an example text, but different, I would like to embed with a NLP model.",
    ],
    "text2": [
        "This is an example text I would like to embed with a NLP model. Another sentence here.",
        "This is also an example text, but different, I would like to embed with a NLP model.",
    ],
    "label": [-1, 1],
}

df_texts = pd.DataFrame(texts)

ray_data = data.from_pandas(df_texts)


def tokenize_batch(batch: pd.DataFrame, tokenizer: BertTokenizer, max_length: int):
    encoding1_output = {}
    encoding2_output = {}
    labels = []

    for _, row in batch.iterrows():
        text1, text2, label = (
            row["text1"],
            row["text2"],
            row["label"],
        )
        encoding1 = tokenizer(
            text1,
            add_special_tokens=True,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=max_length,
        )
        encoding2 = tokenizer(
            text2,
            add_special_tokens=True,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=max_length,
        )

        # Both encodings have the same ekeys
        if len(encoding1_output) == 0:
            for key in encoding1.keys():
                encoding1_output[key] = []
                encoding1_output[key].append(encoding1[key].squeeze())
                encoding2_output[key] = []
                encoding2_output[key].append(encoding2[key].squeeze())
        else:
            for key in encoding1.keys():
                encoding1_output[key].append(encoding1[key].squeeze())
                encoding2_output[key].append(encoding2[key].squeeze())

        labels.append(torch.tensor(label))

    for key in encoding1_output.keys():
        encoding1_output[key] = torch.stack(encoding1_output[key])

    for key in encoding2_output.keys():
        encoding2_output[key] = torch.stack(encoding2_output[key])

    labels = torch.stack(labels)

    # One valid output for map_batches is a dictionary with numpy array as the value
    combined_output = {}
    for key in encoding1_output.keys():
        combined_output[f"encoding_1_{key}"] = encoding1_output[key].numpy()
        combined_output[f"encoding_2_{key}"] = encoding2_output[key].numpy()
    combined_output["labels"] = labels.numpy()

    return combined_output


train_data_tokenized = ray_data.map_batches(
    fn=tokenize_batch,
    fn_kwargs={"tokenizer": tokenizer, "max_length": 100},
    batch_format="pandas",
)

# Datasets keyed by name
datasets = {"train": train_data_tokenized}

scaling_config = ScalingConfig(num_workers=2, use_gpu=False)
train_loop_config = {
    "batch_size": 32,
    "num_epochs": 1,
    "lr": 1e-4,
    "model": model,
    "tokenizer": tokenizer,
}


def train(config):
    # Configurations
    lr = config["lr"]
    batch_size = config["batch_size"]
    num_epochs = config["num_epochs"]
    model = config["model"]

    # PyTorch objects
    model = ray.train.torch.prepare_model(model=model)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    optimizer = ray.train.torch.prepare_optimizer(optimizer=optimizer)
    criterion = torch.nn.CosineEmbeddingLoss()

    # Define custom function to process the dictionary returned from tokenize_rows
    def iter_collate_fn(rows: dict):

        # Convert from numpy arrays to PyTorch tensors
        model_inputs_1 = {
            "input_ids": torch.tensor(rows["encoding_1_input_ids"]),
            "token_type_ids": torch.tensor(rows["encoding_1_token_type_ids"]),
            "attention_mask": torch.tensor(rows["encoding_1_attention_mask"]),
        }
        model_inputs_2 = {
            "input_ids": torch.tensor(rows["encoding_2_input_ids"]),
            "token_type_ids": torch.tensor(rows["encoding_2_token_type_ids"]),
            "attention_mask": torch.tensor(rows["encoding_2_attention_mask"]),
        }
        labels = torch.tensor(rows["labels"])

        return model_inputs_1, model_inputs_2, labels

    with torch.autograd.set_detect_anomaly(True):
        for epoch in range(num_epochs):
            # Data
            train_data_shard = ray.train.get_dataset_shard("train")
            for (
                model_inputs_1,
                model_inputs_2,
                labels,
            ) in train_data_shard.iter_torch_batches(
                batch_size=batch_size, collate_fn=iter_collate_fn
            ):              
                optimizer.zero_grad()

                embedding_1 = model(**model_inputs_1).pooler_output
                embedding_2 = model(**model_inputs_2).pooler_output
                loss = criterion(embedding_1, embedding_2, labels)
                loss.backward()
                optimizer.step()

            metrics = {"loss": loss.item(), "epoch": epoch}
            print(metrics)


trainer = ray.train.torch.TorchTrainer(
    train_loop_per_worker=train,
    scaling_config=scaling_config,
    datasets=datasets,
    train_loop_config=train_loop_config,
)

trainer.fit()