MagicStack / immutables

A high-performance immutable mapping type for Python.
Other
1.14k stars 57 forks source link

Do you want my frozenset implemenation using this lib? #82

Open uriva opened 2 years ago

uriva commented 2 years ago

if you do here it is:

from typing import Iterable

import immutables

# Design choices:
# - Using a class to allow for typing
# - The class has no methods to ensure all logic is in the functions below.
# - The wrapped map is kept private.
# - To prevent the user from making subtle mistake, we override `__eq__` to raise an error.
# Corollaries:
# - Will not work with operators ootb, e.g. `in`, `==` or `len`.

class ImmutableSet:
    def __init__(self, inner):
        self._inner = inner

    def __eq__(self, _):
        raise NotImplementedError(
            "Use the functions in this module instead of operators.",
        )

def create(iterable: Iterable) -> ImmutableSet:
    return ImmutableSet(immutables.Map(map(lambda x: (x, None), iterable)))

EMPTY: ImmutableSet = create([])

def equals(s1: ImmutableSet, s2: ImmutableSet) -> bool:
    return s1._inner == s2._inner  # noqa: SF01

def length(set: ImmutableSet) -> int:
    return len(set._inner)  # noqa: SF01

def add(set: ImmutableSet, element) -> ImmutableSet:
    return ImmutableSet(set._inner.set(element, None))  # noqa: SF01

def remove(set: ImmutableSet, element) -> ImmutableSet:
    return ImmutableSet(set._inner.delete(element))  # noqa: SF01

def contains(set: ImmutableSet, element) -> bool:
    return element in set._inner  # noqa: SF01

def union(set1: ImmutableSet, set2: ImmutableSet) -> ImmutableSet:
    smaller, larger = sorted([set1, set2], key=length)
    return ImmutableSet(larger._inner.update(smaller._inner))  # noqa: SF01

def intersection(set1: ImmutableSet, set2: ImmutableSet) -> ImmutableSet:
    smaller, larger = sorted([set1, set2], key=length)
    for element in smaller._inner:  # noqa: SF01
        if not contains(larger, element):
            smaller = remove(smaller, element)
    return smaller

and tests:

import time

def test_add():
    assert immutable_set.equals(
        immutable_set.add(
            immutable_set.create([1, 2, 3]),
            4,
        ),
        immutable_set.create(
            [1, 2, 3, 4],
        ),
    )

def test_remove():
    assert immutable_set.equals(
        immutable_set.remove(
            immutable_set.create([1, 2, 3]),
            2,
        ),
        immutable_set.create([1, 3]),
    )

def test_contains():
    assert immutable_set.contains(immutable_set.create([1, 2, 3]), 3)

def test_not_contains():
    assert not immutable_set.contains(immutable_set.create([1, 2, 3]), 4)

def test_union():
    assert immutable_set.equals(
        immutable_set.union(
            immutable_set.create([1, 2, 3, 4]),
            immutable_set.create([1, 2, 3]),
        ),
        immutable_set.create([1, 2, 3, 4]),
    )

def _is_o_of_1(f, arg1, arg2):
    start = time.perf_counter()
    f(arg1, arg2)
    return time.perf_counter() - start < 0.0001

_large_number = 9999

def test_intersection():
    assert immutable_set.equals(
        immutable_set.intersection(
            immutable_set.create([1, 2]),
            immutable_set.create([2]),
        ),
        immutable_set.create([2]),
    )

def test_performance_sanity():
    assert not _is_o_of_1(
        immutable_set.union,
        immutable_set.create(range(_large_number)),
        immutable_set.create(range(_large_number)),
    )

def test_union_performance():
    assert _is_o_of_1(
        immutable_set.union,
        immutable_set.create(range(_large_number)),
        immutable_set.create(range(_large_number // 64, _large_number // 32)),
    )

def test_intersection_performance():
    assert _is_o_of_1(
        immutable_set.intersection,
        immutable_set.create(range(_large_number)),
        immutable_set.create(range(1)),
    )