Closed JiaweiZhuang closed 4 years ago
Ideally there should be a parameterized test for the DARTS code to check all kinds of input shapes. This can be done more easily if the code runs on CPU (#8).
Here's how the test code would look like, in an abstract level:
import pytest
from darts import train_search # some sort of DARTS wrapper; TBD
# (channel, x, y)
shape_list = [
(3, 32, 32), # CIFAR
(3, 28, 28), # MNIST
(1, 3, 5), # coarse-grid graphene
(1, 30, 80) # fine-grid graphene
] # add more shapes here
@pytest.mark.parametrize('input_shape', shape_list)
def test_input_shape(input_shape):
train_search(input_shape) # should train on synthetic numpy data with this shape
Such test can be particularly useful when applying DARTS to other scientific data. New datasets are basically just numpy arrays with a different shape. The actual data content doesn't matter much, in terms of making DARTS running.
@dylanrandle You should be able to run the test suite on GPU via #10.
The original DARTS implementation (https://github.com/quark0/darts) requires the input image to be a perfect
n*n
square. Otherwise it crashes as:This is fine for CIFAR (32x32) and MNIST (28x28), but most scientific images won't be a perfect square. The graphene data is 3x5 / 30x80 for example.
If the code is too hard to refactor, we can also resample all input images to a perfect square, which is less ideal.