matsengrp / torchdms

Analyze deep mutational scanning data with PyTorch
https://matsengrp.github.io/torchdms/
2 stars 0 forks source link

Incorrect abstract base class usage #124

Closed wsdewitt closed 3 years ago

wsdewitt commented 3 years ago

All GGE models in torchdms.model.py subclass the abstract base class torchdms.model.TorchdmsModel https://github.com/matsengrp/torchdms/blob/419dde3eb6566f3e80cd10bc2cb5990f3a7e922c/torchdms/model.py#L20-L21

There are two problems:

1. Concretization enforcement

Problem

Abstract methods and attributes declared in the base class (with the @abstractmethod decorator) should be concretized in any subclass, but this is not being enforced. For example, the following dummy subclass doesn't perform any concretization, so we expect this to raise a TypeError complaining about several methods that remain abstract, but it does not:

>>> import torchdms.model
>>> class DummyModel(torchdms.model.TorchdmsModel):
...     pass
... 
>>> DummyModel(1, [], '')
DummyModel()

Proposed solution

Correctly defining an abstract base class requires inheritance from abc.ABCMeta as metaclass. This can be done most simply using multiple inheritance as follows:

class TorchdmsModel(abc.ABC, nn.Module):

After this change we get

>>> DummyModel(1, [], '')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Can't instantiate abstract class DummyModel with abstract methods characteristics, fix_gauge, forward, regularization_loss, str_summary, to_latent

2. Missing concretized property

Problem

The abstract base class defines the abstract property characteristics https://github.com/matsengrp/torchdms/blob/419dde3eb6566f3e80cd10bc2cb5990f3a7e922c/torchdms/model.py#L37-L40 The derived concrete class torchdms.model.Independent has a concretized characteristics property https://github.com/matsengrp/torchdms/blob/419dde3eb6566f3e80cd10bc2cb5990f3a7e922c/torchdms/model.py#L548-L553 However, attempting to access this attribute raises an error:

>>> import torchdms.model
>>> model = torchdms.model.Independent(3, [1, 2], [None, None], [None, None], None)
>>> model.characteristics
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/williamdewitt/Applications/miniconda2/envs/tdms/lib/python3.8/site-packages/torch/nn/modules/module.py", line 575, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'Independent' object has no attribute 'characteristics'

Interestingly, this only seems to affect model objects of type Independent (and its subclasses Conditional and ConditionalSequential). It does not affect other models that are subclasses of the abstract base class: Linear, FullyConnected, and Escape.

Proposed solution

🤷

jgallowa07 commented 3 years ago

@WSDeWitt How are you calling FullyConnected (or others) such that the charactaristics attribute does exist? If I call FullyConnected the same way it gets called in Independent.__init__ given the example you provide above.

>>> import torchdms.model
>>> fc = torchdms.model.FullyConnected(3, [1, 2], [None, None], [None], None) 
>>> fc.characteristics
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jared/miniconda3/envs/torchdms/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1135, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'FullyConnected' object has no attribute 'characteristics'
wsdewitt commented 3 years ago

Hmm, I can reproduce the attribute error above using the same instantiation, but it does fine with the following:

>>> fc = torchdms.model.FullyConnected(1, [], [], [], '')
>>> fc.characteristics
{'activations': '[]', 'monotonic': None, 'beta_l1_coefficient': 0.0, 'interaction_l1_coefficient': 0.0}
jgallowa07 commented 3 years ago

Right, Independent would also comply as with your example above, but it has the assertion len(layers) > 0 removing that, I can do:

>>> fc = torchdms.model.Independent(1, [], [], [[], []], '') 
>>> fc.characteristics
{'bind_activations': '[]', 'bind_monotonic': None, 'bind_beta_l1_coefficient': 0.0, 'bind_interaction_l1_coefficient': 0.0, 'stab_activations': '[]', 'stab_monotonic': None, 'stab_beta_l1_coefficient': 0.0, 'stab_interaction_l1_coefficient': 0.0}
jgallowa07 commented 3 years ago

So, this was my first indication that it wasn't something funky about how you were adding the modules and layers in Independent (My original thought). And that the use of the abstract base class is just fine. Here's the root of the problem:

Root of the problem #2

First, the problem can be re-created simply like so:

import torch.nn as nn

class CustomModule(nn.Module):
    def __init__(self):
        super().__init__()

    @property
    def property_a(self):
        return self.property_b

m = CustomModule()
print(m.property_a)

gives

AttributeError: 'CustomModule' object has no attribute 'property_a'

What's happening is that that when we instantiate the object like so (or any other "correct" way):

>>>  fc = torchdms.model.FullyConnected(3, [1, 2], [None, None], [None], None) 

fc's __getattr__ is now defined by nn.Module.getattr which looks like this:

    def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in _parameters:
                return _parameters[name]
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
                return _buffers[name]
        if '_modules' in self.__dict__:
            modules = self.__dict__['_modules']
            if name in modules:
                return modules[name]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, name))

As you can see, the thing throws an annoyingly innaccurate error, and only searches for the attribute withing those three dict keys behind the conditionals. After a little searching about the getattr, I found that this is a known issue here. Unfortunately, this won't be fixed because of performance issues.

Possible Solution

One solution is to simply get rid of the nice @parameter declaration, and just declare the value of characteristics in the init function of each subclass. This is simple, and easy, but won't force the subclasses to concretize and implementation :shrug: but I don't see this as a big deal (unless you have a bunch of people contributing and breaking things). You could always just raise a NotImplimentedError with the property decorator in te ABC, then it will force people to set a value for characteristics in the subclasses.

A more complex solution: It seems that nn.Module provides a way to register attributes here. Using this, you'll have to register the parameter in the abstract base class after wrapping it in a nn.Parameter. I have never seen this used although I'm sure there are some examples around.

wsdewitt commented 3 years ago

I think I see what's going on. This error is a manifestation of a known pytorch bug: pytorch/pytorch#49726. Our actual attribute error is occurring at this line, because None has no __name__ attribute: https://github.com/matsengrp/torchdms/blob/419dde3eb6566f3e80cd10bc2cb5990f3a7e922c/torchdms/model.py#L374

If we replace the None activations with a callable (which has a __name__ attribute), we have no problem.

>>> fc = torchdms.model.FullyConnected(3, [1, 2], [sum, sum], [None], None) 
>>> fc.characteristics
{'activations': "['sum', 'sum']", 'monotonic': None, 'beta_l1_coefficient': 0.0, 'interaction_l1_coefficient': 0.0}
jgallowa07 commented 3 years ago

haha nice timing

wsdewitt commented 3 years ago

Oh haha jinx!

jgallowa07 commented 3 years ago

Yea, as long as you instantiate it in a valid way you should be just fine. But to avoid confusion I'd just ix-nay the decorator. Not sure this is a place I'd use it in any case.

wsdewitt commented 3 years ago

Thanks! I agree we could lose the @property decorators in several places, and have them be class methods. I think the motivation for this was to enforce attribute concretization via the abstract base class (whereas these attributes would otherwise be more naturally defined in __init__)