pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.12k stars 149 forks source link

Refactor test suite to be more readable? #175

Open pmeier opened 2 years ago

pmeier commented 2 years ago

While working on #174, I also worked on the test suite. In there we have the ginormous tests that are hard to parse, because they do so many things at the same time:

https://github.com/pytorch/data/blob/c06066ae360fc6054fb826ae041b1cb0c09b2f3b/test/test_datapipe.py#L382-L426

I was wondering if there is a reason for that. Can't we split this into multiple smaller ones? Utilizing pytest, placing the following class in the test module is equivalent to the test above:

class TestLineReader:
    @pytest.fixture
    def text1(self):
        return "Line1\nLine2"

    @pytest.fixture
    def text2(self):
        return "Line2,1\nLine2,2\nLine2,3"

    def test_functional_read_lines_correctly(self, text1, text2):
        source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
        line_reader_dp = source_dp.readlines()
        expected_result = [("file1", line) for line in text1.split("\n")] + [
            ("file2", line) for line in text2.split("\n")
        ]
        assert expected_result == list(line_reader_dp)

    def test_functional_strip_new_lines_for_bytes(self, text1, text2):
        source_dp = IterableWrapper(
            [("file1", io.BytesIO(text1.encode("utf-8"))), ("file2", io.BytesIO(text2.encode("utf-8")))]
        )
        line_reader_dp = source_dp.readlines()
        expected_result_bytes = [("file1", line.encode("utf-8")) for line in text1.split("\n")] + [
            ("file2", line.encode("utf-8")) for line in text2.split("\n")
        ]
        assert expected_result_bytes == list(line_reader_dp)

    def test_functional_do_not_strip_newlines(self, text1, text2):
        source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
        line_reader_dp = source_dp.readlines(strip_newline=False)
        expected_result = [
            ("file1", "Line1\n"),
            ("file1", "Line2"),
            ("file2", "Line2,1\n"),
            ("file2", "Line2,2\n"),
            ("file2", "Line2,3"),
        ]
        assert expected_result == list(line_reader_dp)

    def test_reset(self, text1, text2):
        source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
        line_reader_dp = LineReader(source_dp, strip_newline=False)
        expected_result = [
            ("file1", "Line1\n"),
            ("file1", "Line2"),
            ("file2", "Line2,1\n"),
            ("file2", "Line2,2\n"),
            ("file2", "Line2,3"),
        ]

        n_elements_before_reset = 2
        res_before_reset, res_after_reset = reset_after_n_next_calls(line_reader_dp, n_elements_before_reset)
        assert expected_result[:n_elements_before_reset] == res_before_reset
        assert expected_result == res_after_reset

    def test_len(self, text1, text2):
        source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))])
        line_reader_dp = LineReader(source_dp, strip_newline=False)

        with pytest.raises(TypeError, match="has no len"):
            len(line_reader_dp)

This is a lot more readable, since we now actually have 5 separate test cases that can individually fail. Plus, while writing this I also found that test_reset and test_len were somewhat dependent on test_functional_do_not_strip_newlines since they don't neither define line_reader_dp nor expected_result themselves.

pmeier commented 2 years ago

Or even more readable:

class TestLineReader:
    @pytest.fixture
    def files_with_text(self):
        return [
            ("file1", "Line1\nLine2"),
            ("file2", "Line2,1\nLine2,2\nLine2,3"),
        ]

    def make_str_dp(self, files_with_text):
        return IterableWrapper([(file, io.StringIO(text)) for file, text in files_with_text])

    def make_bytes_dp(self, files_with_text):
        return IterableWrapper([(file, io.BytesIO(text.encode("utf-8"))) for file, text in files_with_text])

    def test_functional_read_lines_correctly(self, files_with_text):
        line_reader_dp = self.make_str_dp(files_with_text).readlines()

        expected = []
        for file, text in files_with_text:
            expected.extend((file, line) for line in text.splitlines())

        assert expected == list(line_reader_dp)

    def test_functional_strip_new_lines_for_bytes(self, files_with_text):
        line_reader_dp = self.make_bytes_dp(files_with_text).readlines()

        expected = []
        for file, text in files_with_text:
            expected.extend((file, line.encode("utf-8")) for line in text.splitlines())

        assert expected == list(line_reader_dp)

    def test_functional_do_not_strip_newlines(self, files_with_text):
        line_reader_dp = self.make_str_dp(files_with_text).readlines(strip_newline=False)

        expected = []
        for file, text in files_with_text:
            expected.extend((file, line) for line in text.splitlines(keepends=True))

        assert expected == list(line_reader_dp)

    def test_reset(self, files_with_text):
        line_reader_dp = LineReader(self.make_str_dp(files_with_text))

        expected = []
        for file, text in files_with_text:
            expected.extend((file, line) for line in text.splitlines())

        n_elements_before_reset = 2
        res_before_reset, res_after_reset = reset_after_n_next_calls(line_reader_dp, n_elements_before_reset)

        assert expected[:n_elements_before_reset] == res_before_reset
        assert expected == res_after_reset

    def test_len(self, files_with_text):
        line_reader_dp = LineReader(self.make_str_dp(files_with_text))

        with pytest.raises(TypeError, match="has no len"):
            len(line_reader_dp)
ejguan commented 2 years ago

I like this idea! cc: @NivekT Do you want to incorporate this into your PR https://github.com/pytorch/pytorch/pull/70215

pmeier commented 2 years ago

Ah, that might be an issue. In PyTorch core you cannot rely on pytest so if you want to have this there, you need to adapt what I proposed a little:

ejguan commented 2 years ago

@pytest.fixture's are not available. A workaround might be to store the files_with_text in a class constant and access it from there.

I believe we can do setupClass for this case.

NivekT commented 2 years ago

Thanks for the suggestion! I think this is cleaner than what we have. It will take quite a bit of manual refactoring of each DataPipe to get there.

I am wondering if we can do something even better - a standard template to test out DataPipe with less manual code writing (maybe just specifying the inputs), similar to what OpsInfo does in PyTorch Core.

erip commented 2 years ago

FWIW, we've started something similar in torchtext. See here if you're interested.