minerva-ml / steppy-toolkit

Curated set of transformers that make your work with steppy faster and more effective :telescope:
MIT License
22 stars 9 forks source link

UNet error #2

Open kamil-kaczmarek opened 6 years ago

kamil-kaczmarek commented 6 years ago

Currently UNet model in its forward() method returns a single tensor, so len(tensor) is batch size. On the other hand we have assertion in: https://github.com/minerva-ml/steppy-toolkit/blob/master/steppy_toolkit/pytorch/models.py#L94, so when batch_size>1 this assertion throws an error and it should not.