Resuming training from big models in ray train leads to `grcp` error

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

Hello,

I have the following use case:

  • I start a training using the TorchTrainer and save regular checkpoints (filesize > 400MB, since big model)
  • I want to be able to continue from one of the checkpoints, i.e. in a completely fresh run, I want to be able to load it.

This is using the new Ray 2.0.0 TorchTrainer API. With the old ray-train, Trainer this was not an issue.

What I tried to do is to apply to my code what is described here:
https://docs.ray.io/en/latest/train/dl_guide.html#loading-checkpoints

However, in my code I get the following error:

grpc._channel._InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
	status = StatusCode.RESOURCE_EXHAUSTED
	details = "Received message larger than max (500029835 vs. 262144000)"
	debug_error_string = "{"created":"@1663946219.205844259","description":"Error received from peer ipv4:xxx.xxx.xxx.xxx:xxxxx","file":"src/core/lib/surface/call.cc","file_line":1074,"grpc_message":"Received message larger than max (500029835 vs. 262144000)","grpc_status":8}">

Which apparently is telling me, that my model is to big to be loaded.

I can easily reproduce this error for the provided “how to load model checkpoints example” here Deep Learning User Guide — Ray 2.0.0, by modyfing the code slightly (making the model big.

import ray.train.torch
from ray.air import session, Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer

import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torch.optim import Adam
import numpy as np

def train_func(config):
    n = 100
    # create a toy dataset
    # data   : X - dim = (n, 4)
    # target : Y - dim = (n, 1)
    X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
    Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))

    # toy neural network : 1-layer
    model = nn.Linear(4, 25000000)  # Artificially huge model (does not run on laptop, need a slightly bigger machine for that)
    criterion = nn.MSELoss()
    optimizer = Adam(model.parameters(), lr=3e-4)
    start_epoch = 0

    checkpoint = session.get_checkpoint()
    if checkpoint:
        # assume that we have run the session.report() example
        # and successfully save some model weights
        checkpoint_dict = checkpoint.to_dict()
        model.load_state_dict(checkpoint_dict.get("model_weights"))
        start_epoch = checkpoint_dict.get("epoch", -1) + 1

    # wrap the model in DDP
    model = ray.train.torch.prepare_model(model)
    for epoch in range(start_epoch, config["num_epochs"]):
        y = model.forward(X)
        # compute loss
        loss = criterion(y, Y)
        # back-propagate loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        state_dict = model.state_dict()
        consume_prefix_in_state_dict_if_present(state_dict, "module.")
        checkpoint = Checkpoint.from_dict(
            dict(epoch=epoch, model_weights=state_dict)
        )
        session.report({}, checkpoint=checkpoint)

trainer = TorchTrainer(
    train_func,
    train_loop_config={"num_epochs": 2},
    scaling_config=ScalingConfig(num_workers=1),
)
# save a checkpoint
result = trainer.fit()

# load checkpoint
trainer = TorchTrainer(
    train_func,
    train_loop_config={"num_epochs": 10},
    scaling_config=ScalingConfig(num_workers=2),
    resume_from_checkpoint=result.checkpoint,
)
result = trainer.fit()

print(result.checkpoint.to_dict())

This raises exactly the error that I mentioned above.

I think the difference between the old and the new API is, that now ray-tune is used in the background and ray-tune does things differently from the old ray-train.

I can work around this, by telling TorchTrainer to resume from checkpoint but instead loading from checkpoint myself at the point where I actually define my model on the worker.
This seems not to be the intended solution though.

Note that this happens with checkpoint files of size ~400MB, which is not too unusual for deep learning models.

I was wondering what the recommended way to do this is.
Is this a bug in ray-train?

Thank you!

Hey @M_S, you are right this is an issue with AIR! This will be resolved in 2.1 when [AIR] Support large checkpoints and other arguments by amogkam · Pull Request #28826 · ray-project/ray · GitHub is merged.

In the meantime, you can workaround this by putting the Checkpoint in the object store manually.

 import ray.train.torch
from ray.air import session, Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer

import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torch.optim import Adam
import numpy as np

def train_func(config):
    n = 100
    # create a toy dataset
    # data   : X - dim = (n, 4)
    # target : Y - dim = (n, 1)
    X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
    Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))

    # toy neural network : 1-layer
    model = nn.Linear(4, 25000000)  # Artificially huge model (does not run on laptop, need a slightly bigger machine for that)
    criterion = nn.MSELoss()
    optimizer = Adam(model.parameters(), lr=3e-4)
    start_epoch = 0

    checkpoint = session.get_checkpoint()
    if checkpoint:
        # assume that we have run the session.report() example
        # and successfully save some model weights
        actual_checkpoint = ray.get(checkpoint["obj_ref"])
        checkpoint_dict = checkpoint.to_dict()
        model.load_state_dict(checkpoint_dict.get("model_weights"))
        start_epoch = checkpoint_dict.get("epoch", -1) + 1

    # wrap the model in DDP
    model = ray.train.torch.prepare_model(model)
    for epoch in range(start_epoch, config["num_epochs"]):
        y = model.forward(X)
        # compute loss
        loss = criterion(y, Y)
        # back-propagate loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        state_dict = model.state_dict()
        consume_prefix_in_state_dict_if_present(state_dict, "module.")
        checkpoint = Checkpoint.from_dict(
            dict(epoch=epoch, model_weights=state_dict)
        )
        session.report({}, checkpoint=checkpoint)

trainer = TorchTrainer(
    train_func,
    train_loop_config={"num_epochs": 2},
    scaling_config=ScalingConfig(num_workers=1),
)
# save a checkpoint
result = trainer.fit()

checkpoint_object_ref = ray.put(result.checkpoint)
checkpoint_to_resume_from = Checkpoint.from_dict({"obj_ref": checkpoint_object_ref})

# load checkpoint
trainer = TorchTrainer(
    train_func,
    train_loop_config={"num_epochs": 10},
    scaling_config=ScalingConfig(num_workers=2),
    resume_from_checkpoint=checkpoint_to_resume_from,
)
result = trainer.fit()

print(result.checkpoint.to_dict())

Hi @amogkam,

thanks for the reply, I’m looking forward to 2.1 then.

I am already using the other workaround I mentioned in my post, but good to know that there are other ways to do it!

Thank you.