Ray Tune vs Ray Train Inheritance

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

Ray Train should not be based on the Ray Tune implementation of a trainer.
Ray Tune should use a Ray Train implemented trainer.

Prior to Ray 1.12 I was using a custom sgd.torch TorchTrainer/TrainingOperator to implement an async batch iterator which looked like this:

class AsyncTrainOperator(ray.util.sgd.torch.TrainingOperator):
    async def train_epoch(self, iterator, ... ):
        it = enumerate(iterator)
        ait = AsyncIteratorExecutor(it)
        async for batch_idx, batch in ait:
            metrics = self.train_batch(batch, batch_info=batch_info)

AsyncIteratorExecutor is just a convenience function to convert a regular iterator into an async one. The reason why I was doing this had to do with a training worker blocking execution for another worker on the same machine. The specifics of the implementation aren’t important since I’m not concerned about the regression of sgd being deprecated since it was most likely an extreme edge case, what’s important are the differences between implementing this in sgd versus the new Ray1.12 train framework.

In sgd it required 4 custom classes to implement:

class AsyncDistributedTorchRunner(ray.util.sgd.torch.distributed_torch_runner.DistributedTorchRunner)
class AsyncWorkerGroup(ray.util.sgd.torch.worker_group.RemoteWorkerGroup)
class AsyncTorchTrainer(ray.util.sgd.torch.TorchTrainer)
class AsyncTrainOperator(ray.util.sgd.torch.TrainingOperator)

In the new train framework, the way I was attempting to do this was to change the train_loop_per_worker to be async, which would look like this:

async def async_train_loop(config):
    ... ( load datasets, models, optimizers) ...
    for epoch_idx in range(num_epochs):
        it = iter(trn_ds)
        ait = AsyncIteratorExecutor(it)
        async for batch in ait:
            ... ( train batch logic: forward / backward / optimizer.step ) ...

This train function is then passed into an AsyncTrainer for the train_loop_per_worker:

trainer = AsyncTrainer(train_loop_per_worker=async_train_loop)

From that point the code gets very messy, very quickly. So far I’ve had to create async versions of 8 objects and it still doesn’t work.

class AsyncTorchTrainer(ray.ml.train.integrations.torch.torch_trainer)
class AsyncTuner(ray.tune.tuner.Tuner)
class AsyncTunerInternal(ray.tune.impl.tuner_internal.TunerInternal)
async def async_tune_run [def ray.tune.tune.run]
class AsyncTrialRunner(ray.tune.trial_runner.TrialRunner)
class AsyncTrialExecutor(ray.tune.ray_trial_executor.RayTrialExecutor)
async def async_wrap_function [def ray.tune.function_runner.wrap_function ]
class AsyncFunctionRunner(ray.tune.function_runner.FunctionRunner)

Again, I’m not concerned with the regression, but the complexity of the train system seems to have grown exponetially by forcing it to use the Tune trainer framework. The Tune system on it’s own has grown as it’s incorporated other features, yet having the Tune system implement an entire trainer system doesn’t seem right to me.

Is there a reason why the Train functions are being built to inherit from the Tune trainer functions besides the fact that the Tune code already existed?

The problem is best seen from this line of code from ray.ml.trainer:

from ray.tune import Trainable

That functionality is backwards. A tune object should be a special form of a train object not the other way around. This backwards inheritance is part of the complexity problem I’m dealing with in my current implementation and will only cause more complexity issues in the future. Changing this would take a decent amount of refactoring, but since these internals are currently being reworked for the AIR functionality, this would be an ideal time to do that before the code gets too deep to rework.

The “simple” fix is to move the trainer code from the tune module into the train module and modify tune to work accordingly.

An alternative solution is to change the way Ray handles training all together.

Ray Train Sever
While digging through the internals of Tune, I stumbled across the tune server: ray.tune.web_server. Why I had to import this to do what I wanted in the first place, I don’t know, but I think the idea behind it’s function would act as a good base for re-implementing Ray Train functionality when combined with the functionalality behind Ray Serve.

Ray would provide a train server that would schedule and manage train jobs. Jobs could use any available backend (torch, tensorflow, RLlib etc). On submitting a job to ray.train, if a running train server exists it would be submitted to it, if one doesn’t exists the server would be initialized or the train job could be run in a local mode that attempts to run and manage the job from the local machine only. This server would be implemented as a native ray actor, similar to Ray Serve. Train jobs could be generated by any interface that connects to a cluster: clients, standalone trainers, tune jobs, etc.

When scheduling a job it could either be run in a blocking or non-blocking way. A blocking submission would wait to receive callback updates / logging from the server about the submitted job. This blocking would only exist in the calling process, the server would remain able to receive new jobs and send out updates to other clients / callbacks. The non-blocking version could include a callback for the server to use for results, the server could provide a resource to subscribe to / call into for updates, or just ignore updates all together and run the job.

Once the job is submitted to the server, the server handles the resource allocation for the process groups and other resources based on whats requested by the job or whats available to use including scheduling future resource use. Multiple jobs could be run concurrently and running jobs could be interacted with through the server for stopping, modifications, or viewing training progress.

Running a training server in this way also provides a convenient way to provide a dashboard into the training. The server would manage current jobs, history, progress, and other extendable information types. This interface could then be accessible from the web dashboard in the same way tune jobs are currently. Introspection into training can be accomplished using external callbacks, but a native Ray version would be more useful for providing Ray specific information and still be able to provide information to these external methods. Another useful feature would be a more accessible way to see where, how, and when GPU/CPU/memory resources are being used by train jobs. Besides a web dashboard, the server would also provide a point for trainer shutdowns, forced stoping, and ensuring that resources of jobs are freed in a clean and specific way.

Submitting jobs to the server would require a model object, a dataset object, and settings objects. The model object and dataset object would include a setup function to run on the worker or on the server. I think the model object would be best served by implementing a new unified Ray Model object. The goal of this would be to create a single backend agnostic place for any type of model to live inside Ray. The backend settings and implementation details for a model would be handled by the object/worker when it is used.

Ray Model Object
A RayModel would include the specifics of the model: backend setup, internal model module structure, helper functions, default parameters, etc. It would be nice for this object to imitate pytorch-lightning’s implementation of hooks into key spots of the runtime loop such as: before_batch, before_epoch, before_setup, etc. We wouldn’t need as many hooks as pytorch-lightning provides since our targeted hook points would be much more general and the hooks would often have some relation to the backend being used, so they would need to be implemented in RayModel subclasses. This method of using hooks isn’t necessary, and a simpler system of overriding internal function calls could be implemented for the key steps in the loop, but the functionality provided by the pytorch-lightning method is attractive.

By inheriting from a base RayModel object it would be possible to have backend specific implementations providing classes such as: RayTorchModel, RayTensorflowModel, RayXGBoostModel, RayLightGBMModel, RayRLlibModel, etc. These models would work to isolate the implementation and setup details of the different backends to allow the different types of models to work interchangeably in train, tune, serve, or any future system while still allowing access to the low-level features of each backend. User models would be created as subclasses of these objects. The user model would only need to contain the explicit setup for the model in the backend language and a call/foward function to handle running the model for inference.

A RayModel sub-class would then be bundled with a train_function to make a RayTrainableModel which would be able to be submitted to the train server. This train_function would need to be able to work with the specific backend to access optimizers, schedulers, or whatever other backend specific parts of the train loop are needed.

A RayTrainableModel could then be bundled with a tune_settings to create a RayTunableModel which could be run as a tune job that handles creating multiple RayTrainableModels to send to the train server. The tune scheduler would be able to run in much the same way it currently does, just removing the explicit spawning of workers/tasks for an experiment. Each experiment would be bundled into a RayTrainableModel with the experiment model settings and sent to the train server. The tuner would then receive the updates and continue scheduling based on the results.

Finally a RayModel could be bundled with a server_settings to create a RayServableModel that would be able to get directly pushed to Ray Serve.

Using one base model object that can be used interchangeably for the entire lifecycle of the model on Ray will greatly increase usability and simplicity. Some of this functionality is similar to parts of the Ray AIR RFC, but the main differences are the implementation of Ray Train as the base system as opposed to Ray Tune and the introduction of the Ray Model object to encapsulate a model.

Some pseduocode to demonstrate

# Model creation
class MyTorchModel(RayTorchModel):
    def __init__(self, config={}):
        self.config = self._config(config)

    def setup(self):
        self.net = nn.Sequential(nn.Identity())
        self.preprocessor = transforms.Compose([
        self.postprocessor = transforms.Compose([

    def forward(self, x):
        # x has already passed through self.preprocess
        return self.net(x)
        # return value is passed through self.postprocess

    def on_preprocess(self, x):
        return self.preprocessor(x)

    def on_postprocess(self, x):
        return self.postprocessor(x)

    def _config(self, config):
        default_config = {'size':256}

        return config.update(default_config)

my_model = MyTorchModel()

z = MyTorchModel.forward(x)

# Model Training
class MyTrainData:
    def __init__(self, config={}):
        self.config = config

    def setup(self):
        train_dataset = ray.data.read_binary_files(self.config['train_file_location'])
        valid_dataset = ray.data.read_binary_files(self.config['valid_file_location'])

        return {'train':train_dataset, 'valid':valid_dataset}

class MyTrainableModel(RayTrainableModel):
    def __init__(self, config={}):
        self.model_class = MyTorchModel
        self.config = self._config(config)

    def setup(self):
        self.model = self.model_class(self.config)
        self.optimizer = torch.Optim.AdamW(self.model.parameters(), lr=self.config['lr'])
        self.loss - F.mse_loss

    def train_step(self, x):
        # postprocess hook is disabled for training
        z = self.model.forward(x, no_postprocess=True)
        loss = self.loss(z, x)
        train.report(loss=loss, from='train_step')

    def _config(self, config):
        default_config = {'lr':0.01}

        return config.update(default_config)

model_config = {'size':512}
trainable_config = {'lr':0.001}
train_settings = {'num_workers':3, 'use_gpu':True, 'num_epochs':10, 'batch_size':10,

train_results_location = ray.train(MyTrainableModel, train_settings=train_settings dataset=MyTrainData, block=False)

# blocks until the train job is finished
# the train server streams the logging output back to the client
result = ray.get(train_results_location)

# Model Tuning
class MyTunableModel(RayTuneableModel):
    def __init__(self, config={}):
        self.trainable_class = MyTrainableModel
        self.config = config

        # disable StreamingOutputLogger
        # logging from the tune job takes over instead
        self.disable_training_callbacks = True

model_parameters = {'size':tune.choice(32,256,1024)}

tune_settings = {'num_workers':4, 
                   'scheduler':tune.schedulers.ASHAScheduler(metric='val_loss', mode='min'),

ray.tune(MyTunableModel, tune_settings=tune_settings, dataset=MyTrainData)

# Model Serving
class MyServableModel(RayServeableModel):
    def __init__(self, config={}):
        self.model_class = MyTorchModel
        self.model_checkpoint_location = '/my_trainable_model_checkpoint'

serve_settings = {'route_prefix':'/my_model'}

ray.serve(MyServableModel, serve_settings=serve_settings)

This demo-implementation does not actually solve my original problem of being able to run an async train_epoch function. It works around it by removing the blocking from the calling process, but it would still be useful to be able to use an async train_step in the MyTrainableModel class, but that requires multiple layers of async aware code. All those layers are concentrated around a single object though, the RayTrainableModel, so implementation would most likely be easier.

Ray Model Server
As I’ve been typing this out, I started thinking it might be possible to use the RayModel object as a base object on a Ray implemented server. Ray could provide a model server containing RayModel objects that can be used by train, tune, serve, or any other functionality added in the future. A user creates a model, registers it with the cluster model server, and then is able to use it with any of the other systems by providing the refrence to the model and the run-time configuration settings. This might be the most straight-forward method. When building the RayModel, the train loop, hyper parameter space, and any other extended considerations could be built-in at the start and any part that isn’t needed by the system currently using the model or the user creating the model could just be ignored. This idea starts going against some of what has been laid out for the Ray AIR RFC, specifically the parts dealing with external model / experiment stores. This system would not be designed for global-scope, but instead a cluster-scoped system that could then interact with global-scoped systems by either pulling down models or pushing models out.

Pseduocode for RayModel Server Usage:

# Create a Model
class MyModel(RayTorchModel):
    def __init__(self, config={}):
        self.config = self._config(config)

    def setup(self):
        self.model = nn.Sequential(nn.Identity())
        self.optimizer = torch.Optim.AdamW(self.model.parameters(), lr=self.config['lr'])
        self.loss - F.mse_loss
        self.preprocessor = transforms.Compose([
        self.postprocessor = transforms.Compose([

    def forward(x, no_postprocess=False):
        z = self.preprocessor(x)
        z = self.model(z)

        if no_postprocess is True:
            return z

            return self.postprocessor(z)

    def train_step(self, x):
        z = self.forward(x, no_postprocess=True)
        loss = self.loss(z, x)
        train.report(loss=loss, from='train_step')

    def _config(self, config):
        default_config = {'size':512,

        return config.update(default_config)

    def hyperparameter_space(self):
        config = {'size':tune.choice(32,256,1024),

        return config

# Register Model
model_reference = ray.model.put(MyModel)

# Build Model from Reference
model_config = {'size':256}
my_model = ray.model.get(model_reference, model_config=model_config)
z = my_model.forward(x)

# Train Model From Reference
class MyTrainData:
    def __init__(self, config={}):
        self.config = config

    def setup(self):
        train_dataset = ray.data.read_binary_files(self.config['train_file_location'])
        valid_dataset = ray.data.read_binary_files(self.config['valid_file_location'])

        return {'train':train_dataset, 'valid':valid_dataset}

model_config = {'size':256, 'lr':0.001}
train_settings = {'num_workers':3, 'use_gpu':True, 'num_epochs':10, 'batch_size':10,

train_results = ray.train(model_reference, train_settings=train_settings, dataset=MyTrainData)

# Tune Model From Reference
tune_settings = {'num_workers':4, 
                   'scheduler':tune.schedulers.ASHAScheduler(metric='val_loss', mode='min')}

tune_results = ray.tune(model_reference, tune_settings=tune_settings, dataset=MyTrainData)

# Serve Model From Reference
serve_settings = {'route_prefix':'/my_model'}

ray.serve(model_reference, serve_settings=serve_settings)

In this form, it would still be useful to implement a train server as well to act as the base system for both training and tuning.


  • Having Ray Train inherit from Ray Tune is backwards
  • A better way to work would be a tune object calling a train method to run an experiment
  • Ray Tune has a tune_server that would make a good blueprint for a train_server when combined with the functional layout of Ray Serve
  • Changing to this method gives the opportunity to implement new functionality for both training and tuning
  • Implementing this functionality leads itself to implementing a backend agnostic RayModel object
  • RayModel objects can be used interchangeably across Ray
  • It could be possible to use a RayModel server to centralize model usage on a Ray cluster

While my code is based in pytorch, the concepts should be extendable to other backends.
The main goal here is still an answer to this question:
Is there a reason why the Train functions are being built to inherit from the Tune trainer functions besides the fact that the Tune code already existed?

I put together a demonstration of what a model server could work like. It’s getting away from the point of this thread, but maybe someone will come and answer my original question still. If its better off to move this code some place else to discuss it let me know.

# ModelServer Demonstration

# Running in Jupyter Lab consoles connected to a Ray cluster
# Need to change the addresses to your system => RAY_ADDRESS / SERVE_HOST / SERVE_PORT
# First have a running Ray system with a running Serve instance
# Then run the ModelServer deployment code
# Next run the MyModel deployment code
# Finally run the client code
# Was built on my ray 1.12.0 not current master - so there could be issues

# all steps can be run from different clients connected to the same cluster 
# ie: different consoles on JupyterLab
# after registering a model you can call the HTTP endpoint for the ModelServer
# and it will return a list of registered models

# ModelServer Deployment
import asyncio
import uuid
from collections import OrderedDict

import ray
from ray import serve

#RAY_ADDRESS = 'auto'

if not ray.is_initialized():
    ray.init(namespace='test_ns', address=RAY_ADDRESS)

# clean up serve enviornment
# - ModelServer was sometimes not updating when redeploying and was causing an error afterwards
serve_client = serve.start(http_options={'host':SERVE_HOST, 'port':SERVE_PORT}, 
    location='HeadOnly', dedicated_cpu=True, detached=True)

# Model Server
class ModelServer:
    def __init__(self):
        # repository would need to live outside of the server to survive redeploys
        self.repository = dict()

    def __call__(self, request):
        # call function returns a list of currently registered models
        return str([(z, list(self.repository[z].keys())) for z in self.repository.keys()])

    async def push(self, model, model_key, version=None):
        # push a model to the server
        # - saves model under model_key
        # - if version is provided uses version key to save/overwrite that version
        # - if version key isn't provided, creates a new one to put into an OrderedDict for simple version order

        # would be necessary to try and do some sort of sanity checking on a model here
        # 1. check that the model is a valid sub-class of RayModel
        # 2. a dummy pass of the model sub-functions would be needed to make sure no errors
        # that would need information though about the model that might not be in the base class: input_shape etc
        # could live in a setup function? 
        # or maybe just call it sanity_check or runtime_check
        # and run it from a stub in the RayModel object

        if version is None:
            version = str(uuid.uuid4())

        if model_key not in self.repository:
            self.repository[model_key] = OrderedDict()

        self.repository[model_key][version] = model

        return {'model_key':model_key, 'version':version}

    async def pull(self, model_key, version=None):
        # pull a model from the server
        # - attempts to pull model_key from self.repository
        # - if a version isn't provided, will default to the last item in the OrderedDict

        if version is None:
            version = list(self.repository[model_key].keys())[-1]

        return self.repository[model_key][version]

    async def list(self, model_key=None):
        # lists the current registered models
        # - if model key is provided will try to provide the model_key's versions
        if model_key is None:
            return self.repository.keys()

            return {'model_key':self.repository[model_key].keys()}

# Deploy the model server

# END ModelServer Deployment

# Model Deployment
# MyModel Class => Push Model => Pull Model => Run Model
# - build a simple Torch model using TorchPredictor (requires torch)
# - TorchModel class would include stubs for setup, train, validate, etc
# - could use a decorator in the same way as serve.deployment
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import ray
from ray import serve
from ray.ml.predictors.integrations.torch import TorchPredictor
#from ray import model
#from ray.ml.models.integrations.torch import TorchModel

#RAY_ADDRESS = 'auto'

if not ray.is_initialized():
    ray.init(namespace='test_ns', address=RAY_ADDRESS)

# class MyModel(TorchModel):
class MyModel(TorchPredictor):
    def __init__(self, shape=(2,), dim=1, depth=None):
        self.shape = shape
        self.dim = dim
        self.depth = depth

        # internal model setup
        if depth is not None:
            model_list = [nn.Linear(shape[0], dim)]
            model_list += [nn.Linear(dim, dim) for n in range(depth)]
            model_list += [nn.Linear(dim, shape[0])]
            self.model = nn.Sequential(*model_list)

            self.model = nn.Sequential(nn.Linear(shape[0], dim), nn.Linear(dim, shape[0]))

        self.preprocessor = None
        self.loss = F.mse_loss
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)

    def forward(self, x):
        # inference function
        if not torch.is_tensor(x):
            x = torch.tensor(x, dtype=torch.float)

        return self.model(x)

    async def train(self, x, y):
        # use a dummy train function for now
        # - will crash if tries to train because of missing gradient
        #z = self.model(x)
        #loss = self.loss(z, y)

        await asyncio.sleep(1)
        loss = 1

        return loss

    async def validate(self, x, y):
        # dummy validate function
        #z = self.model(x)
        #loss = self.loss(x, y)

        await asyncio.sleep(1)
        loss = 0

        return loss

# Get the ModelServer
# could be replaced by the Ray model object
# model.push()
# model.pull()
model_server = serve.get_deployment('ModelServer').get_handle(sync=False)

# Push the model to the server
# MyModel.push(model_key='MyModel')
# - why do these remote calls return an objectref object that needs to be awaited?
# - causes double awaiting
# - comes from the way remote works in the RayServeHandle with async methods
# => ray.serve.handle => RayServeHandle => remote => _remote
m_ref = await model_server.push.remote(model=MyModel, model_key='MyModel')
m_ref = await m_ref

# Pull the model from the server
m_class = await model_server.pull.remote(**m_ref)
m_class = await m_class

# Build an instance of the model
model = m_class()

# Run model
z = model.forward(np.zeros((32,2)))

# END Model Deployment

# Client Code
# - must have registered the model before running this
# - could be interesting trying to use a model from a backend missing on the client machine
# - shouldn't work but how much of pytorch gets bundled with the MyModel object?
# => how much of pytorch can we bundle with the MyModel object?
import ray
from ray import serve
#from ray import model

import numpy as np

#RAY_ADDRESS = 'auto'

if not ray.is_initialized():
    ray.init(namespace='test_ns', address=RAY_ADDRESS)

# Get ModelServer
model_server = serve.get_deployment('ModelServer').get_handle(sync=False)

# Get MyModel class
# => double await from RayServeHandle._remote
# - doesn't provide a version key so will pull the most recent version
# m_class = model.pull(model_key='MyModel')
m_class = await model_server.pull.remote('MyModel')
m_class = await m_class

# Build instance of MyModel
# - can inject configuration here
# - this would affect checkpointing or other methods of loading a pretrained model
# - checkpoints would need these parameters bundled
# model = model.pull(model_key='MyModel', version='0', 
#       config_dict={'shape':16, 'dim':8'}, checkpoint='checkpoint_path')
model = m_class(shape=(16,), dim=8)

# Run Model
z = model.forward(np.zeros((32,16)))

# END Client Code

Doesn’t do much, but it works and I like it. I’ll try to extend this to work with a train server next.