Time per iteration is high when using a custom model

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi, I’m having a problem implementing a custom attention network. Here’s the context. I’m working on adapting a custom attention network based on the paper: FEDformer: Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting
Link: [2202.07125] Transformers in Time Series: A Survey
Github: GitHub - MAZiqing/FEDformer

The issue I’m having is that the trial is super slow when using this attention model. In comparison, when using the default Rllib LSTM or Attention models with a fully connected layers of size [[1024,2048, 2048, 4096, 4096, 2048, 2048,1024,512]] (150 million parameters), the time per iteration is 300 seconds on average when the train_batch size is 250. However, using this model (50 million parameters) takes 90 minutes per iteration on average.

So I have some questions:

  1. What part of the adaptation could be causing the trial to take long per iteration?
  2. What determines how long the time per iteration in Rllib?

Below is how I’ve adapted the model for use with Rllib:

class Configs(object):
    ab = 0
    mode_select = "random"
    version = "Wavelets"
    moving_avg = [7, 12, 14, 26, 48]
    L = 4  # Complexity Computations & Analysis - MWT_CZ1d class
    # Whenever 'L < 4', MultiWaveCross returns NaNs on the next decomposition levels
    modes = 64  # Complexity Computations & Analysis - FourierCrossAttentionW
    base = "legendre"
    cross_activation = "silu"
    seq_len = 75
    label_len = 25
    pred_len = 25
    output_attention = True
    enc_in = 239
    dec_in = 239
    c_out = 239
    d_model = 512
    d_ff = 512
    embed = "timeF"
    dropout = 0.05
    freq = "c_in"
    n_heads = 4
    encoder_layers = 2
    decoder_layers = 1
    activation = "silu"
    stride = 1
    wavelet = 0 
    factor = 1  

    # MWT
    k_mwt = 4
    alpha_mwt = 16
    c_mwt = 128
    nCZ_mwt = 1
    # MWC
    c_mwc = 64
    k_mwc = 8


CUSTOM_CONFIG = {
    "max_seq_len": 20,
    "cell_size": 4096,
    "time_major": False,
    "use_prev_action": True,
    "use_prev_reward": True,
    "use_prev_obs": False,
    "num_frames": 100,
    "self.logit_bias": True,
    "logit_activation": "swish",
}

class FEDformer(TorchModelV2, nn.Module):
    """FEDFormer serving as interface between custom model and Ray."""

    def __init__(
        self,
        obs_space: Space,
        action_space: Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
    ):
        nn.Module.__init__(self)
        super().__init__(
            obs_space, action_space, num_outputs, model_config, name  # type:ignore
        )  # type:ignore

        self.custom_config = CUSTOM_CONFIG
        self.use_prev_action = self.custom_config.get("use_prev_action", False)
        self.use_prev_reward = self.custom_config.get("use_prev_reward", False)
        self.use_prev_obs = self.custom_config.get("use_prev_obs", False)
        self.logit_activation: bool = self.custom_config["logit_activation"]
        self.num_frames = int(self.custom_config["num_frames"])
        self.configs = Configs()

        self._num_outputs = 0
        self.action_space_struct = get_base_struct_from_space(self.action_space)
        self.action_dim = 0

        for space in tree.flatten(self.action_space_struct):
            if isinstance(space, Discrete):
                self.action_dim += space.n
            elif isinstance(space, MultiDiscrete):
                self.action_dim += np.sum(space.nvec)
            elif space.shape is not None:
                self.action_dim += int(np.product(space.shape))
            else:
                self.action_dim += int(len(space))

        # Add prev-action/reward nodes to input for post processing layer.
        if self.use_prev_action:
            self._num_outputs += int(self.action_dim)
        if self.use_prev_reward:
            self._num_outputs += 1

        # Add obs nodes to input for post processing linear layer since obs
        #  feature dimensions are constantly changing in our env
        self._num_outputs += int(self.obs_space.shape[1])  # type: ignore
        # print(f"num_outputs with action, reward and obs dimensions: {self._num_outputs}")

        # initialize attention module
        self.fedformer = Model(self.configs, self._num_outputs)

        # Postprocess attention output with another hidden layer and compute values (including correct initialization logic).
        self._num_outputs *= configs.pred_len
        self.logits = SlimFC(
            in_size=self._num_outputs,
            out_size=num_outputs,
            initializer=torch.nn.init.xavier_uniform_,
            activation_fn=self.logit_activation,
        )

        self.values_out = SlimFC(
            in_size=self._num_outputs,
            out_size=1,
            initializer=torch.nn.init.xavier_uniform_,
            activation_fn=None,
        )

        # Set final num_outputs to correct value (depending on action space).
        self.num_outputs = num_outputs

        # Setup trajectory views for the previous num_frames observations, actions and rewards
        if self.use_prev_action:
            self.view_requirements[
                SampleBatch.PREV_ACTIONS
            ] = ViewRequirement(  # type:ignore
                data_col=SampleBatch.ACTIONS,
                space=self.action_space,
                shift=f"-{self.num_frames}:-1",
            )
        if self.use_prev_reward:
            self.view_requirements[
                SampleBatch.PREV_REWARDS
            ] = ViewRequirement(  # type:ignore
                data_col=SampleBatch.REWARDS, shift=f"-{self.num_frames}:-1"
            )

        # For test environment, eg, Cartpole, whose obs dimensions shape is not a 3d tensor
        if self.use_prev_obs: 
            self.view_requirements[
                SampleBatch.PREV_REWARDS
            ] = ViewRequirement(  # type:ignore
                data_col=SampleBatch.REWARDS, shift=f"-{self.num_frames}:-1"
            )

    @override(TorchModelV2)
    def forward(
        self,
        input_dict: Dict[str, Tensor],
        state: List[Tensor],
        seq_lens: Tensor,
    ) -> Tuple[TensorType, List[TensorType]]:
        # Concat. prev-action/reward/observations if required.
        prev_a_r_o = []

        # Prev actions.
        if self.use_prev_action:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS].float()
            actions = flatten_inputs_to_1d_tensor(
                inputs=prev_a,
                spaces_struct=self.action_space_struct,
                time_axis=True,
            )
            # print(f"flat actions: {actions.shape}")
            # flat actions: torch.Size([32, 100, 55])
            prev_a_r_o.append(actions)

        # Prev rewards.
        if self.use_prev_reward:
            prev_r = input_dict[SampleBatch.PREV_REWARDS].float()
            rewards = prev_r.unsqueeze(2)  # type:ignore
            # print(f"rewards: {rewards}")
            # print(f"prev_r: {prev_r.shape}")
            # output of print: rewards: torch.Size([32, 100, 1])
            prev_a_r_o.append(rewards)
            

        # Concat prev. actions + rewards to the "main" input.
        if prev_a_r_o:
            cat_inputs = torch.cat(
                [input_dict[SampleBatch.OBS].float()] + prev_a_r_o, dim=2
            ) 
            # print(f"cat_inputs with rewards and actions: {cat_inputs}")
            # cat_inputs: torch.Size([32, 100, 239])
        else:
            cat_inputs = input_dict[SampleBatch.OBS].float()
            # print(f"cat_inputs: {cat_inputs}")
            # cat_inputs: torch.Size([32, 100, 183])

        # Push everything through our Model.
        output, new_state = self.forward_rnn(cat_inputs, state, seq_lens)
        output = torch.reshape(output, [-1, self.num_outputs])
        return output, new_state

    def forward_rnn(
        self, inputs: Tensor, state: List[Tensor], seq_lens: Tensor
    ) -> Tuple[Tensor, List[TensorType]]:
        # slice the inputs accordingly
        s_begin = 0
        s_end = s_begin + self.configs.seq_len
        r_begin = s_end - self.configs.label_len
        r_end = r_begin + self.configs.label_len + self.configs.pred_len


        x_enc = inputs[:, s_begin:s_end, :]  # type:ignore
        x_dec = inputs[:, r_begin:r_end, :]  # type: ignore

        x_mark_enc = inputs[:, s_begin:s_end, :] 
        x_mark_dec = inputs[:, r_begin:r_end, :]  # type:ignore

        #print(f"x_enc: {x_enc.shape}")
        #print(f"x_dec: {x_dec.shape}")
        #print(f"x_enc_mark: {x_mark_enc.shape}")
        #print(f"x_dec_mark: {x_mark_dec.shape}")
        #x_enc: torch.Size([32, 75, 239])
        #x_dec: torch.Size([32, 50, 239])
        #x_enc_mark: torch.Size([32, 75, 239])
        #x_dec_mark: torch.Size([32, 50, 239])

        # Forward through Attention Network
        self._features, attn = self.fedformer(x_enc, x_mark_enc, x_dec, x_mark_dec)
        # print(f"pre reshape features: {self._features}")
        # print(f"attn: {attn}")

        # pre reshape features: torch. Size([32, 25, 239])
        # attn: torch. Size([None, None])

        # reshape output tensor - merge the 2nd and 3rd dimensions
        self._features = self._features.view(
            [-1, self._features.shape[1] * self._features.shape[2]]
        )
        # print(f"select features: {self._features.shape}")
        # print(f"post reshape features: {self._features}")
        # reshaped features: torch.Size([32, 239])

        # Forward through the last output layer
        # model_out= F.silu(self.logits(self._features))
        model_out = self.logits(self._features)
        # print(f"model_out: {model_out}")
        # model_out pre reshape: torch.Size([32, 55])

        return model_out, state

    @override(ModelV2)
    def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:  # type: ignore
        # Get the device where the model is located
        # device_ = next(self.parameters()).device
        return [
            torch.zeros(
                self.view_requirements[SampleBatch.OBS].space.shape # type: ignore
            ) 
        ]
    
    @override(ModelV2)
    def value_function(self) -> TensorType:
        if self._features is None:
            raise ValueError(
                "self._features cannot be None. Must call forward first AND must have value branch!"
            )
        
        return torch.reshape(self.values_out(self._features), [-1])  # type: ignore

PS: I’ve printed the tensors to try determine whether the forward methods are working. They are not empty, so that rules out the possibility that the forward methods are returning NaNs.

Any help will be appreciated

I decided to use a sanity check to confirm that the code adaptation is the problem, so I used a different custom model. The problem is the same. The time per iteration for the below custom LSTM is 50 minutes. What could be the problem?

CUSTOM_CONFIG = {
    "max_seq_len": 20,
    "cell_size": 2048,
    "num_layers": 1,
    "time_major": False,
    "use_prev_action": False,
    "use_prev_reward": False,
    "use_prev_obs": False,
    "num_frames": 100,
    "self.logit_bias": True,
    "logit_activation": "swish",
}

configs = Configs()


class LSTM(RecurrentNetwork, nn.Module):
    """LSTM serving as interface between custom model and Ray."""

    def __init__(
        self,
        obs_space: Space,
        action_space: Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
    ):
        nn.Module.__init__(self)
        super().__init__(
            obs_space, action_space, num_outputs, model_config, name  # type:ignore
        )  # type:ignore

        self.custom_config = CUSTOM_CONFIG
        self.cell_size: int = self.custom_config["cell_size"]
        self.num_layers: int = self.custom_config["num_layers"]
        self.time_major = model_config.get("_time_major", False)
        self.use_prev_action = self.custom_config["use_prev_action"]
        self.use_prev_reward = self.custom_config["use_prev_reward"]
        self.use_prev_obs = self.custom_config["use_prev_obs"]
        self.logit_activation: bool = self.custom_config["logit_activation"]

        self.num_frames = int(self.custom_config["num_frames"])

        self._num_outputs = 0
        self.action_space_struct = get_base_struct_from_space(self.action_space)
        self.action_dim = 0

        for space in tree.flatten(self.action_space_struct):
            if isinstance(space, Discrete):
                self.action_dim += space.n
            elif isinstance(space, MultiDiscrete):
                self.action_dim += np.sum(space.nvec)
            elif space.shape is not None:
                self.action_dim += int(np.product(space.shape))
            else:
                self.action_dim += int(len(space))

        # Add prev-action/reward nodes to input for post processing layer.
        if self.use_prev_action:
            self._num_outputs += int(self.action_dim)
        if self.use_prev_reward:
            self._num_outputs += 1

        # Add obs nodes to input for post processing linear layer since obs
        #  feature dimensions are constantly changing in our env
        self._num_outputs += int(self.obs_space.shape[1])  # type: ignore
        self._num_outputs *= self.num_frames
        # print(f"num_outputs with action, reward and obs dimensions: {self._num_outputs}")

        # initialize lstm module
        self.lstm = nn.LSTM(
            input_size=self._num_outputs,
            hidden_size=self.cell_size,
            num_layers=self.num_layers,
            bias=True,
            batch_first=not self.time_major,
            dropout=0.4 if self.num_layers > 1 else 0,
            bidirectional=False,
            proj_size=0,
        )

        # Postprocess lstm output with another hidden layer and compute values (including correct initialization logic).
        # Set self.num_outputs to the number of output nodes desired by the
        # caller of this constructor.
        self.num_outputs = num_outputs

        self.logits = SlimFC(
            in_size=self.cell_size,
            out_size=self.num_outputs,
            activation_fn=self.logit_activation,
            initializer=torch.nn.init.xavier_uniform_,
        )
        self.values_out = SlimFC(
            in_size=self.cell_size,
            out_size=1,
            activation_fn=None,
            initializer=torch.nn.init.xavier_uniform_,
        )

        # Setup trajectory views for the previous num_frames observations, actions and rewards
        if self.use_prev_action:
            self.view_requirements[
                SampleBatch.PREV_ACTIONS
            ] = ViewRequirement(  # type:ignore
                data_col=SampleBatch.ACTIONS,
                space=self.action_space,
                shift=f"-{self.num_frames}:-1",
            )
        if self.use_prev_reward:
            self.view_requirements[
                SampleBatch.PREV_REWARDS
            ] = ViewRequirement(  # type:ignore
                data_col=SampleBatch.REWARDS, shift=f"-{self.num_frames}:-1"
            )

        # For test environment, eg, Cartpole, whose obs dimensions shape is not a 3d tensor
        if self.use_prev_obs:
            self.view_requirements[
                SampleBatch.PREV_REWARDS
            ] = ViewRequirement(  # type:ignore
                data_col=SampleBatch.REWARDS, shift=f"-{self.num_frames}:-1"
            )

    @override(RecurrentNetwork)
    def forward(
        self,
        input_dict: Dict[str, Tensor],
        state: List[Tensor],
        seq_lens: Tensor,
    ) -> Tuple[TensorType, List[TensorType]]:
        # Concat. prev-action/reward/observations if required.
        prev_a_r_o = []

        # Prev actions.
        if self.use_prev_action:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS].float()
            actions = flatten_inputs_to_1d_tensor(
                inputs=prev_a,
                spaces_struct=self.action_space_struct,
                time_axis=True,
            )
            # print(f"flat actions: {actions.shape}")
            # flat actions: torch.Size([32, 100, 55])
            prev_a_r_o.append(actions)

        # Prev rewards.
        if self.use_prev_reward:
            prev_r = input_dict[SampleBatch.PREV_REWARDS].float()
            rewards = prev_r.unsqueeze(2)  # type:ignore
            # print(f"rewards: {rewards}")
            # print(f"prev_r: {prev_r.shape}")
            # output of print: rewards: torch.Size([32, 100, 1])
            prev_a_r_o.append(rewards)

        # Concat prev. actions + rewards to the "main" input.
        if prev_a_r_o:
            cat_inputs = torch.cat(
                [input_dict[SampleBatch.OBS].float()] + prev_a_r_o, dim=2
            )
            # print(f"cat_inputs with rewards and actions: {cat_inputs}")
            # cat_inputs: torch.Size([32, 100, 239])
        else:
            cat_inputs = input_dict[SampleBatch.OBS].float()
            # print(f"cat_inputs: {cat_inputs}")
            # cat_inputs: torch.Size([32, 100, 183])

        # Push everything through our Model.
        cat_inputs = torch.reshape(
            cat_inputs, [-1, cat_inputs.shape[1] * cat_inputs.shape[2]]
        )

        input_dict["obs_flat"] = cat_inputs
        # print(f"cat_inputs obs_flat: {input_dict['obs_flat'].shape}")
        return super().forward(input_dict, state, seq_lens)

    @override(RecurrentNetwork)
    def forward_rnn(
        self, inputs: Tensor, state: List[Tensor], seq_lens: Tensor
    ) -> Tuple[Tensor, List[TensorType]]:
        # print(f"inputs: {inputs.shape}")

        # Forward through LSTM
        if self.num_layers > 1:
            h_states = [s for s in state]
            c_states = [s for s in state]
            # print(f"initial h_states: {[h_.shape for h_ in h_states]}")
            # print(f"initial c_states: {[h_.shape for h_ in c_states]}")
            h_states = torch.stack(h_states, dim=0)
            c_states = torch.stack(c_states, dim=0)
            # print(f"post stack h_states: {h_states.shape}")
            # print(f"post stack c_states: {c_states.shape}")
            # post stack h_states: torch.Size([5, 1, 4096])
            # post stack c_states: torch.Size([5, 1, 4096])
        else:
            h_states = torch.unsqueeze(state[0], 0)
            c_states = torch.unsqueeze(state[1], 0)

        self._features, (h, c) = self.lstm(inputs, [h_states, c_states])
        # print(f"lstm output: {self._features.shape}")

        # Forward through the last output layer
        model_out = self.logits(self._features)
        # print(f"model_out: {model_out.shape}")

        return model_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

    @override(ModelV2)
    def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
        # Place hidden states on same device as model.
        linear = next(self.logits._model.children())
        if self.num_layers > 1:
            h = []
            # c = []
            for _ in range(self.num_layers):
                h.append(
                    linear.weight.new(1, self.cell_size) # type:ignore
                    .zero_()
                    .squeeze(0)  # type:ignore
                )
            # TODO: Fix this as it returns `np.ndarray has no attribute `device`.
            # TODO: Find out a way around to initialize the hidden and cell states for multi layer LSTMs
        else:
            h = [
                linear.weight.new(1, self.cell_size).zero_().squeeze(0),  # type:ignore
                linear.weight.new(1, self.cell_size).zero_().squeeze(0),  # type:ignore
            ]
        return h

    @override(ModelV2)
    def value_function(self) -> TensorType:
        if self._features is None:
            raise ValueError(
                "self._features cannot be None. Must call forward first AND must have value branch!"
            )

        return torch.reshape(self.values_out(self._features), [-1])  # type: ignore

I’m I doing something wrong?