HPAC / matchpy

A library for pattern matching on symbolic expressions in Python.
MIT License
164 stars 25 forks source link

Unable to replace matched pattern in `ManyToOneReplacer` #21

Closed arihantparsoya closed 7 years ago

arihantparsoya commented 7 years ago
import matchpy
Pattern, ReplacementRule, ManyToOneReplacer = matchpy.Pattern, matchpy.ReplacementRule, matchpy.ManyToOneReplacer

from matchpy import replace_all, is_match, Wildcard
from sympy.integrals import Integral
from sympy import Symbol, Pow, cacheit, Basic, S
from matchpy.expressions.functions import register_operation_iterator, register_operation_factory
from matchpy import Operation, CommutativeOperation, AssociativeOperation, OneIdentityOperation, match

class WC(Wildcard, Symbol):
    def __init__(self, min_length, fixed_size, variable_name=None, default=None, **assumptions):
        Wildcard.__init__(self, min_length, fixed_size, variable_name, default)

    def __new__(cls, min_length, fixed_size, variable_name=None, default=None, **assumptions):
        cls._sanitize(assumptions, cls)
        return WC.__xnew__(cls, min_length, fixed_size, variable_name, default, **assumptions)

    def __getnewargs__(self):
        return (self.min_length, self.fixed_size, self.variable_name, self.optional)

    @staticmethod
    def __xnew__(cls, min_length, fixed_size, variable_name=None, default=None, **assumptions):
        obj = Symbol.__xnew__(cls, variable_name, **assumptions)
        return obj

    def _hashable_content(self):
        return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name, self.optional)

Operation.register(Integral)
register_operation_iterator(Integral, lambda a: (a._args[0],) + a._args[1], lambda a: len(a._args))

Operation.register(Pow)
OneIdentityOperation.register(Pow)
register_operation_iterator(Pow, lambda a: a._args, lambda a: len(a._args))

def sympy_op_factory(old_operation, new_operands, variable_name):
     return type(old_operation)(*new_operands)

register_operation_factory(Basic, sympy_op_factory)

m_ = WC(1, True, 'm', 1)
x_ = WC(1, True, 'x')
x = Symbol('x')

subject = Integral(x, x)

pattern1 = Pattern(Integral(x_**m_, x_))
rule1 = ReplacementRule(pattern1, lambda x, m: x**(m + 1)/(m + 1))

rubi = ManyToOneReplacer()
rubi.add(rule1)

print(is_match(subject, pattern1))
print(next(match(subject, pattern1)))
print(rubi.replace(subject))

Output:

True
{m ↦ 1, x ↦ x}
Integral(x, x)
wheerd commented 7 years ago

This was an oversight on my part when providing the first implementation of WC. Because the __copy__ of Wildcard actually creates the new copy with an optional keyword argument and that is covered by the assumptions, there is no error message. But the optional information was not copied correctly. Renaming the default paramter to optional fixes this. If you prefer the other name, you could also override __copy__.

Here is a corrected version of the WC class:

class WC(Wildcard, Symbol):
    def __init__(self, min_length, fixed_size, variable_name=None, optional=None, **assumptions):
        Wildcard.__init__(self, min_length, fixed_size, variable_name, optional)

    def __new__(cls, min_length, fixed_size, variable_name=None, optional=None, **assumptions):
        cls._sanitize(assumptions, cls)
        return WC.__xnew__(cls, min_length, fixed_size, variable_name, optional, **assumptions)

    def __getnewargs__(self):
        return (self.min_length, self.fixed_size, self.variable_name, self.optional)

    @staticmethod
    def __xnew__(cls, min_length, fixed_size, variable_name=None, optional=None, **assumptions):
        obj = Symbol.__xnew__(cls, variable_name, **assumptions)
        return obj

    def _hashable_content(self):
        return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name, self.optional)
arihantparsoya commented 7 years ago

Thanks