Trainable class from SACConfig.build()?

I am in the process of migrating a Tensorflow SAC trainer written from scratch to RLLIB, but struggling with some ideas in RLLIB. I try to use Ray Tune for a SAC algorithm and plugged in some config parameters like:

register_env("Env", Env)
sac_config = SACConfig().environment("Env")
sac_config = sac_config.training(
    clip_actions = True,
    gamma = 0.99,
    optimization_config = {
        "actor_learning_rate": 0.01,
        "critic_learning_rate": 0.01,
        "entropy_learning_rate": 0.01,
    },
    policy_model_config = policy_model_config,
    q_model_config = q_model_config,
    store_buffer_in_checkpoints = True,
    target_network_update_freq = 1,
    tau = 0.005,
    train_batch_size = 256,
    twin_q = True,
)

the I run:

sac = sac_config.build()

I first thought sac would be a trainable class that Tune can take in under trainable class API, since build function in Algorithm_config (ray/algorithm_config.py at master · ray-project/ray · GitHub) spits out an Algorithm, and Algorithm seems to be a subclass of Trainable (ray/algorithm.py at master · ray-project/ray · GitHub). So as the next step I start Tune:

tuner = tune.Tuner(sac, 
    param_space = param_space,  #for now, param_space is just a grid search of tau 
    tune_config = TuneConfig(
        scheduler = ASHAScheduler(metric = "agg_reward", mode = "max"), #agg_reward is reported in step of environment
        num_samples = 1,
        ),
    run_config = RunConfig(
        stop = {"training_iteration": 100},
        local_dir = ospath.join(ospath.dirname(__file__),"ray_results"),
        verbose = 1,
        log_to_file = True,
        checkpoint_config = CheckpointConfig(
            checkpoint_frequency = 10
            ),
        ),
    )

Then the error pops up:
raise TuneError("Improper 'run' - not string nor trainable.")

If I register sac as a trainable via register_trainable("sac_trainable", sac()), it says TypeError: 'SAC' object is not callable.

I suspect Tune is telling me that sac is not a trainable class, but if not, what is the type of object that build blurts out? Could someone please point out what went wrong in the above and how I can make Tune work? Thank you!

Hi @Teenforever

I think this is telling you it should look like this:
register_trainable("sac_trainable", sac)

In your case I think the more canonical approach would be to put this in the sac_config:
tau=tune.grid_search([...])

Do not build but run like this:

tuner = tune.Tuner("SAC",
    param_space = sac_config, 
   ... 
1 Like

Works like a breeze, thanks @mannyv .

For whoever in the future wondering if it’s necessary to convert the sac_config to a dict using .dict() - I tried and this isn’t necessary. In fact, if one tries to convert the config to a dict and if the config contains gridsearch (and probably other tune instructions), it blurts out an error.

1 Like