Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.41k stars 3.39k forks source link

Cannot instantiate custom object from config file using Lightning CLI #17400

Closed ggggg111 closed 1 year ago

ggggg111 commented 1 year ago

Bug description

Lightning CLI cannot instantiate correctly my custom object from the configuration file. We'll focus on my StandardGraph object.

For example, in my config file:

model:
  class_path: model.STGCNpp
  init_args:
    criterion:
      class_path: torch.nn.CrossEntropyLoss
    graph:
      class_path: model.utils.StandardGraph
    in_channels: 2

Where my class is:

class StandardGraph:
    def __init__(self):
        self.edges = [
            (0, 1),
            (0, 2),
            (1, 3),
            (2, 4),
            (5, 6),
            (5, 7),
            (5, 11),
            (6, 8),
            (6, 12),
            (7, 9),
            (8, 10),
            (11, 12),
            (11, 13),
            (12, 14),
            (13, 15),
            (14, 16),
        ]

        self.num_vertices = 17

        self.adj_matrix = self._get_adj_matrix(self.edges, self.num_vertices)

    def _get_adj_matrix(self, edges, num_vertices, self_connections=True):
        adj_matrix = np.zeros((num_vertices, num_vertices))

        adj_matrix[edges[:, 0], edges[:, 1]] = 1
        adj_matrix[edges[:, 1], edges[:, 0]] = 1

        if self_connections:
            np.fill_diagonal(adj_matrix, 1)

        return adj_matrix

I get the following error (check traceback):

Traceback (most recent call last):
  File "/root/run.py", line 17, in <module>
    main()
  File "run.py", line 13, in main
    cli = LightningCLI()
  File "/opt/conda/envs/test/lib/python3.9/site-packages/lightning/pytorch/cli.py", line 350, in __init__
    self.instantiate_classes()
  File "/opt/conda/envs/test/lib/python3.9/site-packages/lightning/pytorch/cli.py", line 499, in instantiate_classes
    self.config_init = self.parser.instantiate_classes(self.config)
  File "/opt/conda/envs/test/lib/python3.9/site-packages/jsonargparse/deprecated.py", line 131, in patched_instantiate_classes
    cfg = self._unpatched_instantiate_classes(cfg, **kwargs)
  File "/opt/conda/envs/test/lib/python3.9/site-packages/jsonargparse/core.py", line 1144, in instantiate_classes
    cfg[subcommand] = subparser.instantiate_classes(cfg[subcommand], instantiate_groups=instantiate_groups)
  File "/opt/conda/envs/test/lib/python3.9/site-packages/jsonargparse/deprecated.py", line 131, in patched_instantiate_classes
    cfg = self._unpatched_instantiate_classes(cfg, **kwargs)
  File "/opt/conda/envs/test/lib/python3.9/site-packages/jsonargparse/core.py", line 1135, in instantiate_classes
    parent[key] = component.instantiate_classes(value)
  File "/opt/conda/envs/test/lib/python3.9/site-packages/jsonargparse/typehints.py", line 474, in instantiate_classes
    value[num] = adapt_typehints(val, self._typehint, instantiate_classes=True, sub_add_kwargs=sub_add_kwargs)
  File "/opt/conda/envs/test/lib/python3.9/site-packages/jsonargparse/typehints.py", line 785, in adapt_typehints
    val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
  File "/opt/conda/envs/test/lib/python3.9/site-packages/jsonargparse/typehints.py", line 996, in adapt_class_type
    return val_class(**{**init_args, **dict_kwargs})
TypeError: __init__() got an unexpected keyword argument 'graph.class_path'

What is weird, if I define the class as:

class StandardGraph:
    def __init__(self):
        self.edges = [
            (0, 1),
            (0, 2),
            (1, 3),
            (2, 4),
            (5, 6),
            (5, 7),
            (5, 11),
            (6, 8),
            (6, 12),
            (7, 9),
            (8, 10),
            (11, 12),
            (11, 13),
            (12, 14),
            (13, 15),
            (14, 16),
        ]

        self.num_vertices = 17

the error disappears...

I'm new in Lightning, I hope I can get some help. Thank you for your hard work.

What version are you seeing the problem on?

2.0+

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment * CUDA: - GPU: - NVIDIA GeForce RTX 3090 - NVIDIA GeForce RTX 3090 - NVIDIA GeForce RTX 3090 - NVIDIA GeForce RTX 3090 - available: True - version: 11.8 * Lightning: - lightning: 2.0.1.post0 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.32 - lightning-fabric: 2.0.1 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.0 - torch: 2.0.0+cu118 - torch-cluster: 1.6.1+pt20cu118 - torch-geometric: 2.3.0 - torch-scatter: 2.1.1+pt20cu118 - torch-sparse: 0.6.17+pt20cu118 - torch-spline-conv: 1.2.2+pt20cu118 - torchaudio: 2.0.1+cu118 - torchmetrics: 0.11.4 - torchvision: 0.15.1+cu118 * Packages: - absl-py: 1.4.0 - aiobotocore: 2.4.2 - aiohttp: 3.8.4 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - altair: 4.2.2 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - arrow: 1.2.3 - asttokens: 2.2.1 - async-timeout: 4.0.2 - attrs: 22.2.0 - beautifulsoup4: 4.12.0 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.6.1 - bokeh: 2.4.3 - botocore: 1.27.59 - cachetools: 5.3.0 - certifi: 2022.12.7 - charset-normalizer: 2.1.1 - click: 8.1.3 - cmake: 3.25.0 - colorama: 0.4.6 - contourpy: 1.0.7 - croniter: 1.3.8 - cycler: 0.11.0 - dateutils: 0.6.12 - decorator: 5.1.1 - deepdiff: 6.3.0 - dnspython: 2.3.0 - docker: 6.0.1 - docstring-parser: 0.15 - einops: 0.6.0 - email-validator: 1.3.1 - entrypoints: 0.4 - executing: 1.2.0 - fastapi: 0.88.0 - filelock: 3.9.0 - fonttools: 4.39.3 - frozenlist: 1.3.3 - fsspec: 2022.11.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - google-auth: 2.16.3 - google-auth-oauthlib: 0.4.6 - grpcio: 1.51.3 - h11: 0.14.0 - httpcore: 0.16.3 - httptools: 0.5.0 - httpx: 0.23.3 - hydra-core: 1.3.2 - icecream: 2.1.3 - idna: 3.4 - importlib-metadata: 6.1.0 - importlib-resources: 5.12.0 - inquirer: 3.1.3 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - jsonargparse: 4.20.1 - jsonschema: 4.17.3 - kiwisolver: 1.4.4 - lightning: 2.0.1.post0 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.32 - lightning-fabric: 2.0.1 - lightning-utilities: 0.8.0 - lit: 15.0.7 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.7.1 - mdurl: 0.1.2 - mpmath: 1.2.1 - multidict: 6.0.4 - multimodalgcn: 0.1.0 - networkx: 3.0 - numpy: 1.24.1 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - opencv-python-headless: 4.7.0.72 - ordered-set: 4.1.0 - orjson: 3.8.8 - packaging: 23.0 - pandas: 1.5.3 - panel: 0.14.4 - param: 1.13.0 - pillow: 9.3.0 - pip: 23.0.1 - protobuf: 3.20.3 - psutil: 5.9.4 - pyarrow: 11.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pyct: 0.5.0 - pydantic: 1.10.7 - pydeck: 0.8.0 - pygments: 2.14.0 - pyjwt: 2.6.0 - pympler: 1.0.1 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - python-box: 7.0.1 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.0 - pytz: 2022.7.1 - pytz-deprecation-shim: 0.1.0.post0 - pyviz-comms: 2.2.1 - pyyaml: 6.0 - readchar: 4.0.5 - redis: 4.5.4 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - rfc3986: 1.5.0 - rich: 13.3.2 - rsa: 4.9 - s3fs: 2022.11.0 - scikit-learn: 1.2.2 - scipy: 1.10.1 - setuptools: 65.6.3 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - soupsieve: 2.4 - starlette: 0.22.0 - starsessions: 1.3.0 - streamlit: 1.21.0 - sympy: 1.11.1 - tensorboard: 2.12.0 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.6 - threadpoolctl: 3.1.0 - toml: 0.10.2 - toolz: 0.12.0 - torch: 2.0.0+cu118 - torch-cluster: 1.6.1+pt20cu118 - torch-geometric: 2.3.0 - torch-scatter: 2.1.1+pt20cu118 - torch-sparse: 0.6.17+pt20cu118 - torch-spline-conv: 1.2.2+pt20cu118 - torchaudio: 2.0.1+cu118 - torchmetrics: 0.11.4 - torchvision: 0.15.1+cu118 - tornado: 6.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - triton: 2.0.0 - typeshed-client: 2.2.0 - typing-extensions: 4.4.0 - tzdata: 2023.3 - tzlocal: 4.3 - ujson: 5.7.0 - urllib3: 1.26.13 - uvicorn: 0.21.1 - uvloop: 0.17.0 - validators: 0.20.0 - watchdog: 3.0.0 - watchfiles: 0.18.1 - wcwidth: 0.2.6 - webencodings: 0.5.1 - websocket-client: 1.5.1 - websockets: 10.4 - werkzeug: 2.2.3 - wheel: 0.38.4 - wrapt: 1.15.0 - yarl: 1.8.2 - zipp: 3.15.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.16 - version: #62~20.04.1-Ubuntu SMP Tue Nov 22 21:24:20 UTC 2022

More info

No response

tkella47 commented 1 year ago

I think there is a bug in your code... When I attempt to instantiate your StandardGraph as written, I get


TypeError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 graph = StandardGraph()

Cell In[6], line 24, in StandardGraph.__init__(self)
      3 self.edges = [
      4     (0, 1),
      5     (0, 2),
   (...)
     19     (14, 16),
     20 ]
     22 self.num_vertices = 17
---> 24 self.adj_matrix = self._get_adj_matrix()

Cell In[6], line 29, in StandardGraph._get_adj_matrix(self, self_connections)
     26 def _get_adj_matrix(self, self_connections=True):
     27     adj_matrix = np.zeros((self.num_vertices, self.num_vertices))
---> 29     adj_matrix[self.edges[:, 0], self.edges[:, 1]] = 1
     30     adj_matrix[self.edges[:, 1], self.edges[:, 0]] = 1
     32     if self_connections:

TypeError: list indices must be integers or slices, not tuple

This is not a bug with PytorchLightning

tkella47 commented 1 year ago

The error occurs because you have a list of tuples. In order to access each tuple, you index the list, then the tuple. I have corrected the code below

class StandardGraph:
    def __init__(self):
        self.edges = [
            (0, 1),
            (0, 2),
            (1, 3),
            (2, 4),
            (5, 6),
            (5, 7),
            (5, 11),
            (6, 8),
            (6, 12),
            (7, 9),
            (8, 10),
            (11, 12),
            (11, 13),
            (12, 14),
            (13, 15),
            (14, 16),
        ]

        self.num_vertices = 17

        self.adj_matrix = self._get_adj_matrix()

    def _get_adj_matrix(self, edges, num_vertices, self_connections=True):
        adj_matrix = np.zeros((num_vertices, num_vertices))

        adj_matrix[edges[:][0], edges[:][1]] = 1
        adj_matrix[edges[:][1], edges[:][0]] = 1

        if self_connections:
            np.fill_diagonal(adj_matrix, 1)

        return adj_matrix

Additionally, does _get_adj_matrix need to be a class method? Or should it be static?

ggggg111 commented 1 year ago

Hello Thomas, thank you for your response. Indeed there is a bug in my code not related to Lightning. I didn't notice it because of the error message, it seemed like it was a Lightning error in the traceback.

Anyways, the correct code is a bit different than yours, because edges[:][0] equals edges[0] and edges[:][1] equals edges[1]. So what I did was convert the self.edges in a numpy array and use numpy's indexing like so:

@staticmethod
def _get_adj_matrix(edges, num_vertices, self_connections=True):
    adj_matrix = np.zeros((num_vertices, num_vertices))

    edges = np.array(edges)

    adj_matrix[edges[:, 0], edges[:, 1]] = 1
    adj_matrix[edges[:, 1], edges[:, 0]] = 1

    if self_connections:
        np.fill_diagonal(adj_matrix, 1)

    return adj_matrix

I finally put the _get_adj_matrix function as static.

Appreciate your help! Thank you. I'm closing the issue.