How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Hi, I’m having a problem implementing a custom attention network. Here’s the context. I’m working on adapting a custom attention network based on the paper: FEDformer: Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting
Link: [2202.07125] Transformers in Time Series: A Survey
Github: GitHub - MAZiqing/FEDformer
The issue I’m having is that the trial is super slow when using this attention model. In comparison, when using the default Rllib LSTM or Attention models with a fully connected layers of size [[1024,2048, 2048, 4096, 4096, 2048, 2048,1024,512]] (150 million parameters), the time per iteration is 300 seconds on average when the train_batch size is 250. However, using this model (50 million parameters) takes 90 minutes per iteration on average.
So I have some questions:
- What part of the adaptation could be causing the trial to take long per iteration?
- What determines how long the time per iteration in Rllib?
Below is how I’ve adapted the model for use with Rllib:
class Configs(object):
ab = 0
mode_select = "random"
version = "Wavelets"
moving_avg = [7, 12, 14, 26, 48]
L = 4 # Complexity Computations & Analysis - MWT_CZ1d class
# Whenever 'L < 4', MultiWaveCross returns NaNs on the next decomposition levels
modes = 64 # Complexity Computations & Analysis - FourierCrossAttentionW
base = "legendre"
cross_activation = "silu"
seq_len = 75
label_len = 25
pred_len = 25
output_attention = True
enc_in = 239
dec_in = 239
c_out = 239
d_model = 512
d_ff = 512
embed = "timeF"
dropout = 0.05
freq = "c_in"
n_heads = 4
encoder_layers = 2
decoder_layers = 1
activation = "silu"
stride = 1
wavelet = 0
factor = 1
# MWT
k_mwt = 4
alpha_mwt = 16
c_mwt = 128
nCZ_mwt = 1
# MWC
c_mwc = 64
k_mwc = 8
CUSTOM_CONFIG = {
"max_seq_len": 20,
"cell_size": 4096,
"time_major": False,
"use_prev_action": True,
"use_prev_reward": True,
"use_prev_obs": False,
"num_frames": 100,
"self.logit_bias": True,
"logit_activation": "swish",
}
class FEDformer(TorchModelV2, nn.Module):
"""FEDFormer serving as interface between custom model and Ray."""
def __init__(
self,
obs_space: Space,
action_space: Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
):
nn.Module.__init__(self)
super().__init__(
obs_space, action_space, num_outputs, model_config, name # type:ignore
) # type:ignore
self.custom_config = CUSTOM_CONFIG
self.use_prev_action = self.custom_config.get("use_prev_action", False)
self.use_prev_reward = self.custom_config.get("use_prev_reward", False)
self.use_prev_obs = self.custom_config.get("use_prev_obs", False)
self.logit_activation: bool = self.custom_config["logit_activation"]
self.num_frames = int(self.custom_config["num_frames"])
self.configs = Configs()
self._num_outputs = 0
self.action_space_struct = get_base_struct_from_space(self.action_space)
self.action_dim = 0
for space in tree.flatten(self.action_space_struct):
if isinstance(space, Discrete):
self.action_dim += space.n
elif isinstance(space, MultiDiscrete):
self.action_dim += np.sum(space.nvec)
elif space.shape is not None:
self.action_dim += int(np.product(space.shape))
else:
self.action_dim += int(len(space))
# Add prev-action/reward nodes to input for post processing layer.
if self.use_prev_action:
self._num_outputs += int(self.action_dim)
if self.use_prev_reward:
self._num_outputs += 1
# Add obs nodes to input for post processing linear layer since obs
# feature dimensions are constantly changing in our env
self._num_outputs += int(self.obs_space.shape[1]) # type: ignore
# print(f"num_outputs with action, reward and obs dimensions: {self._num_outputs}")
# initialize attention module
self.fedformer = Model(self.configs, self._num_outputs)
# Postprocess attention output with another hidden layer and compute values (including correct initialization logic).
self._num_outputs *= configs.pred_len
self.logits = SlimFC(
in_size=self._num_outputs,
out_size=num_outputs,
initializer=torch.nn.init.xavier_uniform_,
activation_fn=self.logit_activation,
)
self.values_out = SlimFC(
in_size=self._num_outputs,
out_size=1,
initializer=torch.nn.init.xavier_uniform_,
activation_fn=None,
)
# Set final num_outputs to correct value (depending on action space).
self.num_outputs = num_outputs
# Setup trajectory views for the previous num_frames observations, actions and rewards
if self.use_prev_action:
self.view_requirements[
SampleBatch.PREV_ACTIONS
] = ViewRequirement( # type:ignore
data_col=SampleBatch.ACTIONS,
space=self.action_space,
shift=f"-{self.num_frames}:-1",
)
if self.use_prev_reward:
self.view_requirements[
SampleBatch.PREV_REWARDS
] = ViewRequirement( # type:ignore
data_col=SampleBatch.REWARDS, shift=f"-{self.num_frames}:-1"
)
# For test environment, eg, Cartpole, whose obs dimensions shape is not a 3d tensor
if self.use_prev_obs:
self.view_requirements[
SampleBatch.PREV_REWARDS
] = ViewRequirement( # type:ignore
data_col=SampleBatch.REWARDS, shift=f"-{self.num_frames}:-1"
)
@override(TorchModelV2)
def forward(
self,
input_dict: Dict[str, Tensor],
state: List[Tensor],
seq_lens: Tensor,
) -> Tuple[TensorType, List[TensorType]]:
# Concat. prev-action/reward/observations if required.
prev_a_r_o = []
# Prev actions.
if self.use_prev_action:
prev_a = input_dict[SampleBatch.PREV_ACTIONS].float()
actions = flatten_inputs_to_1d_tensor(
inputs=prev_a,
spaces_struct=self.action_space_struct,
time_axis=True,
)
# print(f"flat actions: {actions.shape}")
# flat actions: torch.Size([32, 100, 55])
prev_a_r_o.append(actions)
# Prev rewards.
if self.use_prev_reward:
prev_r = input_dict[SampleBatch.PREV_REWARDS].float()
rewards = prev_r.unsqueeze(2) # type:ignore
# print(f"rewards: {rewards}")
# print(f"prev_r: {prev_r.shape}")
# output of print: rewards: torch.Size([32, 100, 1])
prev_a_r_o.append(rewards)
# Concat prev. actions + rewards to the "main" input.
if prev_a_r_o:
cat_inputs = torch.cat(
[input_dict[SampleBatch.OBS].float()] + prev_a_r_o, dim=2
)
# print(f"cat_inputs with rewards and actions: {cat_inputs}")
# cat_inputs: torch.Size([32, 100, 239])
else:
cat_inputs = input_dict[SampleBatch.OBS].float()
# print(f"cat_inputs: {cat_inputs}")
# cat_inputs: torch.Size([32, 100, 183])
# Push everything through our Model.
output, new_state = self.forward_rnn(cat_inputs, state, seq_lens)
output = torch.reshape(output, [-1, self.num_outputs])
return output, new_state
def forward_rnn(
self, inputs: Tensor, state: List[Tensor], seq_lens: Tensor
) -> Tuple[Tensor, List[TensorType]]:
# slice the inputs accordingly
s_begin = 0
s_end = s_begin + self.configs.seq_len
r_begin = s_end - self.configs.label_len
r_end = r_begin + self.configs.label_len + self.configs.pred_len
x_enc = inputs[:, s_begin:s_end, :] # type:ignore
x_dec = inputs[:, r_begin:r_end, :] # type: ignore
x_mark_enc = inputs[:, s_begin:s_end, :]
x_mark_dec = inputs[:, r_begin:r_end, :] # type:ignore
#print(f"x_enc: {x_enc.shape}")
#print(f"x_dec: {x_dec.shape}")
#print(f"x_enc_mark: {x_mark_enc.shape}")
#print(f"x_dec_mark: {x_mark_dec.shape}")
#x_enc: torch.Size([32, 75, 239])
#x_dec: torch.Size([32, 50, 239])
#x_enc_mark: torch.Size([32, 75, 239])
#x_dec_mark: torch.Size([32, 50, 239])
# Forward through Attention Network
self._features, attn = self.fedformer(x_enc, x_mark_enc, x_dec, x_mark_dec)
# print(f"pre reshape features: {self._features}")
# print(f"attn: {attn}")
# pre reshape features: torch. Size([32, 25, 239])
# attn: torch. Size([None, None])
# reshape output tensor - merge the 2nd and 3rd dimensions
self._features = self._features.view(
[-1, self._features.shape[1] * self._features.shape[2]]
)
# print(f"select features: {self._features.shape}")
# print(f"post reshape features: {self._features}")
# reshaped features: torch.Size([32, 239])
# Forward through the last output layer
# model_out= F.silu(self.logits(self._features))
model_out = self.logits(self._features)
# print(f"model_out: {model_out}")
# model_out pre reshape: torch.Size([32, 55])
return model_out, state
@override(ModelV2)
def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]: # type: ignore
# Get the device where the model is located
# device_ = next(self.parameters()).device
return [
torch.zeros(
self.view_requirements[SampleBatch.OBS].space.shape # type: ignore
)
]
@override(ModelV2)
def value_function(self) -> TensorType:
if self._features is None:
raise ValueError(
"self._features cannot be None. Must call forward first AND must have value branch!"
)
return torch.reshape(self.values_out(self._features), [-1]) # type: ignore
PS: I’ve printed the tensors to try determine whether the forward methods are working. They are not empty, so that rules out the possibility that the forward methods are returning NaNs.
Any help will be appreciated