Restore from checkpoint gives tf not present error

Hi, Below is my code. Previously with ray 2.0.0 the worked fine. Now I have shifted to ray 2.2.0 I am not able to load the trained model correctly.

This is how I save the checkpoints

agent = ppo.PPOTrainer(config, env="MA_env")
n_iter = 2800
for n in range(n_iter):
    result = agent.train()

    if n % 5 == 0:
        checkpoint =
        print("checkpoint saved at", checkpoint)

This Evaluation

This gives error that tf is not present but I am using torch and it is present in my conda environment.

import os, pdb, matplotlib, tempfile, sys
import numpy as np
import matplotlib.pyplot as plt
import time
from datetime import datetime
from sklearn.metrics import mean_squared_error
import inspect
from pathlib import Path

import gym, ray, natsort
from gym import spaces
from scipy.spatial import distance
from ase import Atoms
# from gpaw import GPAW, PW, FD
from ase.optimize import QuasiNewton, BFGS
from import Trajectory
from import read, write
from import minimize_rotation_and_translation

import torch
import torchani

from ray import tune
from typing import Dict
from ray.tune.logger import pretty_print
from ray.tune.logger import Logger, UnifiedLogger
from ray.rllib.utils.annotations import override
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
import ray.rllib.agents.ppo as ppo
from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
# from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule, ExponentialSchedule, PiecewiseSchedule
from ray.rllib.algorithms.algorithm import Algorithm

from eval_utils import do_optim, read_traj, calc_rmsd

model_A = PolicyNetwork
ModelCatalog.register_custom_model("modelA", model_A)

act_space = spaces.Box(low=-0.05,high=0.05, shape=(3,))
obs_space = spaces.Box(low=-1000,high=1000, shape=(128+3+3+1,))

def gen_policy(atom):
    model = "model{}".format(atom)
    config = {"model": {"custom_model": model,},}
    return (None, obs_space, act_space, config)

policies = {"policy_A": gen_policy("A")}
policy_ids = list(policies.keys())

def policy_mapping_fn(agent_id, **kwargs):
    pol_id = "policy_A"
    return pol_id

def env_creator(env_config):
    return ma_env.MA_env(env_config)  # return an env instance

register_env("MA_env", env_creator)

config = ppo.DEFAULT_CONFIG.copy()

config["multiagent"] = {
        "policy_mapping_fn": policy_mapping_fn,
        "policies": policies,
        "policies_to_train": ["policy_A"],#, "policy_N", "policy_O", "policy_H"],

config["in_evaluation"] = True
config["explore"] = False
config["log_level"] = "WARN"
config["framework"] = "torch"
# config["num_gpus"] =  int(os.environ.get("RLLIB_NUM_GPUS", "0"))
config["num_gpus"] =  1
config["num_workers"] = 1
config["env_config"] =  {"atoms":["C", "H"]}
config["rollout_fragment_length"] = 200
config["vf_share_layers"] = True

def custom_log_creator(custom_path, custom_str):
    timestr ="%Y-%m-%d_%H-%M-%S")
    logdir_prefix = "{}_{}".format(custom_str, timestr)
    def logger_creator(config):
        if not os.path.exists(custom_path):
        logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=custom_path)
        return UnifiedLogger(config, logdir, loggers=None)
    return logger_creator


checkpoint_dir = f"checkpoint_{self.iteration_num:06d}/"
self.agent = Algorithm().from_checkpoint(model_restore + checkpoint_dir)
self.env = ma_env.MA_env({})

for i in range(20):
# while done["__all__"] != True:
    for agent_id, agent_obs in obs.items():
        # print(agent_id, agent_obs)
        policy_id = config['multiagent']['policy_mapping_fn'](agent_id)
    #     action[agent_id] = agent.compute_action(agent_obs, policy_id=policy_id)
        action[agent_id] = self.agent.compute_single_action(agent_obs, policy_id=policy_id)
    obs, rew, done, info = self.env.step(action)

Hi @Rohit_Modee,

Can you share the error trace you get when you run it?

Hey @mannyv, Below is the trace

Traceback (most recent call last):
  File "/home/rohit/ssds/nnp/proj7_RL4Opt/aev_delta_forces_embed/", line 311, in <module>
    optimizer = setup_optimization(root_dir, testset_name, cpk_idx, all_time)
  File "/home/rohit/ssds/nnp/proj7_RL4Opt/aev_delta_forces_embed/", line 157, in __init__
    self.agent = Algorithm().from_checkpoint(model_restore + checkpoint_dir)
  File "/home/rohit/anaconda3/envs/rl2/lib/python3.9/site-packages/ray/rllib/algorithms/", line 368, in __init__
  File "/home/rohit/anaconda3/envs/rl2/lib/python3.9/site-packages/ray/rllib/algorithms/", line 555, in validate
    self._check_if_correct_nn_framework_installed(_tf1, _tf, _torch)
  File "/home/rohit/anaconda3/envs/rl2/lib/python3.9/site-packages/ray/rllib/algorithms/", line 2463, in _check_if_correct_nn_framework_installed
    raise ImportError(
ImportError: TensorFlow was specified as the framework to use (via `config.framework([tf|tf2])`)! However, no installation was found. You can install TensorFlow via `pip install tensorflow`

Hi @Rohit_Modee,

Do you know if the checkpoint was trained with tf? Looking at the code and traceback you shared it looks like the config is not the one you defined but the one that was saved with the checkpoint.

Hey @mannyv,

I have trained the model using a custom PyTorch policy.

in I have defined framework as
config["framework"] = "torch"

I am not using tensorflow.

Also in result.json the framework in torch.

Help would be appreciated!!!

I’d like to open a bug to keep track of this issue.

can you fill this out @Rohit_Modee?

ray new issue form

We’ll probably be able to resolve your issue if you give us a script that we can run ourselves to reproduce your issue. I think you might be able to remove your custom logger, and custom environment.

For the environment you could try using the multi agent random env - ray-project/ray - Sourcegraph

Is there a template to create a minimum working example (MWE) to reproduce this issue?

Any guidelines on how to create it would be helpful.