Closed sjfleming closed 9 months ago
Results so far:
FullyConnectedNetwork(
(network): Sequential(
(0): Linear(in_features=128, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=32, bias=True)
)
)
FullyConnectedNetwork(
(network): Sequential(
(0): Linear(in_features=32, out_features=16, bias=True)
(1): ReLU()
(2): Linear(in_features=16, out_features=3, bias=True)
)
)
The loss function is just how well we can classify: here we use a CrossEntropyLoss
We see that when we do NOT have a gradient reversal layer, then the latent space pulls the classes apart, to make them easier to classify accurately:
And if we train a separate classifier after we have the latent space, for the purposes of validation, then we can get 100% accuracy.
We see that when we DO use a gradient reversal layer, implemented like this
# from https://github.com/janfreyberg/pytorch-revgrad/blob/master/src/pytorch_revgrad/module.py
class RevGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, alpha_):
ctx.save_for_backward(input_, alpha_)
output = input_
return output
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
grad_input = None
_, alpha_ = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = -grad_output * alpha_
return grad_input, None
class GradientReversal(torch.nn.Module):
def __init__(self, alpha=1., *args, **kwargs):
"""A gradient reversal layer.
This layer has no parameters, and simply reverses the gradient
in the backward pass.
"""
super().__init__(*args, **kwargs)
self._alpha = torch.tensor(alpha, requires_grad=False)
def forward(self, input_):
return RevGrad.apply(input_, self._alpha)
# set up gradient reversal layer
gradient_reversal = GradientReversal()
# encode data into latent space
z = encoder(data)
# perform gradient reversal
if do_gradient_reversal:
z = gradient_reversal(z)
# classify data points based on latent space
class_logit_probs = classifier(z)
# loss function
loss = loss_fn(class_logit_probs, true_classes)
then we get a latent space where the labels are much better mixed, at least in a PCA visualization:
However, at this point, it seems like a separate classifier trained to label the latent space for the purposes of validation can still get an accuracy of 97%.
This is confusing to me at this point, as I would think it should do much worse.
@ImXman feel free to test this out and let's try to figure out why the classification accuracy is still so high...
Maybe I have not trained to convergence. The learning curve is not at all smooth. That may indicate some other problem with how I'm doing things.
Instead of training a new neural network based classifier, can you validate with some conventionl ML models, like LR, SVM, or RandomForest. I doubt a SVM classifier would accurately classify these 3 classes in validation set.
It's possible that the loss function needs to have something else... not just classification accuracy. There's no "reconstruction loss" or anything like that currently. So the classifier and encoder are just fighting each other all the time, and there's no other part of the loss function to drive it toward anything in particular.
Yeah @ImXman that's a good point, we could try other classifiers. But here the validation classifier is just the same architecture as the one used for training. So I'd think that if, during training, we tell the model "you have to learn to do a bad job with this classifier", then the classifier should not be able to do a good job...
Maybe I should just look at validation accuracy using the initial trained classifier. Maybe it's not fair to train a new validation classifier using the fixed latent space. But it should be fair....... maybe it's just too powerful for this data.
Yeah the classification accuracy from the classifier used during training is only 26%. But this also doesn't make much sense... it's much too low... if you just guessed "class 0" for the whole dataset, you'd get a 75% accuracy due to the class imbalance built into the simulated dataset.
I also wonder use macroF1 score to evaluate instead overall accuracy. You said the data is imbalanced. It looks like the class 0 is dominant.
Yeah alright if I use a less powerful classifier for validation, like a linear classifier (no hidden layers or nonlinearities), then I only get 72% accuracy with the validation classifier.
FullyConnectedNetwork(
(network): Sequential(
(0): Linear(in_features=32, out_features=3, bias=True)
)
)
But without gradient reversal, that same kind of linear validation classifier can still get 100% accuracy.
Yeah @ImXman maybe you can add the macroF1 score for validation if you get a chance? I didn't think accuracy was really the best thing to look at...
I got F1 score for each class as [0.9168254 , 0.64220183, 0.48979592]. Class 0 is the dominant class and any classifier can assign every sample as 0 to get the maximum accuracy.
z = get_latent(dataset=dataset,encoder=encoder).detach()
z = np.array(z)
classifier = LogisticRegression(random_state=0)
classifier.fit(z, y)
pred = classifier.predict(z)
f1_score(y, pred, average=None)
Well, to me it seems like this could still be working better...
for example, if I do a test where z
is truly random, then I get
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
z_random = torch.randn(z.shape).detach().numpy()
classifier = LogisticRegression(random_state=0)
classifier.fit(z_random, data_classes)
pred = classifier.predict(z_random)
f1_score(data_classes, pred, average=None)
array([0.86234357, 0. , 0. ])
Either gradient reversal is not working properly, the training has not converged, or something about the setup of this test is wrong... what do you think @ImXman ?
(We can write a more targeted unit test to ensure that gradient reversal is working. Then we can eliminate that hypothesis...)
If you use the truly random z to train a decoder and aim to reconstruct it to real data, how much can the decoder achieve? I think the encoder has dual roles here. Not just fooling the classifier but also preserving the intrinsic structure, so the decoder can reconstruct data from the vector z. I do think gradient reversal may be ineffective.
Well, so, as you say, the decoder would not be able to work if the class information were truly stripped from z
in the case of an autoencoder.
BUT, for that first test, it is not actually an autoencoder. There is no decoder, and no reconstruction error in the loss function.
So I think something about that test might be kind of sub-optimal. In particular, I know that the optimal solution is actually just to have an encoder which does encoder(x) = 0
.
So here is a more complicated test case which is actually an autoencoder. There is a decoder and a reconstruction MSE loss in the loss function. And to avoid the problem you mentioned (where the reconstruction CANNOT work if class information is stripped from z
), I have included an artificial "latent space algebra" step, where something different gets added to the latent space for each class.
# do latent space algebra
# this is just something to make it so that the distributions could, in theory,
# be decoded back to the original space even if the latent space removes class information
z_modified = z + torch.ones_like(z) * true_classes.float().unsqueeze(dim=-1)
# reconstruct data
reconstruction = decoder(z_modified)
That way, the reconstruction could, in theory, still be very good, even if the class information is removed from z
.
This actually seems to work. Now we can get a validation classifier accuracy matching random.
See here:
Now, using gradient reversal, z
looks like this
And the validation LogisticRegression
classifier F1 accuracies per class are
array([0.86283438, 0. , 0. ])
while totally random data gives
array([0.86267806, 0.01081081, 0. ])
fpr=dict()
tpr=dict()
roc_auc=dict()
data_classes = np.array(data_classes)
for c in [1,2]:
X = z[np.logical_or(data_classes==0,data_classes==c),:]
y = data_classes[np.logical_or(data_classes==0,data_classes==c)]
y[y==c]=1
random_state = np.random.RandomState(0)
# shuffle and split training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)
# Learn to predict each class against the other
classifier = LogisticRegression(random_state=random_state)
y_score = classifier.fit(X_train, y_train).decision_function(X_test)
fpr[c], tpr[c], _ = roc_curve(y_test, y_score, pos_label=classifier.classes_[1])
roc_auc[c] = auc(fpr[c], tpr[c])
colors = list(sns.color_palette("Paired"))+list(sns.color_palette("hls", 8))
plt.figure()
for i, color in zip(range(2), colors):
plt.plot(
fpr[[1,2][i]],
tpr[[1,2][i]],
color=color,
lw=1,
label="{0}".format([1,2][i]),
)
plt.grid(False)
plt.plot([0, 1], [0, 1], "k--", lw=1)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("")
legend = plt.legend(loc="lower right",prop={'size': 4.5})
legend.get_frame().set_facecolor('none')
plt.show()
We appear to have something working for now, so I will close this
We want to include a test to ensure gradient reversal is working.
This should be a real test in
pytest
.For now I have included a notebook called
gradient_reversal_testing.ipynb
in thesf_gradient_reversal_testing
branch to play around with some ideas on how such a test might work.It is not yet complete or conclusive.