matsengrp / torchdms

Analyze deep mutational scanning data with PyTorch
https://matsengrp.github.io/torchdms/
2 stars 0 forks source link

No grad for regularization loss #101

Open matsen opened 3 years ago

matsen commented 3 years ago

I think that our implementation of regularization loss is broken!

Here's how it looks now:

    def regularization_loss(self):
        """L1 penalize single mutant effects, and pre-latent interaction
        weights."""
        penalty = 0.0
        if self.beta_l1_coefficient > 0.0:
            penalty += self.beta_l1_coefficient * self.latent_layer.weight[
                :, : self.input_size
            ].norm(1)
        if self.interaction_l1_coefficient > 0.0:
            for interaction_layer in self.layers[: self.latent_idx]:
                penalty += self.interaction_l1_coefficient * getattr(
                    self, interaction_layer
                ).weight.norm(1)
        return penalty

The thing is, penalty is thus a float and we have no option for backprop!

I can check this out by using

diff --git a/torchdms/analysis.py b/torchdms/analysis.py
index 767791a..044bb55 100644
--- a/torchdms/analysis.py
+++ b/torchdms/analysis.py
@@ -88,6 +88,9 @@ class Analysis:
                 range(targets.shape[1]), loss_decays
             )
         ]
+        qqq = sum(per_target_loss)
+        ppp = self.model.regularization_loss()
+        breakpoint()
         return sum(per_target_loss) + self.model.regularization_loss()

     def train(

If we print qqq, it's a tensor, but ppp is a float.

matsen commented 3 years ago

I pushed this version of regularization loss to 101-no-regularization-grad

    def regularization_loss(self):                                                                                                                                                                                                
        """L1 penalize single mutant effects, and pre-latent interaction                                                                                                                                                          
        weights."""                                                                                                                                                                                                               
        penalty = self.beta_l1_coefficient * self.latent_layer.weight[                                                                                                                                                            
            :, : self.input_size                                                                                                                                                                                                  
        ].norm(1)                                                                                                                                                                                                                 
        if self.interaction_l1_coefficient > 0.0:                                                                                                                                                                                 
            for interaction_layer in self.layers[: self.latent_idx]:                                                                                                                                                              
                penalty += self.interaction_l1_coefficient * torch.sum(                                                                                                                                                           
                    [getattr(self, interaction_layer).weight.norm(1)]                                                                                                                                                             
                )                                                                                                                                                                                                                 
        return penalty                                                                                                                                                                                                            

This version gives

(Pdb++) ppp
tensor(0., grad_fn=<MulBackward0>)

when I run make test.

wsdewitt commented 3 years ago

@matsen Hmm, I can't reproduce the float issue:

>>> from torchdms.model import FullyConnected
>>> model = FullyConnected(10, [2], [None], [None], None, beta_l1_coefficient=1e-3)
>>> loss = model.regularization_loss()
>>> print(loss)
tensor([14.5608], grad_fn=<AddBackward0>)
matsen commented 3 years ago

That's strange.

Did you try dropping into the debugger as in my original report?

wsdewitt commented 3 years ago

Yes the issue surfaces in the debugger.

(Pdb++) print(ppp)
0.0
(Pdb++) ppp.backward()
*** AttributeError: 'float' object has no attribute 'backward'
matsen commented 3 years ago

Fascinating. And sorry if I sent you on a goose chase.

How do you propose moving forward?

wsdewitt commented 3 years ago

I still don't understand the behavior, so no proposal yet. I'll keep poking! 👨‍🏭