I’m running the following example and got the following warning:
Function checkpointing is disabled. ... set the train function arguments to be `func(config, checkpoint_dir=None)`
After adding ‘checkpoint_dir=None’, I got another warning messages saying it will be deprecated and use session.report(), but the example actually use session report.
import time
import argparse
from ray import air, tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
def evaluation_fn(step, width, height):
time.sleep(0.1)
return (0.1 + width * step / 100) ** (-1) + height * 0.1
def train_func(config):
step = 0
width, height = config["width"], config["height"]
if session.get_checkpoint():
loaded_checkpoint = session.get_checkpoint()
step = loaded_checkpoint.to_dict()["step"] + 1
print('==== existing step: ', step, width, height)
for step in range(step, 100):
print(' ---- step: ', step, width, height)
intermediate_score = evaluation_fn(step, width, height)
checkpoint = Checkpoint.from_dict({"step": step})
session.report(
{"iterations": step, "mean_loss": intermediate_score}, checkpoint=checkpoint
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
parser.add_argument(
"--server-address",
type=str,
default=None,
required=False,
help="The address of server to connect to if using Ray Client.",
)
args, _ = parser.parse_known_args()
if args.server_address:
import ray
ray.init(f"ray://{args.server_address}")
trainable_with_resources = tune.with_resources(train_func, {"cpu": 8})
tuner = tune.Tuner(
train_func,
run_config=air.RunConfig(
name="hyperband_test",
stop={"training_iteration": 1 if args.smoke_test else 10},
),
tune_config=tune.TuneConfig(
metric="mean_loss",
mode="min",
num_samples=5,
),
param_space={
"steps": 10,
"width": tune.randint(10, 100),
"height": tune.loguniform(10, 100),
},
)
results = tuner.fit()
best_result = results.get_best_result()
print("Best hyperparameters: ", best_result.config)
best_checkpoint = best_result.checkpoint
checkpoint_data = best_checkpoint.to_dict()