Default model architecture for tf and torch on some algorithms seem to be different

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I am training both torch and tf2 versions of certain algorithms. I need to have same initial weights for tf2 and torch versions of an algorithm before I train them. I am able to do so for DQN and APEX-DQN by using a callback to save the initial torch weights and then copy the initial weights over to tf2 (transposing them if necessary due to torch and tensorflow shape format convention differences):

# save initial torch weights
if algorithm.get_policy().framework == 'torch':
    print(algorithm.get_policy().model)
    for weights in algorithm.get_policy().get_weights().values():
        print(weights.shape)
        assert weights.ndim == 1 or weights.ndim == 2 or weights.ndim == 4
        if weights.ndim == 1: Callbacks.torch_weights[env].append(weights)
        elif weights.ndim == 2: Callbacks.torch_weights[env].append(weights.T)
        elif weights.ndim == 4: Callbacks.torch_weights[env].append(np.transpose(weights, (2,3,1,0)))
        
# copy torch weights to tf2
elif algorithm.get_policy().framework == 'tf2':
    print(algorithm.get_policy().model.base_model.summary())
    for weights in algorithm.get_policy().get_weights():
        print(weights.shape)
    assert Callbacks.torch_weights[env]
    algorithm.get_policy().set_weights(Callbacks.torch_weights[env])

I am however unable to do so for A2C and other algorithms because the default model architecture seems to be different for tf and torch implementations:

rllib_config = A2CConfig()\
    .framework(framework=tune.grid_search(['torch', 'tf2']), eager_tracing=True)\
    .environment(env='ALE/Breakout-v5', render_env=False, env_config={"frameskip": 1})\
    .resources(num_gpus=0.5)\
    .debugging(seed=0)\
    .callbacks(callbacks_class=Callbacks)\
    .rollouts(num_rollout_workers=5, num_envs_per_worker=5)\
    .training(
        microbatch_size=20,
        lr_schedule=[[0, 0.0007],[20000000, 0.000000000001],],
        )\

When I print the A2C torch default breakout weights:
(4, 256, 1, 1)
(4,)
(16, 4, 8, 8)
(16,)
(32, 16, 4, 4)
(32,)
(256, 32, 11, 11)
(256,)
(1, 256)
(1,)

When I print the A2C tf2 default breakout weights:
(8, 8, 4, 16)
(16,)
(4, 4, 16, 32)
(32,)
(11, 11, 32, 256)
(256,)
(1, 1, 256, 4)
(4,)
(256, 1)
(1,)

When I print the torch model with algorithm.get_policy().model:

When I print the tf2 model with algorithm.get_policy().model.base_model.summary():

I also noticed this difference in architecture/weights in PPO as well.

Questions I have

  1. Shouldn’t both frameworks have the same network architecture and thus similar weight shapes (after the necessary transposing) since I am using the default settings?
  2. How would I go about ensuring they have the same architecture and thus weight shapes?

Versions / Dependencies

Ray 2.2.0
Gym 0.23.1
Python 3.9
Ubuntu

Reproduction script


import numpy as np
import ray
from ray import air, tune
from ray.rllib.algorithms.a2c.a2c import A2CConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.algorithm import Algorithm

class Callbacks(DefaultCallbacks):
    torch_weights = {'Breakout': [], 'BeamRider': [], 'Qbert': [], 'SpaceInvaders': []}
    def on_algorithm_init(
        self,
        *,
        algorithm: Algorithm,
        **kwargs,
    ) -> None:
        for key in Callbacks.torch_weights.keys():
            if key in algorithm.logdir: 
                env = key
                break
        if algorithm.get_policy().framework == 'torch':
            print(algorithm.get_policy().model)
            for weights in algorithm.get_policy().get_weights().values():
                print(weights.shape)
                assert weights.ndim == 1 or weights.ndim == 2 or weights.ndim == 4
                if weights.ndim == 1: Callbacks.torch_weights[env].append(weights)
                elif weights.ndim == 2: Callbacks.torch_weights[env].append(weights.T)
                elif weights.ndim == 4: Callbacks.torch_weights[env].append(np.transpose(weights, (2,3,1,0)))
            print('weights extracted:', algorithm.get_policy().framework)
        elif algorithm.get_policy().framework == 'tf2':
            print(algorithm.get_policy().model.base_model.summary())
            for weights in algorithm.get_policy().get_weights():
                print(weights.shape)
            assert Callbacks.torch_weights[env]
            algorithm.get_policy().set_weights(Callbacks.torch_weights[env])
            print('weights set:', algorithm.get_policy().framework)

rllib_config = A2CConfig()\
    .framework(framework=tune.grid_search(['torch', 'tf2']), eager_tracing=True)\
    .environment(env='ALE/Breakout-v5', render_env=False, env_config={"frameskip": 1})\
    .resources(num_gpus=0.5)\
    .debugging(seed=0)\
    .callbacks(callbacks_class=Callbacks)\
    .rollouts(num_rollout_workers=5, num_envs_per_worker=5)\
    .training(
        microbatch_size=20,
        lr_schedule=[[0, 0.0007],[20000000, 0.000000000001],],
        )\

air_config = air.RunConfig(
    name='A2C',
    stop={'timesteps_total': 10e6},
    checkpoint_config=air.CheckpointConfig(
        checkpoint_at_end=True
    ),
)

tuner = tune.Tuner(
    'A2C',
    param_space=rllib_config,
    run_config=air_config,
)

ray.init()
tuner.fit()
ray.shutdown()

I have also created a GitHub issue for it here

Update
I tried to print the algorithm config using print(algorithm.get_policy().config) for both torch and tf2 and when inspecting the config, I found that the same network structure was defined for both of them: 'fcnet_hiddens': [256, 256], 'fcnet_activation': 'tanh', 'conv_filters': [[16, [8, 8], 4], [32, [4, 4], 2], [256, [11, 11], 1]], 'conv_activation': 'relu', 'post_fcnet_hiddens': [], 'post_fcnet_activation': 'relu', 'free_log_std': False, 'no_final_linear': False

The configs were however not entirely the same. They differed in one variable:
torch: 'simple_optimizer': False
tf2: 'simple_optimizer': True

Hi @rajfly,

Cross posting for others who might find this in the future.

What I was suggesting was you use the weight names and shapes from the call to get_weights(as_dict=True) and match up the parameters that way. If you do that you should see that the names and orders are different but otherwise each one in tf2 has a corresponding one with the same shape in torch. Like this:

tf2 pytorch
conv1/kernel:0 (8, 8, 4, 16) _convs.0._model.1.weight (16, 4, 8, 8)
conv1/bias:0 (16,) _convs.0._model.1.bias (16,)
conv2/kernel:0 (4, 4, 16, 32) _convs.1._model.1.weight (32, 16, 4, 4)
conv2/bias:0 (32,) _convs.1._model.1.bias (32,)
conv3/kernel:0 (11, 11, 32, 256) _convs.2._model.0.weight (256, 32, 11, 11)
conv3/bias:0 (256,) _convs.2._model.0.bias (256,)
conv_out/kernel:0 (1, 1, 256, 4) _logits._model.1.weight (4, 256, 1, 1)
conv_out/bias:0 (4,) _logits._model.1.bias (4,)
value_out/kernel:0 (256, 1) _value_branch._model.0.weight (1, 256)
value_out/bias:0 (1,) _value_branch._model.0.bias (1,)
1 Like