SAC Networks not Looking like Definition in Configuration

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

Hello everyone,

I am trying to use DQN and SAC to solve a simple small minigrid environment. As expected, DQN has no issue handling it, but SAC seems to struggle. In my search for solutions, I stumbled over a pretty big networks difference between the two, even tho I initialized the networks with the exact same configurations.

Here are my configurations for both algorithms:

class DoubleDefineSACConfig(SACConfig):
    def __init__(self, algo_class=None):
        super().__init__(algo_class)
        self.rollouts(
            num_rollout_workers=2,
            num_envs_per_worker=4,
        )

        self.training(
            policy_model_config={"fcnet_hiddens": [256, 256], 
                    "conv_filters": [[32, [7, 7], 3],
                                    [32, [3, 3], 3]] ,},
            q_model_config={"fcnet_hiddens": [256, 256], 
                    "conv_filters": [[32, [7, 7], 3],
                                    [32, [3, 3], 3]] ,},
            )


        self.rl_module(_enable_rl_module_api=False)

        self.framework(framework='tf')

        self.evaluation(evaluation_interval=5, evaluation_duration=50)

class ConvDQNConfig(DQNConfig):
    def __init__(self, algo_class=None):
        super().__init__(algo_class)
        self.rollouts(
            num_rollout_workers=2,
            num_envs_per_worker=4,
        )

        self.training(
            model={"fcnet_hiddens": [256, 256], 
                    "conv_filters": [[32, [7, 7], 3],
                                    [32, [3, 3], 3]] ,}
        )

        self.rl_module(_enable_rl_module_api=False)

        self.framework(framework='tf')

        self.evaluation(evaluation_interval=5, evaluation_duration=50)

When I load a checkpoint and checkout the model with “policy.model.base_model.summary()” for DQN and “policy.model.action_model.base_model.summary() policy.model.q_net.base_model.summary()” for SAC I get the following output for SAC:

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 observations (InputLayer)      [(None, 7, 7, 3)]    0           []                               
                                                                                                  
 conv1 (Conv2D)                 (None, 3, 3, 32)     4736        ['observations[0][0]']           
                                                                                                  
 conv2 (Conv2D)                 (None, 1, 1, 32)     9248        ['conv1[0][0]']                  
                                                                                                  
 lambda (Lambda)                (None, 32)           0           ['conv2[0][0]']                  
                                                                                                  
 conv_out (Conv2D)              (None, 1, 1, 7)      231         ['conv2[0][0]']                  
                                                                                                  
 value_out (Dense)              (None, 1)            33          ['lambda[0][0]']                 
                                                                                                  
==================================================================================================
Total params: 14,248
Trainable params: 14,248
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 observations (InputLayer)      [(None, 7, 7, 3)]    0           []                               
                                                                                                  
 conv1 (Conv2D)                 (None, 3, 3, 32)     4736        ['observations[0][0]']           
                                                                                                  
 conv2 (Conv2D)                 (None, 1, 1, 32)     9248        ['conv1[0][0]']                  
                                                                                                  
 lambda_1 (Lambda)              (None, 32)           0           ['conv2[0][0]']                  
                                                                                                  
 conv_out (Conv2D)              (None, 1, 1, 7)      231         ['conv2[0][0]']                  
                                                                                                  
 value_out (Dense)              (None, 1)            33          ['lambda_1[0][0]']               
                                                                                                  
==================================================================================================
Total params: 14,248
Trainable params: 14,248
Non-trainable params: 0
__________________________________________________________________________________________________

and DQN:

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 observations (InputLayer)   [(None, 7, 7, 3)]         0         
                                                                 
 conv1 (Conv2D)              (None, 3, 3, 32)          4736      
                                                                 
 conv_out (Conv2D)           (None, 1, 1, 256)         73984     
                                                                 
 lambda (Lambda)             (None, 256)               0         
                                                                 
 value_out (Dense)           (None, 1)                 257       
                                                                 
=================================================================
Total params: 78,977
Trainable params: 78,977
Non-trainable params: 0
_________________________________________________________________

Why do the two SAC networks not like the network from DQN?

PS: I am aware that using SAC to solve a gridworld is not optimal, I just want to be sure that the issue is not on my end.