Open SXQ233 opened 2 years ago
I added a program to calculate the training accuracy in the hsic_train function:
**def hsic_data(hiddens, h_data, h_target):
train_num_correct = 0 for batch_idx, (data, target) in pbar: id, pred = output.max(1) a = np.array(pred.cpu().detach().numpy(), dtype=np.int) b = np.array(target.cpu().detach().numpy(), dtype=np.int) num_correct = np.sum(a == b) acc = num_correct / config_dict['batch_size'] train_num_correct += num_correct print("\n train_acc: ", train_num_correct/n_data) return batch_log**
Run train-hsicbt-api.py and the results are as follows:
Train Epoch: 0 [ 60032/60032 (100%)] H_hx:29.0176 H_hy:7.9819 acc:0.0000: 100%|█████| 469/469.0 [02:35<00:00, 3.28it/s] train_acc: 0.01890658315565032 Train Epoch: 1 [ 60032/60032 (100%)] H_hx:26.0663 H_hy:8.3982 acc:0.0078: 100%|█████| 469/469.0 [02:35<00:00, 3.25it/s] train_acc: 0.007762526652452025 Train Epoch: 2 [ 60032/60032 (100%)] H_hx:25.1033 H_hy:8.4767 acc:0.0000: 100%|█████| 469/469.0 [02:36<00:00, 3.15it/s] train_acc: 0.006163379530916844 Train Epoch: 3 [ 60032/60032 (100%)] H_hx:24.0530 H_hy:8.4954 acc:0.0000: 100%|█████| 469/469.0 [02:42<00:00, 3.23it/s] train_acc: 0.006463219616204691 Train Epoch: 4 [ 60032/60032 (100%)] H_hx:24.0890 H_hy:8.5384 acc:0.0000: 100%|█████| 469/469.0 [02:38<00:00, 3.21it/s] train_acc: 0.006396588486140725 Process finished with exit code 0
Why is the accuracy rate inconsistent with that described in the paper? Is there something wrong with my program?
**def hsic_data(hiddens, h_data, h_target):
Train Epoch: 0 [ 60032/60032 (100%)] H_hx:29.0176 H_hy:7.9819 acc:0.0000: 100%|█████| 469/469.0 [02:35<00:00, 3.28it/s] train_acc: 0.01890658315565032 Train Epoch: 1 [ 60032/60032 (100%)] H_hx:26.0663 H_hy:8.3982 acc:0.0078: 100%|█████| 469/469.0 [02:35<00:00, 3.25it/s] train_acc: 0.007762526652452025 Train Epoch: 2 [ 60032/60032 (100%)] H_hx:25.1033 H_hy:8.4767 acc:0.0000: 100%|█████| 469/469.0 [02:36<00:00, 3.15it/s] train_acc: 0.006163379530916844 Train Epoch: 3 [ 60032/60032 (100%)] H_hx:24.0530 H_hy:8.4954 acc:0.0000: 100%|█████| 469/469.0 [02:42<00:00, 3.23it/s] train_acc: 0.006463219616204691 Train Epoch: 4 [ 60032/60032 (100%)] H_hx:24.0890 H_hy:8.5384 acc:0.0000: 100%|█████| 469/469.0 [02:38<00:00, 3.21it/s] train_acc: 0.006396588486140725 Process finished with exit code 0