tophat / syrupy

:pancakes: The sweeter pytest snapshot plugin
https://tophat.github.io/syrupy
Apache License 2.0
501 stars 33 forks source link

refactor: scaffolding to support custom context in extensions #816

Open noahnu opened 10 months ago

noahnu commented 10 months ago

NOTE: Since syrupy v4 migrated from instance methods to classmethods, this new context is not actual usable. This lays the groundwork for a switch back to instance methods though (if we continue along this path).

Related to https://github.com/tophat/syrupy/pull/814, this PR lays the groundwork to switch back to instance-based extensions (reverting an earlier decision to move to class methods for easier pytest-xdist compatibility).

atharva-2001 commented 10 months ago

Thank you @noahnu for this PR! Would you mind letting me know how long this PR would take to merge? This feature was critical in some of our tests in TARDIS. If there is something I can do to help, please let me know!!

noahnu commented 10 months ago

@atharva-2001 I can't give an ETA. Could you describe what you're trying to do in your project (possibly with an example)? I may be able to recommend a workaround.

atharva-2001 commented 10 months ago

I see. Here is some example code. Since most of my code deals with NumPy arrays and Pandas dataframes, I want to send in additional assertion options, for example, rtol, the assertion function etc. I don't want to create multiple fixtures by typing them out. Is it possible to at least create them programatically? Thanks for all the help!

from typing import Any

import numpy as np
import pytest

from syrupy.data import SnapshotCollection
from syrupy.extensions.single_file import SingleFileSnapshotExtension
from syrupy.types import SerializableData

class NumpySnapshotExtenstion(SingleFileSnapshotExtension):
    _file_extension = "dat"

    def matches(self, *, serialized_data, snapshot_data, **kwargs):
        print(kwargs, "kwargs inside matches")
        try:
            if (
                # Allow multiple assertion methodds here, for example- assert_almost_equal
                # allow relative and default tolerance
                np.testing.assert_allclose(
                    np.array(snapshot_data), np.array(serialized_data), **kwargs
                )
                is not None
            ):
                return False
            else:
                return True

        except:
            return False

    def _read_snapshot_data_from_location(
        self, *, snapshot_location: str, snapshot_name: str, session_id: str
    ):
        # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L139
        try:
            return np.loadtxt(snapshot_location).tolist()
        except OSError:
            return None

    @classmethod
    def _write_snapshot_collection(
        cls, *, snapshot_collection: SnapshotCollection
    ) -> None:
        # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L161

        filepath, data = (
            snapshot_collection.location,
            next(iter(snapshot_collection)).data,
        )
        np.savetxt(filepath, data)

    def serialize(self, data: SerializableData, **kwargs: Any) -> str:
        return data

@pytest.fixture
def snapshot_numpy(snapshot):
    options = dict(matcher_options=dict(rtol=1, atol=0))
    return snapshot.with_defaults(extension_class=NumpySnapshotExtenstion)

def test_np(snapshot_numpy):
    x = [1e-5, 1e-3, 1e-1]
    # ideally-
    # from numpy.testing import assert_allclose
    # assert snapshot_numpy(matcher = assert_allclose, rtol=1e6...)
    assert snapshot_numpy == x
noahnu commented 9 months ago

@atharva-2001 Does something like this work until syrupy has built-in support?

import pytest

class NumpySnapshotExtension(SingleFileSnapshotExtension):
    _file_extension = "dat"

    rtol = 0
    atol = 0

    def with_kwargs(**kwargs):
        class MyCopy(NumpySnapshotExtension):
            rtol = kwargs["rtol"]
            atol = kwargs["atol"]
        return MyCopy

@pytest.fixture
def snapshot(snapshot):
    def factory(**kwargs):
        _class = NumpySnapshotExtension.with_kwargs(**kwargs)
        return snapshot.with_defaults(extension_class=_class)
    return factory

def test_np(snapshot):
    assert snapshot(rtol=1, atol=0) == [1e-5, 1e-3]

(not tested)