haofeixu / aanet

[CVPR'20] AANet: Adaptive Aggregation Network for Efficient Stereo Matching
Apache License 2.0
534 stars 102 forks source link

can you share your training log #81

Closed q5390498 closed 2 years ago

q5390498 commented 2 years ago

Thank you for your great work. I have tried use the code to train with custom dataset, but the trained network alway output 0, the disp_loss is quickly converged around 0.1. could you share us the log you trained? thank you very much. @haofeixu @all

below is my training log when I use a small dataset to train, is there any abnormal things happen?

[2022-02-08 08:45:08,330] => 100 training samples found in the training set /opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py:475: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 12, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( [2022-02-08 08:45:10,523] => Loading pretrained AANet: pretrained/aanet+_sceneflow-d3e13ef0.pth [2022-02-08 08:45:10,754] => Number of trainable parameters: 8514130 [2022-02-08 08:45:10,770] => Start training... /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3499: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. warnings.warn( /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3825: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details. warnings.warn( [2022-02-08 08:45:15,963] cur_loss: 84.850 [2022-02-08 08:45:18,179] cur_loss: 59.475 [2022-02-08 08:45:20,229] cur_loss: 37.604 [2022-02-08 08:45:22,276] cur_loss: 5.961 [2022-02-08 08:45:24,318] cur_loss: 0.305 [2022-02-08 08:45:26,360] cur_loss: 0.130 [2022-02-08 08:45:28,412] cur_loss: 0.098 [2022-02-08 08:45:30,464] cur_loss: 0.120 [2022-02-08 08:45:32,511] cur_loss: 0.117 [2022-02-08 08:45:34,553] cur_loss: 0.133 [2022-02-08 08:45:36,181] Epoch: [ 1/200] [ 10/ 25] time: 25.41s disp_loss: 0.133 [2022-02-08 08:45:36,598] cur_loss: 0.095 [2022-02-08 08:45:38,647] cur_loss: 0.124 [2022-02-08 08:45:40,705] cur_loss: 0.094 [2022-02-08 08:45:42,752] cur_loss: 0.115 [2022-02-08 08:45:44,795] cur_loss: 0.101 [2022-02-08 08:45:46,841] cur_loss: 0.093 [2022-02-08 08:45:48,884] cur_loss: 0.099 [2022-02-08 08:45:50,932] cur_loss: 0.110 [2022-02-08 08:45:52,974] cur_loss: 0.095 [2022-02-08 08:45:55,018] cur_loss: 0.095 [2022-02-08 08:45:56,643] Epoch: [ 1/200] [ 20/ 25] time: 20.46s disp_loss: 0.095 [2022-02-08 08:45:57,060] cur_loss: 0.102 [2022-02-08 08:45:59,103] cur_loss: 0.115 [2022-02-08 08:46:01,143] cur_loss: 0.094 [2022-02-08 08:46:03,180] cur_loss: 0.116 [2022-02-08 08:46:05,216] cur_loss: 0.143 [2022-02-08 08:46:10,908] cur_loss: 0.108 [2022-02-08 08:46:12,967] cur_loss: 0.102 [2022-02-08 08:46:15,004] cur_loss: 0.093 [2022-02-08 08:46:17,046] cur_loss: 0.100 [2022-02-08 08:46:19,091] cur_loss: 0.089 [2022-02-08 08:46:20,719] Epoch: [ 2/200] [ 5/ 25] time: 13.18s disp_loss: 0.089 [2022-02-08 08:46:21,135] cur_loss: 0.101 [2022-02-08 08:46:23,172] cur_loss: 0.114 [2022-02-08 08:46:25,208] cur_loss: 0.116 [2022-02-08 08:46:27,253] cur_loss: 0.117 [2022-02-08 08:46:29,297] cur_loss: 0.107 [2022-02-08 08:46:31,338] cur_loss: 0.127 [2022-02-08 08:46:33,379] cur_loss: 0.085 [2022-02-08 08:46:35,417] cur_loss: 0.091 [2022-02-08 08:46:37,459] cur_loss: 0.103 [2022-02-08 08:46:39,500] cur_loss: 0.092 [2022-02-08 08:46:41,127] Epoch: [ 2/200] [ 15/ 25] time: 20.41s disp_loss: 0.092 [2022-02-08 08:46:41,541] cur_loss: 0.166 [2022-02-08 08:46:43,579] cur_loss: 0.097 [2022-02-08 08:46:45,625] cur_loss: 0.142 [2022-02-08 08:46:47,671] cur_loss: 0.105 [2022-02-08 08:46:49,709] cur_loss: 0.103 [2022-02-08 08:46:51,747] cur_loss: 0.128 [2022-02-08 08:46:53,784] cur_loss: 0.104 [2022-02-08 08:46:55,824] cur_loss: 0.135 [2022-02-08 08:46:57,861] cur_loss: 0.094 [2022-02-08 08:46:59,900] cur_loss: 0.108 [2022-02-08 08:47:01,525] Epoch: [ 2/200] [ 25/ 25] time: 20.40s disp_loss: 0.108 [2022-02-08 08:47:05,687] cur_loss: 0.099 [2022-02-08 08:47:07,731] cur_loss: 0.110 [2022-02-08 08:47:09,768] cur_loss: 0.138 [2022-02-08 08:47:11,809] cur_loss: 0.157 [2022-02-08 08:47:13,840] cur_loss: 0.115 [2022-02-08 08:47:15,875] cur_loss: 0.100 [2022-02-08 08:47:17,914] cur_loss: 0.105 [2022-02-08 08:47:19,957] cur_loss: 0.118 [2022-02-08 08:47:22,000] cur_loss: 0.098 [2022-02-08 08:47:24,035] cur_loss: 0.096 [2022-02-08 08:47:25,654] Epoch: [ 3/200] [ 10/ 25] time: 23.38s disp_loss: 0.096 [2022-02-08 08:47:26,068] cur_loss: 0.106 [2022-02-08 08:47:28,103] cur_loss: 0.081 [2022-02-08 08:47:30,138] cur_loss: 0.106 [2022-02-08 08:47:32,179] cur_loss: 0.098 [2022-02-08 08:47:34,215] cur_loss: 0.105 [2022-02-08 08:47:36,248] cur_loss: 0.100 [2022-02-08 08:47:38,293] cur_loss: 0.090 [2022-02-08 08:47:40,332] cur_loss: 0.098 [2022-02-08 08:47:42,371] cur_loss: 0.099 [2022-02-08 08:47:44,404] cur_loss: 0.173 [2022-02-08 08:47:46,028] Epoch: [ 3/200] [ 20/ 25] time: 20.37s disp_loss: 0.173 [2022-02-08 08:47:46,440] cur_loss: 0.102 [2022-02-08 08:47:48,475] cur_loss: 0.102 [2022-02-08 08:47:50,510] cur_loss: 0.116 [2022-02-08 08:47:52,545] cur_loss: 0.100 [2022-02-08 08:47:54,581] cur_loss: 0.126 [2022-02-08 08:48:00,155] cur_loss: 0.106 [2022-02-08 08:48:02,237] cur_loss: 0.111 [2022-02-08 08:48:04,284] cur_loss: 0.086 [2022-02-08 08:48:06,334] cur_loss: 0.081 [2022-02-08 08:48:08,378] cur_loss: 0.137 [2022-02-08 08:48:10,010] Epoch: [ 4/200] [ 5/ 25] time: 12.94s disp_loss: 0.137 [2022-02-08 08:48:10,424] cur_loss: 0.117 [2022-02-08 08:48:12,468] cur_loss: 0.090 [2022-02-08 08:48:14,507] cur_loss: 0.102 [2022-02-08 08:48:16,549] cur_loss: 0.125 [2022-02-08 08:48:18,592] cur_loss: 0.100 [2022-02-08 08:48:20,630] cur_loss: 0.100 [2022-02-08 08:48:22,669] cur_loss: 0.126 [2022-02-08 08:48:24,706] cur_loss: 0.101 [2022-02-08 08:48:26,753] cur_loss: 0.098 [2022-02-08 08:48:28,792] cur_loss: 0.099 [2022-02-08 08:48:30,418] Epoch: [ 4/200] [ 15/ 25] time: 20.41s disp_loss: 0.099 [2022-02-08 08:48:30,830] cur_loss: 0.114 [2022-02-08 08:48:32,869] cur_loss: 0.082 [2022-02-08 08:48:34,906] cur_loss: 0.103 [2022-02-08 08:48:36,942] cur_loss: 0.098 [2022-02-08 08:48:38,984] cur_loss: 0.104 [2022-02-08 08:48:41,025] cur_loss: 0.103 [2022-02-08 08:48:43,068] cur_loss: 0.119 [2022-02-08 08:48:45,107] cur_loss: 0.161 [2022-02-08 08:48:47,159] cur_loss: 0.101 [2022-02-08 08:48:49,201] cur_loss: 0.167 [2022-02-08 08:48:50,829] Epoch: [ 4/200] [ 25/ 25] time: 20.41s disp_loss: 0.167 [2022-02-08 08:48:56,882] cur_loss: 0.111 [2022-02-08 08:48:58,936] cur_loss: 0.082 [2022-02-08 08:49:00,979] cur_loss: 0.096 [2022-02-08 08:49:03,021] cur_loss: 0.185 [2022-02-08 08:49:05,058] cur_loss: 0.142 [2022-02-08 08:49:07,103] cur_loss: 0.088 [2022-02-08 08:49:09,147] cur_loss: 0.099 [2022-02-08 08:49:11,194] cur_loss: 0.109 [2022-02-08 08:49:13,241] cur_loss: 0.101 [2022-02-08 08:49:15,281] cur_loss: 0.130 [2022-02-08 08:49:16,915] Epoch: [ 5/200] [ 10/ 25] time: 23.52s disp_loss: 0.130 [2022-02-08 08:49:17,333] cur_loss: 0.112 [2022-02-08 08:49:19,379] cur_loss: 0.093 [2022-02-08 08:49:21,421] cur_loss: 0.111 [2022-02-08 08:49:23,462] cur_loss: 0.134 [2022-02-08 08:49:25,505] cur_loss: 0.108 [2022-02-08 08:49:27,544] cur_loss: 0.109 [2022-02-08 08:49:29,586] cur_loss: 0.102 [2022-02-08 08:49:31,633] cur_loss: 0.101 [2022-02-08 08:49:33,674] cur_loss: 0.092 [2022-02-08 08:49:35,714] cur_loss: 0.108 [2022-02-08 08:49:37,348] Epoch: [ 5/200] [ 20/ 25] time: 20.43s disp_loss: 0.108 [2022-02-08 08:49:37,763] cur_loss: 0.104 [2022-02-08 08:49:39,804] cur_loss: 0.096 [2022-02-08 08:49:41,847] cur_loss: 0.099 [2022-02-08 08:49:43,883] cur_loss: 0.112 [2022-02-08 08:49:45,921] cur_loss: 0.121 [2022-02-08 08:49:51,606] cur_loss: 0.101 [2022-02-08 08:49:53,676] cur_loss: 0.192 [2022-02-08 08:49:55,715] cur_loss: 0.114 [2022-02-08 08:49:57,758] cur_loss: 0.173 [2022-02-08 08:49:59,806] cur_loss: 0.111 [2022-02-08 08:50:01,439] Epoch: [ 6/200] [ 5/ 25] time: 12.84s disp_loss: 0.111 [2022-02-08 08:50:01,873] cur_loss: 0.104 [2022-02-08 08:50:03,921] cur_loss: 0.086 [2022-02-08 08:50:05,970] cur_loss: 0.106 [2022-02-08 08:50:08,020] cur_loss: 0.098 [2022-02-08 08:50:10,081] cur_loss: 0.096 [2022-02-08 08:50:12,130] cur_loss: 0.096 [2022-02-08 08:50:14,167] cur_loss: 0.100 [2022-02-08 08:50:16,211] cur_loss: 0.104 [2022-02-08 08:50:18,253] cur_loss: 0.091 [2022-02-08 08:50:20,295] cur_loss: 0.103 [2022-02-08 08:50:21,922] Epoch: [ 6/200] [ 15/ 25] time: 20.48s disp_loss: 0.103 [2022-02-08 08:50:22,339] cur_loss: 0.102 [2022-02-08 08:50:24,379] cur_loss: 0.114 [2022-02-08 08:50:26,417] cur_loss: 0.084 [2022-02-08 08:50:28,459] cur_loss: 0.128 [2022-02-08 08:50:30,500] cur_loss: 0.089 [2022-02-08 08:50:32,546] cur_loss: 0.105 [2022-02-08 08:50:34,592] cur_loss: 0.101 [2022-02-08 08:50:36,636] cur_loss: 0.098 [2022-02-08 08:50:38,688] cur_loss: 0.102 [2022-02-08 08:50:40,733] cur_loss: 0.125 [2022-02-08 08:50:42,361] Epoch: [ 6/200] [ 25/ 25] time: 20.44s disp_loss: 0.125 [2022-02-08 08:50:46,862] cur_loss: 0.103 [2022-02-08 08:50:48,899] cur_loss: 0.189 [2022-02-08 08:50:50,941] cur_loss: 0.110 [2022-02-08 08:50:52,983] cur_loss: 0.117 [2022-02-08 08:50:55,020] cur_loss: 0.085 [2022-02-08 08:50:57,056] cur_loss: 0.094 [2022-02-08 08:50:59,092] cur_loss: 0.130 [2022-02-08 08:51:01,128] cur_loss: 0.097 [2022-02-08 08:51:03,167] cur_loss: 0.145 [2022-02-08 08:51:05,200] cur_loss: 0.094 [2022-02-08 08:51:06,824] Epoch: [ 7/200] [ 10/ 25] time: 23.69s disp_loss: 0.094 [2022-02-08 08:51:07,238] cur_loss: 0.119 [2022-02-08 08:51:09,279] cur_loss: 0.091 [2022-02-08 08:51:11,320] cur_loss: 0.091 [2022-02-08 08:51:13,358] cur_loss: 0.089 [2022-02-08 08:51:15,396] cur_loss: 0.097 [2022-02-08 08:51:17,436] cur_loss: 0.101 [2022-02-08 08:51:19,484] cur_loss: 0.112 [2022-02-08 08:51:21,525] cur_loss: 0.079 [2022-02-08 08:51:23,565] cur_loss: 0.120 [2022-02-08 08:51:25,603] cur_loss: 0.101 [2022-02-08 08:51:27,229] Epoch: [ 7/200] [ 20/ 25] time: 20.40s disp_loss: 0.101 [2022-02-08 08:51:27,641] cur_loss: 0.098 [2022-02-08 08:51:29,690] cur_loss: 0.099 [2022-02-08 08:51:31,730] cur_loss: 0.107 [2022-02-08 08:51:33,768] cur_loss: 0.137 [2022-02-08 08:51:35,805] cur_loss: 0.121 [2022-02-08 08:51:41,497] cur_loss: 0.104 [2022-02-08 08:51:43,559] cur_loss: 0.108 [2022-02-08 08:51:45,594] cur_loss: 0.088 [2022-02-08 08:51:47,641] cur_loss: 0.105 [2022-02-08 08:51:49,685] cur_loss: 0.108 [2022-02-08 08:51:51,310] Epoch: [ 8/200] [ 5/ 25] time: 12.95s disp_loss: 0.108 [2022-02-08 08:51:51,724] cur_loss: 0.103 [2022-02-08 08:51:53,760] cur_loss: 0.121 [2022-02-08 08:51:55,799] cur_loss: 0.110 [2022-02-08 08:51:57,839] cur_loss: 0.170 [2022-02-08 08:51:59,879] cur_loss: 0.136 [2022-02-08 08:52:01,925] cur_loss: 0.118 [2022-02-08 08:52:03,966] cur_loss: 0.108 [2022-02-08 08:52:06,004] cur_loss: 0.111 [2022-02-08 08:52:08,044] cur_loss: 0.098 [2022-02-08 08:52:10,090] cur_loss: 0.135 [2022-02-08 08:52:11,718] Epoch: [ 8/200] [ 15/ 25] time: 20.41s disp_loss: 0.135 [2022-02-08 08:52:12,134] cur_loss: 0.107 [2022-02-08 08:52:14,178] cur_loss: 0.101 [2022-02-08 08:52:16,223] cur_loss: 0.091 [2022-02-08 08:52:18,265] cur_loss: 0.111 [2022-02-08 08:52:20,308] cur_loss: 0.117 [2022-02-08 08:52:22,347] cur_loss: 0.095 [2022-02-08 08:52:24,386] cur_loss: 0.089 [2022-02-08 08:52:26,428] cur_loss: 0.098 [2022-02-08 08:52:28,469] cur_loss: 0.101 [2022-02-08 08:52:30,512] cur_loss: 0.095 [2022-02-08 08:52:32,145] Epoch: [ 8/200] [ 25/ 25] time: 20.43s disp_loss: 0.095 [2022-02-08 08:52:37,867] cur_loss: 0.093 [2022-02-08 08:52:39,931] cur_loss: 0.089 [2022-02-08 08:52:41,978] cur_loss: 0.118 [2022-02-08 08:52:44,025] cur_loss: 0.083 [2022-02-08 08:52:46,066] cur_loss: 0.097 [2022-02-08 08:52:48,111] cur_loss: 0.103 [2022-02-08 08:52:50,152] cur_loss: 0.106 [2022-02-08 08:52:52,198] cur_loss: 0.099 [2022-02-08 08:52:54,243] cur_loss: 0.091 [2022-02-08 08:52:56,287] cur_loss: 0.122

q5390498 commented 2 years ago

I print the netword prediction result, found that it is normal value(non-zero) in very begin, but it is quickly that all output is zero. what reasons can be conduct this phenomenon probably?

` [2022-02-08 09:20:52,391] cur_loss: 84.850 [2022-02-08 09:20:54,613] cur_loss: 59.421 [2022-02-08 09:20:56,657] cur_loss: 37.729 [2022-02-08 09:20:58,700] cur_loss: 5.955 pred: tensor([[[ 60.3958, 57.9909, 55.7744, ..., 120.9461, 118.6578, 124.5530], [ 57.4691, 56.0375, 54.9852, ..., 120.7754, 118.1667, 121.8229], [ 57.2455, 54.5802, 54.8936, ..., 125.5730, 122.2032, 124.1629], ..., [136.7134, 136.7585, 135.2419, ..., 176.4765, 176.3001, 167.1669], [134.0021, 130.6874, 132.8948, ..., 173.2103, 170.2460, 163.4151], [142.1232, 130.4962, 136.2892, ..., 168.4543, 164.3932, 154.5437]],

    [[ 35.1972,  36.5312,  38.4757,  ...,  77.5482,  77.8897,  78.8743],
     [ 35.8069,  38.0379,  39.2738,  ...,  75.5471,  75.5091,  78.4552],
     [ 38.3917,  39.2442,  40.5765,  ...,  75.8006,  75.4843,  78.0489],
     ...,
     [144.0496, 156.2293, 153.2752,  ..., 135.3938, 128.4638, 131.3442],
     [147.0923, 154.5704, 153.6310,  ..., 126.2524, 122.4596, 131.9489],
     [143.0517, 150.0596, 151.9824,  ..., 127.7324, 128.8004, 139.6858]],

    [[ 46.7300,  46.4000,  45.8144,  ...,  49.4122,  46.3503,  43.4428],
     [ 46.4703,  45.8650,  45.8478,  ...,  50.3494,  48.6170,  46.1489],
     [ 46.0063,  44.8311,  45.6043,  ...,  51.4486,  50.0680,  48.6887],
     ...,
     [166.9636, 169.2554, 158.5244,  ...,  93.4978,  93.9326,  95.7670],
     [165.7707, 166.1769, 159.4405,  ...,  90.3437,  90.5251,  93.7652],
     [161.5703, 163.0487, 163.8259,  ...,  91.9817,  94.8161,  96.8855]],

    [[ 38.9851,  37.0821,  35.8220,  ...,  58.3030,  50.8931,  42.7645],
     [ 37.5203,  35.6132,  35.4405,  ...,  61.0706,  56.8674,  40.9830],
     [ 37.2854,  34.9211,  35.0649,  ...,  59.9734,  57.2902,  50.2336],
     ...,
     [141.0944, 141.6223, 137.9460,  ...,  54.5146,  49.8646,  57.1323],
     [139.3730, 139.5795, 137.5088,  ...,  51.2052,  51.0434,  65.5106],
     [134.8223, 139.8197, 140.4545,  ...,  55.1063,  62.1394,  76.7297]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)

pred: tensor([[[ 40.2422, 40.3682, 40.5317, ..., 28.9843, 28.3127, 38.7624], [ 40.1813, 40.3274, 40.6232, ..., 28.8020, 26.8436, 30.5479], [ 40.0127, 39.8091, 40.5936, ..., 31.1058, 28.9942, 33.5787], ..., [153.3121, 156.9409, 152.2289, ..., 27.2218, 27.1960, 28.7927], [151.1767, 156.0058, 152.4844, ..., 28.0512, 29.0036, 29.6318], [150.5991, 149.9173, 155.7352, ..., 29.9039, 30.3969, 31.0580]],

    [[ 40.6195,  40.2955,  39.5429,  ...,  65.1209,  67.3683,  86.9687],
     [ 39.8015,  38.8467,  38.6322,  ...,  59.8054,  62.3455,  75.8580],
     [ 39.0024,  37.6076,  38.0337,  ...,  69.2202,  67.5764,  77.9588],
     ...,
     [ 70.1766,  75.0723,  73.6379,  ...,  92.3470,  92.3131,  88.9738],
     [ 76.8734,  75.2559,  74.9278,  ...,  90.8862,  90.8725,  87.8768],
     [ 71.8072,  76.8179,  76.4584,  ...,  90.5880,  89.3136,  90.2819]],

    [[ 58.1412,  55.9844,  58.0036,  ...,  39.2926,  30.7917,  36.1742],
     [ 60.6683,  58.0760,  57.4861,  ...,  36.8728,  39.1118,  35.9307],
     [ 60.9955,  58.8031,  58.7510,  ...,  39.8082,  39.2346,  44.6219],
     ...,
     [153.8498, 151.0643, 149.7025,  ...,  12.8411,  13.3434,  13.2162],
     [147.4372, 150.5965, 149.0699,  ...,  12.6299,  13.1242,  13.3502],
     [154.4095, 144.5931, 154.6027,  ...,  13.0644,  13.2323,  13.4366]],

    [[ 39.8573,  37.9619,  36.9057,  ...,  42.1044,  40.2077,  38.0919],
     [ 38.7509,  36.6201,  36.5197,  ...,  40.4639,  40.0419,  38.8953],
     [ 38.3336,  36.4496,  36.7688,  ...,  39.2797,  39.6609,  40.1514],
     ...,
     [135.7857, 134.6154, 130.9226,  ...,  47.5453,  38.8995,  47.3962],
     [133.2933, 133.7791, 131.4800,  ...,  40.8988,  40.8693,  63.3618],
     [125.8759, 128.7854, 134.1934,  ...,  49.0825,  59.6753,  80.7865]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)

pred: tensor([[[ 4.4194, 3.2357, 0.7156, ..., 8.0940, 7.6042, 14.6319], [ 3.2366, 0.8199, 0.0000, ..., 7.5350, 8.9960, 7.4030], [ 0.0000, 0.0000, 0.0000, ..., 7.4874, 8.2638, 10.1045], ..., [153.2203, 151.7722, 149.0503, ..., 161.7336, 155.4834, 159.0959], [155.0462, 151.8277, 150.1341, ..., 152.6343, 156.0988, 158.0346], [138.7147, 151.4463, 153.9102, ..., 164.6467, 159.8742, 163.0121]],

    [[ 59.7358,  59.2537,  58.3708,  ...,   0.0000,   0.0000,   7.6454],
     [ 59.0069,  58.3980,  58.2597,  ...,   0.0000,   0.0000,   0.6457],
     [ 58.4347,  56.7561,  58.0880,  ...,   0.0000,   0.6698,   1.9930],
     ...,
     [102.2050, 103.8153, 103.2228,  ...,  10.0267,   9.6739,   5.5513],
     [106.9606, 107.4507, 105.4501,  ...,  11.8747,  10.8758,  10.4535],
     [112.8723, 111.1960, 109.1483,  ...,  15.4556,  13.3552,  10.8301]],

    [[ 19.4347,  21.1684,  23.7017,  ...,   0.0000,   0.0000,  19.9356],
     [ 23.6737,  23.8128,  24.1997,  ...,   0.0000,   0.0000,   3.3105],
     [ 24.8388,  24.0135,  25.7692,  ...,   0.0000,   0.0000,   5.0559],
     ...,
     [125.8863, 122.7824, 121.6627,  ...,   1.2391,   2.4334,   0.0000],
     [123.5106, 124.1855, 120.8899,  ...,   3.2927,   4.2453,   5.3043],
     [123.9110, 122.3902, 126.0438,  ...,   5.5579,   5.8330,   5.9820]],

    [[ 31.4857,  23.5120,  16.3238,  ...,  17.8765,  18.2859,  30.9049],
     [ 23.6049,  17.3306,  16.0096,  ...,  15.9916,  14.7557,  18.0256],
     [ 17.1505,  15.0129,  17.0562,  ...,  18.3809,  15.6091,  18.1920],
     ...,
     [124.2791, 122.0681, 123.2237,  ...,  43.1251,  41.7653,  47.6129],
     [126.5202, 123.0051, 123.7755,  ...,  40.4210,  41.6319,  59.4113],
     [125.0918, 127.3637, 126.1499,  ...,  48.7767,  56.3044,  72.6417]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)

pred: tensor([[[ 0.0000, 0.0000, 0.0000, ..., 7.7598, 6.5291, 9.6075], [ 0.0000, 0.0000, 0.0000, ..., 8.8233, 7.7605, 7.0710], [ 0.0000, 0.0000, 0.0000, ..., 9.2834, 9.4017, 8.5750], ..., [13.2084, 10.5547, 10.5498, ..., 0.0000, 0.0000, 0.0000], [12.2839, 9.8691, 9.9626, ..., 0.0000, 0.0000, 0.0000], [13.4774, 11.9916, 11.7469, ..., 0.0000, 0.0000, 0.0000]],

    [[12.0873,  5.2150,  0.0000,  ...,  0.0000,  0.0000, 14.3237],
     [ 9.5041,  2.6964,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 4.3040,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  1.6176],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  4.2367],
     [22.3126,  0.0000,  0.0000,  ...,  0.0000,  3.7081, 15.4229]],

    [[ 2.8575,  3.1858,  1.6512,  ..., 10.4605,  9.9471, 18.4150],
     [ 2.6541,  0.7860,  0.4965,  ...,  9.5509,  9.4959, 12.1756],
     [ 0.0000,  0.0000,  0.0000,  ..., 12.3192, 11.3461, 10.9253],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.2062,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  2.5944],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.2506,  0.0000,  5.2095]],

    [[11.0264,  7.9926,  4.4777,  ..., 14.1673, 24.5804, 56.6707],
     [ 7.7415,  5.2060,  5.0052,  ...,  1.9215,  9.1792, 28.5566],
     [ 4.4257,  4.1824,  5.7704,  ..., 11.4801, 10.9685, 26.1688],
     ...,
     [13.3924, 12.6206, 12.4879,  ...,  7.4414,  4.4492, 26.7554],
     [13.6897, 11.7324, 11.0781,  ..., 12.2885, 16.3833, 54.2250],
     [14.5986, 12.1807, 11.6549,  ..., 31.7092, 47.3135, 77.3596]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)

pred: tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 7.7891], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.8271]],

    [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
     [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
     [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
     ...,
     [  1.7379,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
     [  9.1776,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
     [ 28.4450,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

    [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,  12.1848,  75.6153],
     [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,  32.7908],
     [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   4.8962],
     ...,
     [111.9232, 108.7729, 108.7321,  ...,   0.0000,   0.0000,   0.0000],
     [111.6686, 109.6710, 109.7158,  ...,   0.0000,   0.0000,   0.0000],
     [118.3945, 114.3893, 110.5331,  ...,   0.0000,   0.0000,   0.0000]],

    [[  0.2929,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
     [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
     [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
     ...,
     [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
     [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   6.9630],
     [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,  25.9977]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)[2022-02-08 09:21:00,747] cur_loss: 0.313

[2022-02-08 09:21:02,795] cur_loss: 0.126 [2022-02-08 09:21:04,843] cur_loss: 0.097 [2022-02-08 09:21:06,884] cur_loss: 0.120 [2022-02-08 09:21:08,931] cur_loss: 0.117 [2022-02-08 09:21:10,978] cur_loss: 0.133 [2022-02-08 09:21:12,607] Epoch: [ 1/200] [ 10/ 25] time: 25.51s disp_loss: 0.133

pred: tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],

    [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 32.0340],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.2943],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.6711,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

    [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

    [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  9.7637],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 32.0763],
     [10.1621,  0.0000,  0.0000,  ...,  0.0000, 20.9321, 60.8215]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)

pred: tensor([[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.3589]],

    [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     ...,
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

    [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     ...,
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

    [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     ...,
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)

pred: tensor([[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],

    [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     ...,
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

    [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     ...,
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 6.0604]],

    [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     ...,
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)

pred: tensor([[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]],

    [[0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     ...,
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.]],

    [[0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     ...,
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.]],

    [[0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     ...,
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
   grad_fn=<SqueezeBackward1>)

pred: tensor([[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],

    [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     ...,
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

    [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     ...,
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.9761]],

    [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     ...,
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
     [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)

pred: tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 35.2818], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 5.8052], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],

    [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

    [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

    [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)[2022-02-08 09:21:13,025] cur_loss: 0.095

[2022-02-08 09:21:15,073] cur_loss: 0.124 [2022-02-08 09:21:17,117] cur_loss: 0.094 [2022-02-08 09:21:19,168] cur_loss: 0.115 [2022-02-08 09:21:21,215] cur_loss: 0.101 [2022-02-08 09:21:23,256] cur_loss: 0.093

pred: tensor([[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]],

    [[0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     ...,
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.]],

    [[0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     ...,
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.]],

    [[0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     ...,
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.],
     [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
   grad_fn=<SqueezeBackward1>)

pred: tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],

    [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

    [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

    [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 23.1598],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     ...,
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
     [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
   device='cuda:0', grad_fn=<SqueezeBackward1>)`
haofeixu commented 2 years ago

Sorry that the training log is no longer available due to the long time of this project. I would suggest you to first try our model on our dataset, and then on you dataset to see whether this is some problem regrading the dataset.