broadinstitute / CellCap

Interpret perturbation responses from scRNA-seq perturbation experiments
BSD 3-Clause "New" or "Revised" License
2 stars 0 forks source link

Test gradient reversal #5

Closed sjfleming closed 9 months ago

sjfleming commented 1 year ago

We want to include a test to ensure gradient reversal is working.

This should be a real test in pytest.

For now I have included a notebook called gradient_reversal_testing.ipynb in the sf_gradient_reversal_testing branch to play around with some ideas on how such a test might work.

It is not yet complete or conclusive.

sjfleming commented 1 year ago

Results so far:

image

FullyConnectedNetwork(
  (network): Sequential(
    (0): Linear(in_features=32, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=3, bias=True)
  )
)
sjfleming commented 1 year ago

The loss function is just how well we can classify: here we use a CrossEntropyLoss

sjfleming commented 1 year ago

We see that when we do NOT have a gradient reversal layer, then the latent space pulls the classes apart, to make them easier to classify accurately: image

And if we train a separate classifier after we have the latent space, for the purposes of validation, then we can get 100% accuracy.

sjfleming commented 1 year ago

We see that when we DO use a gradient reversal layer, implemented like this

# from https://github.com/janfreyberg/pytorch-revgrad/blob/master/src/pytorch_revgrad/module.py

class RevGrad(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_, alpha_):
        ctx.save_for_backward(input_, alpha_)
        output = input_
        return output

    @staticmethod
    def backward(ctx, grad_output):  # pragma: no cover
        grad_input = None
        _, alpha_ = ctx.saved_tensors
        if ctx.needs_input_grad[0]:
            grad_input = -grad_output * alpha_
        return grad_input, None

class GradientReversal(torch.nn.Module):

    def __init__(self, alpha=1., *args, **kwargs):
        """A gradient reversal layer.
        This layer has no parameters, and simply reverses the gradient
        in the backward pass.
        """
        super().__init__(*args, **kwargs)
        self._alpha = torch.tensor(alpha, requires_grad=False)

    def forward(self, input_):
        return RevGrad.apply(input_, self._alpha)
# set up gradient reversal layer
gradient_reversal = GradientReversal()
# encode data into latent space
z = encoder(data)

# perform gradient reversal
if do_gradient_reversal:
    z = gradient_reversal(z)

# classify data points based on latent space
class_logit_probs = classifier(z)

# loss function
loss = loss_fn(class_logit_probs, true_classes)

then we get a latent space where the labels are much better mixed, at least in a PCA visualization:

image

However, at this point, it seems like a separate classifier trained to label the latent space for the purposes of validation can still get an accuracy of 97%.

This is confusing to me at this point, as I would think it should do much worse.

sjfleming commented 1 year ago

File is here https://github.com/broadinstitute/single-cell-compositional-perturbations/blob/sf_gradient_reversal_testing/gradient_reversal_testing.ipynb

@ImXman feel free to test this out and let's try to figure out why the classification accuracy is still so high...

sjfleming commented 1 year ago

Maybe I have not trained to convergence. The learning curve is not at all smooth. That may indicate some other problem with how I'm doing things.

ImXman commented 1 year ago

Instead of training a new neural network based classifier, can you validate with some conventionl ML models, like LR, SVM, or RandomForest. I doubt a SVM classifier would accurately classify these 3 classes in validation set.

sjfleming commented 1 year ago

It's possible that the loss function needs to have something else... not just classification accuracy. There's no "reconstruction loss" or anything like that currently. So the classifier and encoder are just fighting each other all the time, and there's no other part of the loss function to drive it toward anything in particular.

sjfleming commented 1 year ago

Yeah @ImXman that's a good point, we could try other classifiers. But here the validation classifier is just the same architecture as the one used for training. So I'd think that if, during training, we tell the model "you have to learn to do a bad job with this classifier", then the classifier should not be able to do a good job...

Maybe I should just look at validation accuracy using the initial trained classifier. Maybe it's not fair to train a new validation classifier using the fixed latent space. But it should be fair....... maybe it's just too powerful for this data.

sjfleming commented 1 year ago

Yeah the classification accuracy from the classifier used during training is only 26%. But this also doesn't make much sense... it's much too low... if you just guessed "class 0" for the whole dataset, you'd get a 75% accuracy due to the class imbalance built into the simulated dataset.

ImXman commented 1 year ago

I also wonder use macroF1 score to evaluate instead overall accuracy. You said the data is imbalanced. It looks like the class 0 is dominant.

sjfleming commented 1 year ago

Yeah alright if I use a less powerful classifier for validation, like a linear classifier (no hidden layers or nonlinearities), then I only get 72% accuracy with the validation classifier.

FullyConnectedNetwork(
  (network): Sequential(
    (0): Linear(in_features=32, out_features=3, bias=True)
  )
)

But without gradient reversal, that same kind of linear validation classifier can still get 100% accuracy.

sjfleming commented 1 year ago

Yeah @ImXman maybe you can add the macroF1 score for validation if you get a chance? I didn't think accuracy was really the best thing to look at...

ImXman commented 1 year ago

I got F1 score for each class as [0.9168254 , 0.64220183, 0.48979592]. Class 0 is the dominant class and any classifier can assign every sample as 0 to get the maximum accuracy.


z = get_latent(dataset=dataset,encoder=encoder).detach()
z = np.array(z)
classifier = LogisticRegression(random_state=0)
classifier.fit(z, y)
pred = classifier.predict(z)
f1_score(y, pred, average=None)
sjfleming commented 1 year ago

Well, to me it seems like this could still be working better... for example, if I do a test where z is truly random, then I get

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

z_random = torch.randn(z.shape).detach().numpy()
classifier = LogisticRegression(random_state=0)
classifier.fit(z_random, data_classes)
pred = classifier.predict(z_random)
f1_score(data_classes, pred, average=None)
array([0.86234357, 0.        , 0.        ])
sjfleming commented 1 year ago

Either gradient reversal is not working properly, the training has not converged, or something about the setup of this test is wrong... what do you think @ImXman ?

sjfleming commented 1 year ago

(We can write a more targeted unit test to ensure that gradient reversal is working. Then we can eliminate that hypothesis...)

ImXman commented 1 year ago

If you use the truly random z to train a decoder and aim to reconstruct it to real data, how much can the decoder achieve? I think the encoder has dual roles here. Not just fooling the classifier but also preserving the intrinsic structure, so the decoder can reconstruct data from the vector z. I do think gradient reversal may be ineffective.

sjfleming commented 1 year ago

Well, so, as you say, the decoder would not be able to work if the class information were truly stripped from z in the case of an autoencoder.

BUT, for that first test, it is not actually an autoencoder. There is no decoder, and no reconstruction error in the loss function.

So I think something about that test might be kind of sub-optimal. In particular, I know that the optimal solution is actually just to have an encoder which does encoder(x) = 0.

So here is a more complicated test case which is actually an autoencoder. There is a decoder and a reconstruction MSE loss in the loss function. And to avoid the problem you mentioned (where the reconstruction CANNOT work if class information is stripped from z), I have included an artificial "latent space algebra" step, where something different gets added to the latent space for each class.

    # do latent space algebra
    # this is just something to make it so that the distributions could, in theory,
    # be decoded back to the original space even if the latent space removes class information
    z_modified = z + torch.ones_like(z) * true_classes.float().unsqueeze(dim=-1)

    # reconstruct data
    reconstruction = decoder(z_modified)

That way, the reconstruction could, in theory, still be very good, even if the class information is removed from z.

This actually seems to work. Now we can get a validation classifier accuracy matching random.

See here:

https://github.com/broadinstitute/single-cell-compositional-perturbations/blob/sf_gradient_reversal_testing/gradient_reversal_autoencoder_testing.ipynb

sjfleming commented 1 year ago

Now, using gradient reversal, z looks like this image

And the validation LogisticRegression classifier F1 accuracies per class are

array([0.86283438, 0.        , 0.        ])

while totally random data gives

array([0.86267806, 0.01081081, 0.        ])
ImXman commented 1 year ago


fpr=dict()
tpr=dict()
roc_auc=dict()
data_classes = np.array(data_classes)
for c in [1,2]:
    X = z[np.logical_or(data_classes==0,data_classes==c),:]
    y = data_classes[np.logical_or(data_classes==0,data_classes==c)]
    y[y==c]=1
    random_state = np.random.RandomState(0)

    # shuffle and split training and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)

    # Learn to predict each class against the other
    classifier = LogisticRegression(random_state=random_state)
    y_score = classifier.fit(X_train, y_train).decision_function(X_test)

    fpr[c], tpr[c], _ = roc_curve(y_test, y_score, pos_label=classifier.classes_[1])
    roc_auc[c] = auc(fpr[c], tpr[c])

colors = list(sns.color_palette("Paired"))+list(sns.color_palette("hls", 8))

plt.figure()
for i, color in zip(range(2), colors):
    plt.plot(
        fpr[[1,2][i]],
        tpr[[1,2][i]],
        color=color,
        lw=1,
        label="{0}".format([1,2][i]),
    )
plt.grid(False)
plt.plot([0, 1], [0, 1], "k--", lw=1)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("")
legend = plt.legend(loc="lower right",prop={'size': 4.5})
legend.get_frame().set_facecolor('none')
plt.show()
sjfleming commented 9 months ago

We appear to have something working for now, so I will close this