firedrakeproject / firedrake

Firedrake is an automated system for the portable solution of partial differential equations using the finite element method (FEM)
https://firedrakeproject.org
Other
512 stars 160 forks source link

Assign erronously caches index relabelling between assignments #1855

Closed stephankramer closed 4 years ago

stephankramer commented 4 years ago

Consider the following code:

from firedrake import *
mesh=UnitSquareMesh(1,1)
V=VectorFunctionSpace(mesh, "CG", 1)
W=TensorFunctionSpace(mesh, "CG", 1)
u=Function(V)
w=Function(W)

u.assign(as_vector(2*u[i], i))
u.assign(as_vector(2*u[j], j))
w.assign(as_tensor(2*w[i,j], (i,j)))

This fails somewhere deep in UFL. The reason for this is that the Assign object stores a IndexRelabeller as a class attribute. The IndexRelabeller is reset calling _reset() in the slow_key() method of Assign, but that invalidates the index_cache that is stored on the IndexRelabeller. In the code above, in the first vector assignment i=Index(0) gets relabeled to Index(0), in the second assignment j=Index(1) is also relabeled to Index(0). Then because the relabellings are still in the index_cache in the final tensor assignment it maps both i and j to Index(0) causing UFL to get upset.

Should the index_cache also be reset in IndexRelabeller._reset()? Why does the IndexRelabeller have to be a class attribute - I probably don't understand this MultiFunction.reuse_if_untouched business well enough....

wence- commented 4 years ago

Oh yeah, probably this?

diff --git a/firedrake/assemble_expressions.py b/firedrake/assemble_expressions.py
index 564b15ba..74987717 100644
--- a/firedrake/assemble_expressions.py
+++ b/firedrake/assemble_expressions.py
@@ -108,10 +108,10 @@ class IndexRelabeller(MultiFunction):
     def __init__(self):
         super().__init__()
         self._reset()
-        self.index_cache = defaultdict(lambda: Index(next(self.count)))

     def _reset(self):
-        self.count = itertools.count()
+        count = itertools.count()
+        self.index_cache = defaultdict(lambda: Index(next(count)))

     expr = MultiFunction.reuse_if_untouched

The idea is that UFL expressions should be invariant under renaming indices (so the first two examples should not generate different kernels). But yes, the index cache should be reset. It's a class attribute because MultiFunction.__init__ is expensive and I was micro-optimising.

stephankramer commented 4 years ago

ok, I'll PR this - do you want a test?

wence- commented 4 years ago

Yes please, thanks.

Can you also confirm that the caching continues to work. Something like:

u = Function(V)
v = Function(V)
w = Function(W)

from firedrake.assemble_expressions import Assign, evaluate_expression

i, j = indices(2)
exprA = Assign(u, as_vector(2*u[i], i))
exprB = Assign(u, as_vector(2*u[j], j))

assert len(u._expression_cache) == 0

evaluate_expression(exprA)

assert exprA.fast_key in u._expression_cache
assert exprA.slow_key in u._expression_cache
assert exprB.fast_key not in u._expression_cache
assert exprB.slow_key in u._expression_cache

evaluate_expression(exprB)
assert exprB.fast_key in u._expression_cache
assert exprA.fast_key in u._expression_cache

assert exprB.slow_key == exprA.slow_key

assert len(u._expression_cache) == 3

Needs this slightly expanded patch:

diff --git a/firedrake/assemble_expressions.py b/firedrake/assemble_expressions.py
index 564b15ba..8f846ac0 100644
--- a/firedrake/assemble_expressions.py
+++ b/firedrake/assemble_expressions.py
@@ -108,10 +108,10 @@ class IndexRelabeller(MultiFunction):
     def __init__(self):
         super().__init__()
         self._reset()
-        self.index_cache = defaultdict(lambda: Index(next(self.count)))

     def _reset(self):
-        self.count = itertools.count()
+        count = itertools.count()
+        self.index_cache = defaultdict(lambda: Index(next(count)))

     expr = MultiFunction.reuse_if_untouched

@@ -466,6 +466,7 @@ def evaluate_expression(expr, subset=None):
             slow_key = expr.slow_key
             try:
                 arguments = cache[slow_key]
+                cache[fast_key] = arguments
             except KeyError:
                 arguments = None
         if arguments is not None: