Is it normal for tf2 to be slower 6-8 times compared to pytorch ? also tf2 uses 2-3 times more memory, cpu and gpu resources
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.
Is it normal for tf2 to be slower 6-8 times compared to pytorch ? also tf2 uses 2-3 times more memory, cpu and gpu resources
This depends very much on the setting. For different algorithms, torch, tf and tf2 will result in different throughputs and different wall clock times.
What algorithm are you using? Can you post a complete reproduction script?
Here is an example, using pytorch it uses 10GB of RAM, with tf2 it goes to 100GB. Even if I lower the workers to 1 and num env to 1, torch is still faster at least 5 times.
I am using an nvidia rtx3060 12gb GPU, I don’t have a high end AI GPU
import gym
import numpy as np
import ray
from ray import tune
from ray.tune.logger import pretty_print
from ray.rllib.agents import impala
import random
class MyEnv(gym.Env):
def __init__(self, config=None):
super(MyEnv, self).__init__()
self.action_space = gym.spaces.Box(
low=-1, high=1, shape=(10,), dtype=np.float32)
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7000,),
dtype=np.float32)
def _next_observation(self):
obs = np.random.normal(0,1,7000)
return obs
def _take_action(self, action):
self._reward = random.randrange(-1,1)
def step(self, action):
# Execute one time step within the environment
self._reward = 0
self._take_action(action)
done = False
obs = self._next_observation()
return obs, self._reward, done, {}
def reset(self):
self._reward = 0
self.total_reward = 0
self.visualization = None
return self._next_observation()
if __name__ == "__main__":
ray.init()
cfg = impala.DEFAULT_CONFIG.copy()
cfg["env"] = MyEnv
cfg["num_gpus"] = 1
cfg["num_workers"] = 5
cfg["num_envs_per_worker"] = 5
cfg["framework"] = "tf2"
cfg["horizon"] = 500
cfg["model"] = {
"fcnet_hiddens": [256, 256],
}
agent = impala.ImpalaTrainer(config=cfg, env=MyEnv)
i = 0
while True:
result = agent.train()
if i % 35 == 0:
checkpoint_path = agent.save()
print(pretty_print(result))
print(checkpoint_path)
i += 1
Throughput can differ vastly between framework from case to case. Can you try TF1? If you say “It contributes to significant difficulty to complete my task, but I can work around it.”, I gather that you want to use TF?
looks like TF1 is much faster, about 50% faster than torch, RAM usage is way better than TF2 (about 4 times lower) and I also see an increase of GPU usage from 35% with torch to 60% tf1, maybe because of TF internal parallelization?
I prefer to use torch where possible because it is written in C++ instead of java which does not waste RAM, CPU and GPU resources and torch is much faster overall, but with torch it crashes after a few millions of steps, I have posted this problem here RLlib crashes with more workers and envs - #2 by christy