gicsaw / ARAE_torch

MIT License
8 stars 4 forks source link

questions about the AE part #3

Open jiaweih14 opened 3 years ago

jiaweih14 commented 3 years ago

Hi, there

Thanks for the great work on molecule generation, I have read the __ARAE_train.py__ for a better understanding about your method. During the process I have some questions and I want to ask about.

Q1

In __ARAE_train.py__ line 187 to line 200

batch_x2 = batch_x[:, 1:]

        optimizer_AE.zero_grad()
        noise = torch.normal(mean=mean0, std=std).to(device)
        out_decoding = model.AE(batch_x, batch_l, noise)
        out2 = out_decoding[:, :-1]
        loss_AE = criterion_AE(
            out2.reshape(-1, Nfea), batch_x2.reshape(-1))
        loss_AE.backward(retain_graph=True)
        optimizer_AE.step()
        running_loss_AE += loss_AE.data

        Z_real = model.Enc(batch_x, batch_l)

you first take batch_x2 by removing the first number, which makes sense because the first character is always "<" based on what you wrote in data_char.py. But what I don't undersatnd is for the output of the autoencoder out_decoding, why do you remove the last element of character(out2 = out_decoding[:, :-1]).

Q2

Also, In __ARAE_train.py__ line 230 to line 232

_, out_num_AE = torch.max(out_decoding, 2) acc, acc2 = accu(out_num_AE, batch_x2, batch_l) print("reconstruction accuracy:", acc, acc2)

why did you choose the maximum possible character from out_decoding, but calculated the accuracy from batch_x2, the dimension of their features are not the same.(in your original settings, out_decoding has 110 as Nseq, but batch_x2 is 109 as you took batch_x2 = batch_x[:, 1:])

gicsaw commented 3 years ago

Sorry for the late response. Q1. First, consider the case where the sentence "\" is given as data. At this time, the input (batch_x) to enter the decoder is "\" At this time, the label to be predicted (batch_x2) is "ABCDE>". out2 corresponds to the probability of each character coming. out2[:,0] is the ouput corresponding to the input ">". If well trained, argmax(out2[:,0]): A argmax(out2[:,1]): B argmax(out2[:,2]): C argmax(out2[:,3]): D argmax(out2[:,4]): E argmax(out2[:,5]): > argmax(out2[:,6]): >

len(batch_x2) is 6, and len(out2) is 7 The last character of out2 is meaningless.

Q2. Your point is correct. It seems that it should have been compared to out2 instead of out_decoding. However, since batch_l contains the length of the actual sentence excluding padding (">>>>>"), the calculation probably worked as intended.

jiaweih14 commented 3 years ago

Thanks for the response. For Q1, why don't we just treat the label as "\<ABCDE>", because I think you add a "<" manully, so what's the point if you cut it off when calculating the loss.

gicsaw commented 3 years ago

New issue cannot be notified to me by e-mail.

The decoder works only when some characters are entered as input. So, at the beginning, we put "<" as the first input, and predict the next character ("A"). Your opinion is to predict "<" by putting a 0 vector as the first input? It is also possible.