zhaoymn / mbeats

21 stars 7 forks source link

How to test trained neural network for mdrone #2

Open mjk1376 opened 11 months ago

mjk1376 commented 11 months ago

Hello , I downloaded your data from dropbox and unzipped it and did preprocessing procedure on data with your code. Eventually I had 280 gigabyte data and I ran your code for a proposed neural network.I have trained files and I want to test the result but I did not find any code for testing in your files and I did not have good vision about how to test it. Can you help me with that ?

zhaoymn commented 11 months ago

Hello,

You can just follow typical pytorch evaluation pipeline.

  1. First, create the model and load the weights

    model = CNN_LSTM()
    model.load_state_dict(torch.load('path_to_saved_model.pth'))
  2. set the model to eval mode

    model.eval()
  3. run the loop similar to validation in the training file

    with torch.no_grad():
    for i, (x_azimuth, x_elevation, labels) in enumerate(testloader):
        x_azimuth = x_azimuth.to(device)
        x_elevation = x_elevation.to(device)
        outputs = net(x_azimuth, x_elevation)

    and you can compare the outputs to labels to see the error.

Alternately, you can load the testing data using fft_dataset_abs.py as a reference.

Thanks.

MMWAVEMJK commented 11 months ago

I write below code: I change validation sequence in test sequence in fft_dataset_abs. if __name__ == __main__:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
testset = FFT_DATASET_2D(2)
batch_size = 32
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle = False, num_workers=4)

model = CNN_LSTM()
model = model.to(device)
test_criterion = nn.MSELoss(reduction='mean').to(device)
optimizer = optim.Adam(list(model.parameters()), lr=0.00003)

checkpoint = torch.load('no_dropout_26.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()

with torch.no_grad():
    test_loss = 0
    for i, (x_azimuth, x_elevation, labels) in enumerate(testloader):

        x_azimuth = x_azimuth.to(device)
        #print(x_azimuth)
        x_elevation = x_elevation.to(device)
        labels = labels.to(device)#, dtype=torch.float)
        #print(labels)
        outputs = model(x_azimuth, x_elevation)
        loss = torch.sqrt(test_criterion(outputs, labels)*3)
        test_loss += loss
    print(test_loss/i)`

the result was 0.0987 . is this valid result?

zhaoymn commented 11 months ago

No. The loss function I defined was an approximate of the translation error, so that I can get a general idea of how the network is performing during training. For accurate testing result, you will have to compare outputs to ground truth (i.e., labels in your code). You do not need optimizer or loss during testing.

MMWAVEMJK commented 10 months ago

Hello i test with below code and result showed in image . did i do it right ?

if name == 'main':

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
testset = FFT_DATASET_2D(2)
batch_size = 32
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle = False, num_workers=4)

model = CNN_LSTM()
model = model.to(device)
model.load_state_dict(torch.load('cnn_state_dict_27.pth'))

model.eval()

with torch.no_grad():
    test_loss = 0
    for i, (x_azimuth, x_elevation, labels) in enumerate(testloader):
        #print(i)
        x_azimuth = x_azimuth.to(device)#, dtype=torch.float)
        x_elevation = x_elevation.to(device)#, dtype=torch.float)
        labels = labels.to(device)#, dtype=torch.float)
        outputs = model(x_azimuth, x_elevation)
        print(torch.mean(torch.abs(labels-outputs)))
mjk1376 commented 10 months ago

my error result is this : Captureg

zhaoymn commented 10 months ago

yes. but the error is based on a batch of data if you use a dataloader. You can manually load a single sample to get the error. Also, you may want to convert the tensor to numpy for further processing.