How severe does this issue affect your experience of using Ray?
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.
Hello everyone,
I am trying to use DQN and SAC to solve a simple small minigrid environment. As expected, DQN has no issue handling it, but SAC seems to struggle. In my search for solutions, I stumbled over a pretty big networks difference between the two, even tho I initialized the networks with the exact same configurations.
Here are my configurations for both algorithms:
class DoubleDefineSACConfig(SACConfig):
def __init__(self, algo_class=None):
super().__init__(algo_class)
self.rollouts(
num_rollout_workers=2,
num_envs_per_worker=4,
)
self.training(
policy_model_config={"fcnet_hiddens": [256, 256],
"conv_filters": [[32, [7, 7], 3],
[32, [3, 3], 3]] ,},
q_model_config={"fcnet_hiddens": [256, 256],
"conv_filters": [[32, [7, 7], 3],
[32, [3, 3], 3]] ,},
)
self.rl_module(_enable_rl_module_api=False)
self.framework(framework='tf')
self.evaluation(evaluation_interval=5, evaluation_duration=50)
class ConvDQNConfig(DQNConfig):
def __init__(self, algo_class=None):
super().__init__(algo_class)
self.rollouts(
num_rollout_workers=2,
num_envs_per_worker=4,
)
self.training(
model={"fcnet_hiddens": [256, 256],
"conv_filters": [[32, [7, 7], 3],
[32, [3, 3], 3]] ,}
)
self.rl_module(_enable_rl_module_api=False)
self.framework(framework='tf')
self.evaluation(evaluation_interval=5, evaluation_duration=50)
When I load a checkpoint and checkout the model with “policy.model.base_model.summary()” for DQN and “policy.model.action_model.base_model.summary() policy.model.q_net.base_model.summary()” for SAC I get the following output for SAC:
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
observations (InputLayer) [(None, 7, 7, 3)] 0 []
conv1 (Conv2D) (None, 3, 3, 32) 4736 ['observations[0][0]']
conv2 (Conv2D) (None, 1, 1, 32) 9248 ['conv1[0][0]']
lambda (Lambda) (None, 32) 0 ['conv2[0][0]']
conv_out (Conv2D) (None, 1, 1, 7) 231 ['conv2[0][0]']
value_out (Dense) (None, 1) 33 ['lambda[0][0]']
==================================================================================================
Total params: 14,248
Trainable params: 14,248
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
observations (InputLayer) [(None, 7, 7, 3)] 0 []
conv1 (Conv2D) (None, 3, 3, 32) 4736 ['observations[0][0]']
conv2 (Conv2D) (None, 1, 1, 32) 9248 ['conv1[0][0]']
lambda_1 (Lambda) (None, 32) 0 ['conv2[0][0]']
conv_out (Conv2D) (None, 1, 1, 7) 231 ['conv2[0][0]']
value_out (Dense) (None, 1) 33 ['lambda_1[0][0]']
==================================================================================================
Total params: 14,248
Trainable params: 14,248
Non-trainable params: 0
__________________________________________________________________________________________________
and DQN:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
observations (InputLayer) [(None, 7, 7, 3)] 0
conv1 (Conv2D) (None, 3, 3, 32) 4736
conv_out (Conv2D) (None, 1, 1, 256) 73984
lambda (Lambda) (None, 256) 0
value_out (Dense) (None, 1) 257
=================================================================
Total params: 78,977
Trainable params: 78,977
Non-trainable params: 0
_________________________________________________________________
Why do the two SAC networks not like the network from DQN?
PS: I am aware that using SAC to solve a gridworld is not optimal, I just want to be sure that the issue is not on my end.