firedrakeproject / tsfc

Two-stage form compiler
Other
15 stars 24 forks source link

gem: make ComponentTensor broadcast expression for missing indices #280

Closed ksagiyam closed 2 years ago

ksagiyam commented 2 years ago

Attempt to fix Firedrake tests failure https://github.com/firedrakeproject/firedrake/actions/runs/2767636395

Firedrake tests: https://github.com/firedrakeproject/firedrake/pull/2511

wence- commented 2 years ago

This is usually the wrong approach for broadcasting in gem, can you pull part why this happened?

ksagiyam commented 2 years ago

So when we pass Zero to ComponentTensor, the constructor just returns Zero of the right shape, but, if we pass Literal(array_of_zeros), this check https://github.com/firedrakeproject/tsfc/blob/3aff5744614afa0cb3dee0436db11044b24a98dd/gem/gem.py#L664 fails. So I naively added another path to handle broadcasting.

ksagiyam commented 2 years ago

Just noticed that there was a typo in the package name: https://github.com/firedrakeproject/firedrake/pull/2406/files, so we overlooked.

wence- commented 2 years ago

So when we pass Zero to ComponentTensor, the constructor just returns Zero of the right shape, but, if we pass Literal(array_of_zeros), this check

So why is that literal array with non-zero shape not indexed with free indices that we're trying to turn into shape with ComponentTensor ?

ksagiyam commented 2 years ago

I think Firedrake MFE is something like this:

from firedrake import *

base = UnitIntervalMesh(1)
mesh = ExtrudedMesh(base, 1)
hel = FiniteElement("DG", "interval", 0)
vel = FiniteElement("CG", "interval", 1)
prod = HDiv(TensorProductElement(hel, vel))
V = FunctionSpace(mesh, prod)
v = TestFunction(V)
assemble(inner(as_tensor([[1, 1],[1, 1]]), grad(v))*dx)

This example fails with:

  File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 629, in translate_argument
    table = ctx.entity_selector(callback, mt.restriction)
  File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 100, in entity_selector
    return callback(self.entity_ids[0])
  File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 617, in callback
    finat_dict = ctx.basis_evaluation(element, mt, entity_id)
  File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 289, in basis_evaluation
    return finat_element.basis_evaluation(mt.local_derivatives,
  File "/home/ksagiyam/current/firedrake/src/FInAT/finat/hdivcurl.py", line 71, in basis_evaluation
    return self._transform_evaluation(core_eval)
  File "/home/ksagiyam/current/firedrake/src/FInAT/finat/hdivcurl.py", line 66, in _transform_evaluation
    return {alpha: promote(table)
  File "/home/ksagiyam/current/firedrake/src/FInAT/finat/hdivcurl.py", line 66, in <dictcomp>
    return {alpha: promote(table)
  File "/home/ksagiyam/current/firedrake/src/FInAT/finat/hdivcurl.py", line 64, in promote
    return gem.ComponentTensor(gem.Indexed(u, zeta), beta + zeta)
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/gem.py", line 48, in __call__
    obj = super(NodeMeta, self).__call__(*args, **kwargs)
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/gem.py", line 664, in __new__
    assert set(multiindex) <= set(expression.free_indices)
AssertionError

The values of objects around https://github.com/FInAT/FInAT/blob/a5cb72f607627140b519d02a63213979c2b53c46/finat/hdivcurl.py#L61 right before ComponentTensor complains are:

beta                           : (Index(86), Index(87))
beta[0].extent                 : 1
beta[1].extent                 : 2
zeta                           : (Index(88),)
zeta[0].extent                 : 2
table                          : Zero((1, 2))
v                              : Zero(())
v.shape                        : ()
v.free_indices                 : ()
self.transform(v)              : [Zero(()), Zero(())]
u                              : Literal(array([0., 0.]))
u.shape                        : (2,)
u.free_indices                 : ()
Indexed(u, zeta)               : Indexed(Literal(array([0., 0.])), (Index(88),))

So, when we do v = gem.partial_indexed(table, beta), we loose indices beta, which ComponentTensor eventually complains.

wence- commented 2 years ago

So I thought I had changed things such that the basis evaluation call wouldn't return a literal Zero, so that select_expression would work. Did I constant fold the zeros too early? If table were a gem.Literal this would probably work?

ksagiyam commented 2 years ago

If we make the table gem.Literal (maybe we need to change this: https://github.com/FInAT/FInAT/blob/a5cb72f607627140b519d02a63213979c2b53c46/finat/fiat_elements.py#L139), the example runs longer, but we then hit:

File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 629, in translate_argument
    table = ctx.entity_selector(callback, mt.restriction)
  File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 100, in entity_selector
    return callback(self.entity_ids[0])
  File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 624, in callback
    square = fiat_to_ufl(filtered_dict, mt.local_derivatives)
  File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 606, in fiat_to_ufl
    tensor, = constant_fold_zero([tensor])
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/optimise.py", line 184, in constant_fold_zero
    return [mapper(e) for e in exprs]
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/optimise.py", line 184, in <listcomp>
    return [mapper(e) for e in exprs]
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 204, in __call__
    result = self.function(node, self)
  File "/usr/lib/python3.8/functools.py", line 875, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 235, in reuse_if_untouched
    new_children = list(map(self, node.children))
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 204, in __call__
    result = self.function(node, self)
  File "/usr/lib/python3.8/functools.py", line 875, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 235, in reuse_if_untouched
    new_children = list(map(self, node.children))
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 204, in __call__
    result = self.function(node, self)
  File "/usr/lib/python3.8/functools.py", line 875, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 235, in reuse_if_untouched
    new_children = list(map(self, node.children))
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 204, in __call__
    result = self.function(node, self)
  File "/usr/lib/python3.8/functools.py", line 875, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 239, in reuse_if_untouched
    return node.reconstruct(*new_children)
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 53, in reconstruct
    return type(self)(*self._cons_args(args))
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/gem.py", line 48, in __call__
    obj = super(NodeMeta, self).__call__(*args, **kwargs)
  File "/home/ksagiyam/current/firedrake/src/tsfc/gem/gem.py", line 664, in __new__
    assert set(multiindex) <= set(expression.free_indices)
AssertionError

I think I need to look at the details.

Maybe this is not directly related to this error, but translate_argument calls ctx.entity_selector, which, in some cases, calls gem.select_expression; but it seems that the callback defined inside translate_argument is actually called before gem.select_expression is called and this callback calls fiat_to_ufl, which calls constant_fold_zero.

wence- commented 2 years ago

OK, I guess this makes sense. If the constant folding can be pushed later, I guess that is fine (I tried to keep the change localised, but obviously this is an example problematic case). The other approach is to make select_expression smarter, but I ran out of brains to figure that out at the time.

wence- commented 2 years ago

I think that #275 should probably be reverted until this issue is fixed (so that firedrake is not broken).

ksagiyam commented 2 years ago

Yes. constant_fold_zero() being called before select_expression() seems independent of the error posted above, but I cannot figure out why.

How do we properly revert specific commits? Could you do that?

ksagiyam commented 2 years ago

https://github.com/firedrakeproject/tsfc/pull/281

ksagiyam commented 2 years ago

Another observation:

If we merely remove constant_fold_zero() in fiat_to_ufl(), tests/extrusion/test_cylinder.py::test_betti2_cylinder hangs in complex mode.

wence- commented 2 years ago

If we merely remove constant_fold_zero() in fiat_to_ufl(), tests/extrusion/test_cylinder.py::test_betti2_cylinder hangs in complex mode.

I bet that is because no zero-simplification happens and the kernel takes forever in loopy?

ksagiyam commented 2 years ago

I think I fixed it; please see https://github.com/firedrakeproject/tsfc/pull/282.