Restoring RLlib Run Using Tuner.restore

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi,

I’ve been using RLlib in a multi-agent CARLA environment (similar to this implementation) which crashes from time to time due to memory issues, and I would like to resume the training after each crash. The training script and the config file are as follows:

import os
import ray
import yaml
import time
import argparse

from tensorboard import program

from ray import air, tune

from ray.tune.registry import register_env

from carla_env import CarlaEnv

argparser = argparse.ArgumentParser(description='CoPeRL Training Implementation.')

argparser.add_argument('config', help='configuration file')
argparser.add_argument('-d', '--directory',
                       metavar='D',
                       default='/home/coperl/ray_results',
                       help='directory to save the results (default: /home/coperl/ray_results)')
argparser.add_argument('-n', '--name',
                       metavar='N',
                       default='sac_experiment',
                       help='name of the experiment (default: sac_experiment)')
argparser.add_argument('--restore',
                       action='store_true',
                       default=False,
                       help='restore the specified experiment (default: False)')
argparser.add_argument('--tb',
                       action='store_true',
                       default=False,
                       help='activate tensorboard (default: False)')

args = argparser.parse_args()


def parse_config(args):
    '''
    Parse the configuration file.

    Args:
        args: command line arguments.

    Return:
        config: configuration dictionary.
    '''
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    return config

def launch_tensorboard(logdir, host='localhost', port='6006'):
    '''
    Launch TensorBoard.

    Args:
        logdir: directory of the saved results.
        host: host address.
        port: port number.

    Return:

    '''
    tb = program.TensorBoard()
    tb.configure(argv=[None, '--logdir', logdir, '--host', host, '--port', port])
    url = tb.launch()

def env_creator(env_config):
    '''
    Create Gymnasium-like environment.
    
    Args:
        env_config: configuration passed to the environment.

    Return:
        env: environment object.
    '''
    return CarlaEnv(env_config)

def run(args):
    '''
    Run Ray Tuner.

    Args:
        args: command line arguments.

    Return:

    '''
    try:
        ray.init(num_cpus=12, num_gpus=2)

        register_env('carla', env_creator)

        os.system('nvidia-smi')

        if not args.restore:
            tuner = tune.Tuner(
                'SAC',
                run_config=air.RunConfig(
                    name=args.name,
                    storage_path=args.directory,
                    checkpoint_config=air.CheckpointConfig(
                        num_to_keep=2,
                        checkpoint_frequency=1,
                        checkpoint_at_end=True
                    ),
                    stop={'training_iteration': 8192},
                    verbose=2
                ),
                param_space=args.config,
            )
        else:
            tuner = tune.Tuner.restore(os.path.join(args.directory, args.name), 'SAC', resume_errored=True)

        result = tuner.fit().get_best_result()

        print(result)

    except Exception as e:
        print(e)
    finally:
        ray.shutdown()
        time.sleep(10.0)

def main():
    args.config = parse_config(args)

    if args.tb:
        launch_tensorboard(logdir=os.path.join(args.directory, args.name))

    run(args)


if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        ray.shutdown()
    finally:
        print('Done.')
framework: 'torch'

env: 'carla'
disable_env_checking: True

num_workers: 1
num_gpus: 1
num_cpus_per_worker: 8
num_gpus_per_worker: 1

train_batch_size: 256

log_level: 'DEBUG'

ignore_worker_failures: True
restart_failed_sub_environments: False

checkpoint_at_end: True
export_native_model_files: True

keep_per_episode_custom_metrics: True

q_model_config:
  fcnet_hiddens: [256, 256]
  dim: 200
  conv_filters: [
    [16, [3, 3], 2],
    [32, [3, 3], 2],
    [32, [3, 3], 2],
    [64, [3, 3], 2],
    [64, [3, 3], 2],
    [128, [3, 3], 2]
  ]
  post_fcnet_hiddens: [256]

policy_model_config:
  fcnet_hiddens: [256, 256]
  dim: 200
  conv_filters: [
    [16, [3, 3], 2],
    [32, [3, 3], 2],
    [32, [3, 3], 2],
    [64, [3, 3], 2],
    [64, [3, 3], 2],
    [128, [3, 3], 2]
  ]
  post_fcnet_hiddens: [256]

However, most parameters seem to get reset when I use Tuner.restore after a crash, apart from a few like the number of steps, as shown in the following figure (same thing for TD error, mean and max Q, episode reward, etc.)

I tried what was suggested here and it seems that neither the get_weights nor the set_weights functions get accessed. I also tried the algo.save(), Algorithm.from_checkpoint() method and got similar results. I would appreciate it if someone could let me know where I’m going wrong.

Hi, this is a mystery for many people in this forum.
Here are some example of past conversations:

Thanks for letting me know. I’ll try to see if I can figure out where the bug is or if there’s a way around it.

@Goodarz Thanks for posting this and welcome to the forum. I tried to reproduce your issue with this script:

import os
import ray

from ray import air, tune
from ray.rllib.algorithms.sac.sac import SACConfig
from ray.rllib.examples.env.cartpole_crashing import CartPoleCrashing


base_path = "/home/simon/ray_results/test_tuner_restore"

tune.register_env("cartpole_crashing", lambda env_ctx: CartPoleCrashing(env_ctx))

config = (
    SACConfig()
    .environment(
        env=CartPoleCrashing,
        env_config={"crash_after_n_steps": 1000, "p_crash": 0.0001},
    )
    .rollouts(
        rollout_fragment_length=256,
        num_envs_per_worker=1,
        num_rollout_workers=1,
    )
    .fault_tolerance(
        recreate_failed_workers=False,
        restart_failed_sub_environments=False,
        num_consecutive_worker_failures_tolerance=1,
    )
    .training(
        train_batch_size=256,
    )
    .debugging(
        log_level="DEBUG",
        seed=42,
    )
)

if os.path.exists(base_path):
    ray.init(local_mode=True)
    tuner = tune.Tuner.restore(path=base_path, trainable="SAC", resume_errored=True)
    tuner.fit()

# Start fresh.
else:
    ray.init()
    tuner = tune.Tuner(
        "SAC",
        param_space=config,
        run_config=air.RunConfig(
            stop={"num_env_steps_sampled": 100000},
            name="test_tuner_restore",
            checkpoint_config=air.CheckpointConfig(
                checkpoint_frequency=1,
                checkpoint_at_end=True,
            ),
        ),
    )
    tuner.fit()

but could not see any problems (but simple counters that do not play a role here). save_checkpoint() is called and so is laod_from_checkpoint() which in turn calls TorchPolicy.set_state() that restores all optimizers and also calls TorchPolicy.set_weights(). Parameters before and after crashing are the same and I also do not get any discontinuities in the graphs for td_error or q_x.

On which version of ray are you running? If you can give a reproducable example I can take a look into this.

@Lars_Simon_Zehnder I am using Ray 2.9.0.

It may have something to do with the ComplexInputNetwork model, given that my observation space is

Tuple((Box(low=-1.001, high=1.001, shape=(200, 200, 5)),
       Box(low=-1.001, high=1.001, shape=(12, 3))))

(My action space is simply

Box(low=-1.0, high=1.0)

)

In my debugging I noticed that the native model saved to a checkpoint in my case is around ~116 MB, but when I use Algorithm.from_checkpoint() to reinitialize the model (i.e. load the weights and state parameters from the state file) and then immediately use Algorithm.save_checkpoint() without any training, the native model saved to the new checkpoint is ~66 MB. The tower_stats or _last_output attributes of the new model were not the same as the original one, but even after modifying torch_policy.py to save those to the state file and load them from it, the model was smaller (~88 MB) compared to the original one, indicating some information may still be missing. The only solution I found that could solve the problem was doing this:

ray.init(num_cpus=12, num_gpus=2)
  
register_env('carla', env_creator)
  
os.system('nvidia-smi')
  
if not os.path.exists(os.path.join(args.directory, args.name)):
    os.mkdir(os.path.join(args.directory, args.name))
  
if not args.restore:
    sac_config = SACConfig().framework(**args.config['framework']) \
        .environment(**args.config['environment']) \
        .callbacks(**args.config['callbacks']) \
        .rollouts(**args.config['rollouts']) \
        .fault_tolerance(**args.config['fault_tolerance']) \
        .resources(**args.config['resources']) \
        .debugging(**args.config['debugging']) \
        .checkpointing(**args.config['checkpointing']) \
        .reporting(**args.config['reporting']) \
        .training(**args.config['training'])
    
    sac_algo = sac_config.build()
else:
    sac_algo = Algorithm.from_checkpoint(os.path.join(args.directory, args.name))
    
    model = torch.load(os.path.join(args.directory, args.name, 'policies', 'default_policy', 'model', 'model.pt'))
    
    sac_algo.get_policy().model = copy.deepcopy(model)
    sac_algo.get_policy().target_model = copy.deepcopy(model)
  
    gpu_ids = list(range(torch.cuda.device_count()))
  
    devices = [
        torch.device("cuda:{}".format(i))
        for i, id_ in enumerate(gpu_ids)
        if i < args.config['resources']['num_gpus']
    ]
  
    sac_algo.get_policy().model_gpu_towers = []
  
    for i, _ in enumerate(gpu_ids):
        model_copy = copy.deepcopy(model)
        sac_algo.get_policy().model_gpu_towers.append(model_copy.to(devices[i]))
  
    sac_algo.get_policy().model_gpu_towers[0] = sac_algo.get_policy().model
    
    sac_algo.get_policy().target_models = {
        m: copy.deepcopy(sac_algo.get_policy().target_model).to(devices[i])
        for i, m in enumerate(sac_algo.get_policy().model_gpu_towers)
    }
  
    sac_algo.get_policy()._state_inputs = sac_algo.get_policy().model.get_initial_state()
  
    sac_algo.get_policy()._is_recurrent = len(sac_algo.get_policy()._state_inputs) > 0
  
    sac_algo.get_policy()._update_model_view_requirements_from_init_state()
  
    sac_algo.get_policy().view_requirements.update(sac_algo.get_policy().model.view_requirements)
  
    sac_algo.get_policy().unwrapped_model = model
  
    sac_algo.get_policy()._optimizers = force_list(sac_algo.get_policy().optimizer())
  
    sac_algo.get_policy().multi_gpu_param_groups = []
  
    main_params = {p: i for i, p in enumerate(sac_algo.get_policy().model.parameters())}
    
    for o in sac_algo.get_policy()._optimizers:
        param_indices = []
        
        for pg_idx, pg in enumerate(o.param_groups):
            for p in pg["params"]:
                param_indices.append(main_params[p])
        
        sac_algo.get_policy().multi_gpu_param_groups.append(set(param_indices))
  
for i in range(32768):
    print(f'Iteration: {i}')
  
    sac_algo.train()
  
    if i % 8 == 0:
        sac_algo.save_checkpoint(os.path.join(args.directory, args.name))

I’m essentially plugging the old model back in and re-intializing the model-based stuff (like the optimizers). It’s probably not the best solution but seems to be the only one that works for me at the moment.