microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.76k stars 348 forks source link

class_weights cannot be passed via config file as a tensor is expected #2060

Open robmarkcole opened 6 months ago

robmarkcole commented 6 months ago

Description

Using the Lightning CLI we can train the SemanticSegmentationTask, but cannot use class_weights without an error. Solution is to accept a list if int in addition to tensor

Steps to reproduce

In Lightning CLI Yaml:

model:
  class_path: SemanticSegmentationTask
  init_args:
    model: unet
    backbone: resnet50
    weights: null
    lr: 0.001
    in_channels: 6
    num_classes: 2
    class_weights:
      - 1
      - 50

Will result in

      Does not validate against any of the Union subtypes
      Subtypes: (<class 'torch.Tensor'>, <class 'NoneType'>)
      Errors:
        - Not a valid subclass of Tensor
          Subclass types expect one of:
          - a class path (str)
          - a dict with class_path entry
          - a dict without class_path but with init_args entry (class path given previously)
        - Expected a <class 'NoneType'>
      Given value type: <class 'list'>
      Given value: [1, 50]

Version

main

isaaccorley commented 6 months ago

Care to make a PR to accept a list and convert to a tensor? If not then I can take it on this weekend.

robmarkcole commented 6 months ago

You will get to it way before me!

adamjstewart commented 6 months ago

For a bit of history, I added this in #1221 and it initially only supported lists. In #1413, @ntw-au modified this to support lists, numpy arrays, and torch tensors. Then in #1541, I modified it to only accept torch tensors. I agree we need a way to support class_weights in a YAML file (and preferably also on the command line). If omegaconf supports this, we could also easily enable omegaconf as a parser: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced_2.html#enable-variable-interpolation.

isaaccorley commented 6 months ago

If you want to use it with hydra.utils.instantiate and omegaconf you would only need to do the following:

class_weights:
   _target_: torch.tensor
   data: [0.5, 0.5]
isaaccorley commented 6 months ago

I haven't looked at the Lightning CLI in awhile but I wonder if it supports recursive instantiation like

class_weights:
   class_path: torch.tensor
   init_args:
      data: [0.5, 0.5]