Open uriva opened 2 years ago
if you do here it is:
from typing import Iterable import immutables # Design choices: # - Using a class to allow for typing # - The class has no methods to ensure all logic is in the functions below. # - The wrapped map is kept private. # - To prevent the user from making subtle mistake, we override `__eq__` to raise an error. # Corollaries: # - Will not work with operators ootb, e.g. `in`, `==` or `len`. class ImmutableSet: def __init__(self, inner): self._inner = inner def __eq__(self, _): raise NotImplementedError( "Use the functions in this module instead of operators.", ) def create(iterable: Iterable) -> ImmutableSet: return ImmutableSet(immutables.Map(map(lambda x: (x, None), iterable))) EMPTY: ImmutableSet = create([]) def equals(s1: ImmutableSet, s2: ImmutableSet) -> bool: return s1._inner == s2._inner # noqa: SF01 def length(set: ImmutableSet) -> int: return len(set._inner) # noqa: SF01 def add(set: ImmutableSet, element) -> ImmutableSet: return ImmutableSet(set._inner.set(element, None)) # noqa: SF01 def remove(set: ImmutableSet, element) -> ImmutableSet: return ImmutableSet(set._inner.delete(element)) # noqa: SF01 def contains(set: ImmutableSet, element) -> bool: return element in set._inner # noqa: SF01 def union(set1: ImmutableSet, set2: ImmutableSet) -> ImmutableSet: smaller, larger = sorted([set1, set2], key=length) return ImmutableSet(larger._inner.update(smaller._inner)) # noqa: SF01 def intersection(set1: ImmutableSet, set2: ImmutableSet) -> ImmutableSet: smaller, larger = sorted([set1, set2], key=length) for element in smaller._inner: # noqa: SF01 if not contains(larger, element): smaller = remove(smaller, element) return smaller
and tests:
import time def test_add(): assert immutable_set.equals( immutable_set.add( immutable_set.create([1, 2, 3]), 4, ), immutable_set.create( [1, 2, 3, 4], ), ) def test_remove(): assert immutable_set.equals( immutable_set.remove( immutable_set.create([1, 2, 3]), 2, ), immutable_set.create([1, 3]), ) def test_contains(): assert immutable_set.contains(immutable_set.create([1, 2, 3]), 3) def test_not_contains(): assert not immutable_set.contains(immutable_set.create([1, 2, 3]), 4) def test_union(): assert immutable_set.equals( immutable_set.union( immutable_set.create([1, 2, 3, 4]), immutable_set.create([1, 2, 3]), ), immutable_set.create([1, 2, 3, 4]), ) def _is_o_of_1(f, arg1, arg2): start = time.perf_counter() f(arg1, arg2) return time.perf_counter() - start < 0.0001 _large_number = 9999 def test_intersection(): assert immutable_set.equals( immutable_set.intersection( immutable_set.create([1, 2]), immutable_set.create([2]), ), immutable_set.create([2]), ) def test_performance_sanity(): assert not _is_o_of_1( immutable_set.union, immutable_set.create(range(_large_number)), immutable_set.create(range(_large_number)), ) def test_union_performance(): assert _is_o_of_1( immutable_set.union, immutable_set.create(range(_large_number)), immutable_set.create(range(_large_number // 64, _large_number // 32)), ) def test_intersection_performance(): assert _is_o_of_1( immutable_set.intersection, immutable_set.create(range(_large_number)), immutable_set.create(range(1)), )
if you do here it is:
and tests: