Environment error: ValueError: The two structures don't have the same nested structure

Hi

I am trying to implement the transformer architecture GtrXL with a stock market trading environment, which you can find here.

This was working a while back and I got results from the transformer. But recently it is not working, and I getting the following error

 First structure: type=tuple str=(array([2.2788222e+02, 2.9840982e-01, 0.0000000e+00, 2.7170289e-01,
        1.4102402e+00, 1.2217970e+00, 1.8251243e+00, 1.0809784e+00,
        8.5718751e-01, 2.5914115e-01, 1.3132960e+00, 1.0967580e+00,
        8.3406353e-01, 2.3491893e+00, 1.0383877e+00, 1.1138891e+00,
        1.8886892e+00, 3.0702892e-01, 1.1083301e+00, 7.0789695e-01,
        4.7711730e-01, 1.1759629e+00, 1.6288950e+00, 5.5352813e-01,
        4.9202183e-01, 5.5303156e-01, 9.6180123e-01, 1.1277000e+00,
        1.0114505e+00, 8.0985188e-01, 4.9966720e-01, 6.7659909e-01,
        1.0117418e+00, 2.0312500e-01, 2.8125000e-01, 0.0000000e+00,
        2.8125000e-01, 7.8125000e-01, 7.9687500e-01, 9.3750000e-02,
        3.5937500e-01, 2.9687500e-01, 6.8750000e-01, 4.6875000e-02,
        5.4687500e-01, 2.8125000e-01, 6.8750000e-01, 1.2500000e-01,
        4.8437500e-01, 7.5000000e-01, 6.8750000e-01, 6.8750000e-01,
        5.3125000e-01, 2.1875000e-01, 8.4375000e-01, 6.4062500e-01,
        8.7500000e-01, 5.7812500e-01, 7.0312500e-01, 6.8750000e-01,
        9.2187500e-01, 9.8437500e-01, 2.8125000e-01, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.5433314e-01,
        1.5433314e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.0468752e-01,
        9.0468752e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.9882810e-01,
        6.9882810e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0677344e+00,
        1.0677344e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.0210940e-01,
        7.0210940e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.2859375e-01,
        4.2859375e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.7187500e-01,
        1.7187500e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.6984375e-01,
        9.6984375e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 5.9585935e-01,
        5.9585935e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.4359374e-01,
        4.4359374e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3819531e+00,
        1.3819531e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.4078122e-01,
        6.4078122e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.7272139e-01,
        6.7272139e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3857105e+00,
        1.3857105e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.0148438e-01,
        2.0148438e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.1117187e-01,
        7.1117187e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.5476562e-01,
        4.5476562e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.1765625e-01,
        3.1765625e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.5320315e-01,
        7.5320315e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0791407e+00,
        1.0791407e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.6893189e-01,
        3.6893189e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.9031250e-01,
        2.9031250e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.0562499e-01,
        3.0562499e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.2921876e-01,
        6.2921876e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.9789064e-01,
        6.9789064e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 5.8257812e-01,
        5.8257812e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.3166015e-01,
        4.3166015e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.8281250e-01,
        3.8281250e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.4242188e-01,
        4.4242188e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.1648440e-01,
        6.1648440e-01], dtype=float32), {})
 
 Second structure: type=ndarray str=[ -792.2486   -2144.3738    -129.80092   2915.458    -1308.43
   1269.5146    2074.0398    1901.5674    2501.5889    1164.0255
   -322.9323    1597.6952   -2573.6863   -2746.5698    2014.0723
  -1753.0637    -850.8143    2718.6846    2060.1797     258.4059
    545.0819    2851.2747   -1561.0944     349.4782   -1130.0878
   -403.55383   1070.6107    1372.8518   -2732.747     2289.9187
   2586.476    -1875.0045   -1937.3246    1576.0361    1310.3674
   -774.18475   -803.65875   2640.113     -840.6407    2802.322
     19.350145  2128.231     -979.46356     31.213558 -1505.1682
    208.86848   -780.9078   -1926.5431    2011.087     -629.80927
   1189.3198    2154.1665    2446.9292   -1517.6228    -167.33015
   -185.99898   1706.5778     824.91113   1430.8196     289.54074
    -83.10947  -2880.832    -1074.4584    2531.774     -290.6278
    329.9323   -1801.7854   -2637.3906      16.099054  -168.64241
   -125.876114 -1034.4384    1642.2655     351.62454    509.90973
    360.6945    -697.22485  -1484.3866   -1492.6287    -459.1929
  -2391.1506   -1728.1654    -764.98883  -1659.634    -2860.677
   -773.3472    2364.6418     916.4717     507.7752    1981.0371
   -257.88364   -637.90393   -901.27203  -1614.2391    -721.26996
   1833.0618   -1154.7478     969.5041     175.16115    223.94052
  -2363.3438     286.42517   1502.3491   -1839.2252     990.35913
   1965.7655    2190.7688     -74.32482  -1554.0822   -1204.554
    359.37643   1198.4722    -960.0168    1388.8052    1451.6226
  -2596.9067    -851.9495   -2175.7922    1587.5688    -322.92984
   1295.961    -2309.2183    2583.9165    2652.8699    2667.5493
   2502.1238    2813.4495    -842.62805   -242.12244   1219.3545
   2137.8694    -829.9557    1579.9534    2924.1921   -1282.7042
   2918.0173   -1075.8888    1481.3325    1201.7566     689.5607
  -1504.4034    1773.555     2896.91      2875.2483    -138.64302
  -2744.6174    -629.5382    1995.0635   -1050.1123     -73.988846
   -825.6564   -2097.2258    -809.13696  -2182.993    -2932.378
  -1921.0228    2604.5742    1624.0813   -2238.1482   -2156.0967
  -1990.5931   -2145.9197    2912.961     -726.6033    1492.1472
   1311.7047   -1292.6624     676.77026   -419.49207  -1351.4688
  -2679.772      920.6515   -1701.5579   -2236.0894    2608.5308
   -522.4806    -729.45245  -1825.2494    1263.7957    -882.4481
    338.66553  -2045.2618   -1575.5939     839.1834    -764.1029
   -687.5354   -2897.113    -2659.8728   -1605.8586    1714.4854
   1740.4427    2742.6748     367.4229    1100.9697   -2560.8313
   1445.6852   -1679.0385   -2801.7346     895.27704   1424.2039
  -2685.7834     634.448     1515.6146   -2839.2126   -2124.1313
   -331.8107   -2572.0156    2901.6733    1867.6338    -889.7915
   2168.2368   -2892.2544    2699.993    -1446.0782   -2005.3076
   -287.88196   -571.70496  -1412.371     1470.7764    1845.962
    321.17136    -32.960793  -486.28726  -2803.8381   -2929.3594
    934.08813    501.7609     391.37888    420.77762  -1653.7288
  -2707.625    -2399.6025    -514.26886   2206.8079   -1886.1824
   2690.6538    -398.5362     715.6277    1739.5406   -2561.3362
     74.25057   -945.2912     634.98395    663.10583    114.414795
    384.81158  -2277.277     1975.9971    1328.6526     819.96796
   1112.7996      55.950527 -1017.8577   -1035.914     1011.9785
    714.5928    1143.4572    -772.09064   1600.602     -283.70132
   -463.30814   -519.4983     -96.57187    662.2404     518.65125
    440.34192   -543.07623   1639.1079    1398.8292    2889.071
  -2044.9774     284.34378   1523.3135   -2078.8965     269.63492
   -990.6091    -844.06244  -1025.7518     769.9258   -2966.9124
  -1770.2169   -2029.7166     -11.077788 -1954.0278     146.9773
   1446.9406    2624.8562    2317.0337      51.89259    603.17914
  -1323.7025    1419.9907   -2526.2334    1230.5354    2883.4102
  -1753.4723    1794.5089     613.6288    2328.141     2853.8284
   2666.7334   -1498.047     1295.5944     967.1999   -2580.7476
  -2294.3696    1557.7621    2936.6091   -1691.9373   -1991.3083
   -895.6827    2094.7136   -1278.9231     -96.7892    -472.42078
   1660.7162      75.07274  -2942.1597   -1975.431    -2479.6326
   2608.9258    2222.379     1456.2847   -1376.8156      83.49044
   -281.9596   -1511.0402    -807.6998   -1640.7294     673.7634
  -1975.97      2338.9548   -2384.6924  ]
 
 More specifically: Substructure "type=tuple str=(array([2.2788222e+02, 2.9840982e-01, 0.0000000e+00, 2.7170289e-01,
        1.4102402e+00, 1.2217970e+00, 1.8251243e+00, 1.0809784e+00,
        8.5718751e-01, 2.5914115e-01, 1.3132960e+00, 1.0967580e+00,
        8.3406353e-01, 2.3491893e+00, 1.0383877e+00, 1.1138891e+00,
        1.8886892e+00, 3.0702892e-01, 1.1083301e+00, 7.0789695e-01,
        4.7711730e-01, 1.1759629e+00, 1.6288950e+00, 5.5352813e-01,
        4.9202183e-01, 5.5303156e-01, 9.6180123e-01, 1.1277000e+00,
        1.0114505e+00, 8.0985188e-01, 4.9966720e-01, 6.7659909e-01,
        1.0117418e+00, 2.0312500e-01, 2.8125000e-01, 0.0000000e+00,
        2.8125000e-01, 7.8125000e-01, 7.9687500e-01, 9.3750000e-02,
        3.5937500e-01, 2.9687500e-01, 6.8750000e-01, 4.6875000e-02,
        5.4687500e-01, 2.8125000e-01, 6.8750000e-01, 1.2500000e-01,
        4.8437500e-01, 7.5000000e-01, 6.8750000e-01, 6.8750000e-01,
        5.3125000e-01, 2.1875000e-01, 8.4375000e-01, 6.4062500e-01,
        8.7500000e-01, 5.7812500e-01, 7.0312500e-01, 6.8750000e-01,
        9.2187500e-01, 9.8437500e-01, 2.8125000e-01, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.5433314e-01,
        1.5433314e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.0468752e-01,
        9.0468752e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.9882810e-01,
        6.9882810e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0677344e+00,
        1.0677344e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.0210940e-01,
        7.0210940e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.2859375e-01,
        4.2859375e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.7187500e-01,
        1.7187500e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.6984375e-01,
        9.6984375e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 5.9585935e-01,
        5.9585935e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.4359374e-01,
        4.4359374e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3819531e+00,
        1.3819531e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.4078122e-01,
        6.4078122e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.7272139e-01,
        6.7272139e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.3857105e+00,
        1.3857105e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.0148438e-01,
        2.0148438e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.1117187e-01,
        7.1117187e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.5476562e-01,
        4.5476562e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.1765625e-01,
        3.1765625e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.5320315e-01,
        7.5320315e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0791407e+00,
        1.0791407e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.6893189e-01,
        3.6893189e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.9031250e-01,
        2.9031250e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.0562499e-01,
        3.0562499e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.2921876e-01,
        6.2921876e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.9789064e-01,
        6.9789064e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 5.8257812e-01,
        5.8257812e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.3166015e-01,
        4.3166015e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.8281250e-01,
        3.8281250e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 4.4242188e-01,
        4.4242188e-01, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 6.1648440e-01,
        6.1648440e-01], dtype=float32), {})" is a sequence, while substructure "type=ndarray str=[ -792.2486   -2144.3738    -129.80092   2915.458    -1308.43
   1269.5146    2074.0398    1901.5674    2501.5889    1164.0255
   -322.9323    1597.6952   -2573.6863   -2746.5698    2014.0723
  -1753.0637    -850.8143    2718.6846    2060.1797     258.4059
    545.0819    2851.2747   -1561.0944     349.4782   -1130.0878
   -403.55383   1070.6107    1372.8518   -2732.747     2289.9187
   2586.476    -1875.0045   -1937.3246    1576.0361    1310.3674
   -774.18475   -803.65875   2640.113     -840.6407    2802.322
     19.350145  2128.231     -979.46356     31.213558 -1505.1682
    208.86848   -780.9078   -1926.5431    2011.087     -629.80927
   1189.3198    2154.1665    2446.9292   -1517.6228    -167.33015
   -185.99898   1706.5778     824.91113   1430.8196     289.54074
    -83.10947  -2880.832    -1074.4584    2531.774     -290.6278
    329.9323   -1801.7854   -2637.3906      16.099054  -168.64241
   -125.876114 -1034.4384    1642.2655     351.62454    509.90973
    360.6945    -697.22485  -1484.3866   -1492.6287    -459.1929
  -2391.1506   -1728.1654    -764.98883  -1659.634    -2860.677
   -773.3472    2364.6418     916.4717     507.7752    1981.0371
   -257.88364   -637.90393   -901.27203  -1614.2391    -721.26996
   1833.0618   -1154.7478     969.5041     175.16115    223.94052
  -2363.3438     286.42517   1502.3491   -1839.2252     990.35913
   1965.7655    2190.7688     -74.32482  -1554.0822   -1204.554
    359.37643   1198.4722    -960.0168    1388.8052    1451.6226
  -2596.9067    -851.9495   -2175.7922    1587.5688    -322.92984
   1295.961    -2309.2183    2583.9165    2652.8699    2667.5493
   2502.1238    2813.4495    -842.62805   -242.12244   1219.3545
   2137.8694    -829.9557    1579.9534    2924.1921   -1282.7042
   2918.0173   -1075.8888    1481.3325    1201.7566     689.5607
  -1504.4034    1773.555     2896.91      2875.2483    -138.64302
  -2744.6174    -629.5382    1995.0635   -1050.1123     -73.988846
   -825.6564   -2097.2258    -809.13696  -2182.993    -2932.378
  -1921.0228    2604.5742    1624.0813   -2238.1482   -2156.0967
  -1990.5931   -2145.9197    2912.961     -726.6033    1492.1472
   1311.7047   -1292.6624     676.77026   -419.49207  -1351.4688
  -2679.772      920.6515   -1701.5579   -2236.0894    2608.5308
   -522.4806    -729.45245  -1825.2494    1263.7957    -882.4481
    338.66553  -2045.2618   -1575.5939     839.1834    -764.1029
   -687.5354   -2897.113    -2659.8728   -1605.8586    1714.4854
   1740.4427    2742.6748     367.4229    1100.9697   -2560.8313
   1445.6852   -1679.0385   -2801.7346     895.27704   1424.2039
  -2685.7834     634.448     1515.6146   -2839.2126   -2124.1313
   -331.8107   -2572.0156    2901.6733    1867.6338    -889.7915
   2168.2368   -2892.2544    2699.993    -1446.0782   -2005.3076
   -287.88196   -571.70496  -1412.371     1470.7764    1845.962
    321.17136    -32.960793  -486.28726  -2803.8381   -2929.3594
    934.08813    501.7609     391.37888    420.77762  -1653.7288
  -2707.625    -2399.6025    -514.26886   2206.8079   -1886.1824
   2690.6538    -398.5362     715.6277    1739.5406   -2561.3362
     74.25057   -945.2912     634.98395    663.10583    114.414795
    384.81158  -2277.277     1975.9971    1328.6526     819.96796
   1112.7996      55.950527 -1017.8577   -1035.914     1011.9785
    714.5928    1143.4572    -772.09064   1600.602     -283.70132
   -463.30814   -519.4983     -96.57187    662.2404     518.65125
    440.34192   -543.07623   1639.1079    1398.8292    2889.071
  -2044.9774     284.34378   1523.3135   -2078.8965     269.63492
   -990.6091    -844.06244  -1025.7518     769.9258   -2966.9124
  -1770.2169   -2029.7166     -11.077788 -1954.0278     146.9773
   1446.9406    2624.8562    2317.0337      51.89259    603.17914
  -1323.7025    1419.9907   -2526.2334    1230.5354    2883.4102
  -1753.4723    1794.5089     613.6288    2328.141     2853.8284
   2666.7334   -1498.047     1295.5944     967.1999   -2580.7476
  -2294.3696    1557.7621    2936.6091   -1691.9373   -1991.3083
   -895.6827    2094.7136   -1278.9231     -96.7892    -472.42078
   1660.7162      75.07274  -2942.1597   -1975.431    -2479.6326
   2608.9258    2222.379     1456.2847   -1376.8156      83.49044
   -281.9596   -1511.0402    -807.6998   -1640.7294     673.7634
  -1975.97      2338.9548   -2384.6924  ]" is not
 Entire first structure:
 (., {})
 Entire second structure:
 .

I am not sure why I am getting this error, and I was not getting this error earlier, but now I am getting this error. So this observation space is for 30 stocks. Also, I am unable to understand the error. Can someone help me with it?

python==3.10.4
ray=2.3.0
OS: WSL2.0

Hey @Athe-kunal, looks like this might be a question about RLlib? I’ve moved the category accordingly, let me know if this is a mistake.

Hi @matthewdeng
Yes, this belongs to the RLlib category. Sorry I missed it. Thanks

Hi @Athe-kunal ,

can you show the complete error message? Where exactly are the objects returned from? From first sight it looks as if there is expected a different structure then returned. I guess it’s coming from Keras. Are you re-using a saved model?

Could it be you are missing the state in your input to compute_actions()? The Transformer needs to see the previous n memory outputs.

Hi @Lars_Simon_Zehnder, Sorry for the delayed reply. Yes this is how my step and reset state look like in the environment

def step(self, actions):
        actions = (actions * self.max_stock).astype(int)

        self.day += 1
        price = self.price_ary[self.day]
        self.stocks_cool_down += 1

        if self.turbulence_bool[self.day] == 0:
            min_action = int(self.max_stock * self.min_stock_rate)  # stock_cd
            for index in np.where(actions < -min_action)[0]:  # sell_index:
                if price[index] > 0:  # Sell only if current asset is > 0
                    sell_num_shares = min(self.stocks[index], -actions[index])
                    self.stocks[index] -= sell_num_shares
                    self.amount += (
                        price[index] * sell_num_shares * (1 - self.sell_cost_pct)
                    )
                    self.stocks_cool_down[index] = 0
            for index in np.where(actions > min_action)[0]:  # buy_index:
                if (
                    price[index] > 0
                ):  # Buy only if the price is > 0 (no missing data in this particular date)
                    buy_num_shares = min(self.amount // price[index], actions[index])
                    self.stocks[index] += buy_num_shares
                    self.amount -= (
                        price[index] * buy_num_shares * (1 + self.buy_cost_pct)
                    )
                    self.stocks_cool_down[index] = 0

        else:  # sell all when turbulence
            self.amount += (self.stocks * price).sum() * (1 - self.sell_cost_pct)
            self.stocks[:] = 0
            self.stocks_cool_down[:] = 0

        state = self.get_state(price)
        total_asset = self.amount + (self.stocks * price).sum()
        reward = (total_asset - self.total_asset) * self.reward_scaling
        self.total_asset = total_asset

        self.gamma_reward = self.gamma_reward * self.gamma + reward
        done = self.day == self.max_step
        if done:
            reward = self.gamma_reward
            self.episode_return = total_asset / self.initial_total_asset
        truncated = done #This env does not require truncated, hence just using done for new Gymnasium style environment
        return state, reward, done,truncated, dict()
def reset(self,seed=None,options=None):
        self.day = 0
        price = self.price_ary[self.day]

        if self.if_train:
            self.stocks = (
                self.initial_stocks + rd.randint(0, 64, size=self.initial_stocks.shape)
            ).astype(np.float32)
            self.stocks_cool_down = np.zeros_like(self.stocks)
            self.amount = (
                self.initial_capital * rd.uniform(0.95, 1.05)
                - (self.stocks * price).sum()
            )
        else:
            self.stocks = self.initial_stocks.astype(np.float32)
            self.stocks_cool_down = np.zeros_like(self.stocks)
            self.amount = self.initial_capital

        self.total_asset = self.amount + (self.stocks * price).sum()
        self.initial_total_asset = self.total_asset
        self.gamma_reward = 0.0
        return self.get_state(price),{}

So for transformer architecture, do I need to do any changes to these two methods? It is complaining about the structure array, so I guess these codes can you help you find the issue, if not please let me know. Thanks in advance

@Athe-kunal, I might have been a little ambiguous with my naming in my last answer.

So with state as input to the compute_actions()/compute_single_action()-method of your policy or algorithm I meant the state of the Transformer not the state of the environment. The GTrXLNet holds an internal state and it appears that this state might be missing. The error message comes from Keras.

How do you run the algorithm? What kind of script are you running there? It is hard to say from the outside where exactly this error comes from. Do you have a stacktrace of the error?

Hi @Lars_Simon_Zehnder

My state space consists of the amount that the automated stock trading agent has, the price and number of stocks, and the technical indicators array. So it has the following dimension

self.state_dim = 1 + 2 + 3 * stock_dim + self.tech_ary.shape[1]

So the algorithm takes in the state space of this dimension in the form of gymnasium spaces

self.observation_space = gym.spaces.Box(
            low=-3000, high=3000, shape=(self.state_dim,), dtype=np.float32
        )

I do realize the mistake that I am doing, and the transformer would take in a different form of the state vector. I was able to dig up the shape of the transformer, but not sure how I am going to integrate it into my environment

init_state = state = [
     np.zeros([100, 64], np.float32) for _ in range(num_transformers) ]

And during the episode, update it as

state = [
        np.concatenate([state[i], [state_out[i]]], axis=0)[1:]
        for i in range(num_transformers)
    ]

Do I need to do these changes? Also, I don’t understand how this state dimensions, can you explain me or point to resources to explain this? Thanks in advance

@Athe-kunal , alright this helps. So, the observation state does not play a role here.
It is about the state of the model (basically the model’s memory) as mentioned above and you already figured out how this state could look like for the GTrXLNet.

If you look into the documentation of the compute_action() and compute_single_action()-methods of the Algorithm class you see that they take in an observation (that is the observation state from the POMDP) and the last state of the Transformer. The function then returns not only the action, but also the Transformer’s next state.

So you get the last state of the Transformer and return it to it:

action, state, _ = algo.compute_single_action(obs, state)

The only question that remains is: how to get the initial state of the transformer. This you can call by model.get_initial_states()

For the GTrXLNet this is defined as:

def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
        return [
            np.zeros(self.gtrxl.view_requirements["state_in_{}".format(i)].space.shape)
            for i in range(self.gtrxl.num_transformer_units)
        ]

The easiest way to get this state is possibly via:

init_state = algo.get_policy().get_initial_state()
2 Likes

Hi @Lars_Simon_Zehnder
Sorry for the delayed response, I got caught up. Thank you for the extensive reply to my question and I understood your point. However, I will be getting the state space from my RL environment. But as per your suggestion, I need to get the state from the RL algorithm, like RLlib PPO Agent. How can I pass the agent to my environment (as an argument?) to get the initial state space?
Also, please note that I am using Ray tune for hyperparameter optimization and I am just passing the trainable as “PPO”. For instance, I would need the num_transformers to build the state array. So how can I facilitate the interaction between the agent and the environment and use the model config like PPOConfig from Ray tune to the environment? Should I pass it as an argument? Or are there some other ways?
Thanks in advance

@Athe-kunal , the number of transformer units can be set in the configuration of your algorithm:

from ray.rllib.algorithms.ppo.ppo import PPOConfig

config = (
      PPOConfig()
      .training(
            "model": {
                "attention_num_transformer_units": 1, # 1 unit is the default! 
            },
      )
)

See for an example of how to set up the GTrXL net here.

Got it, Thanks @Lars_Simon_Zehnder

Adding this post by Eric for reference: [rllib] How does the use_lstm option work? · Issue #2536 · ray-project/ray · GitHub