When I return: [self.conv.weight.new(1, self.cell_size).zero_().squeeze(0)] in the get_initial_state function, and remove the hacky batch fixing code, it results in an error: Input batch size 32 doesn't match hidden0 batch size 4
When I add back in the
if not h_in.shape[0] == x.shape[0]:
missing_h = self.conv.weight.new(x.shape[0] - h_in.shape[0], h_in.shape[1]).zero_()
h_in = torch.vstack((h_in, missing_h))
and have the squeeze, it seems to run. But that still leaves my original question of: why is seq_len four 8s rather than thirty-two 1s since it seems that the batch size is determined by len(seq_len) rather than sum(seq_len)?