Closed ksagiyam closed 2 years ago
This is usually the wrong approach for broadcasting in gem, can you pull part why this happened?
So when we pass Zero
to ComponentTensor
, the constructor just returns Zero
of the right shape, but, if we pass Literal(array_of_zeros)
, this check https://github.com/firedrakeproject/tsfc/blob/3aff5744614afa0cb3dee0436db11044b24a98dd/gem/gem.py#L664 fails. So I naively added another path to handle broadcasting.
Just noticed that there was a typo in the package name: https://github.com/firedrakeproject/firedrake/pull/2406/files, so we overlooked.
So when we pass
Zero
toComponentTensor
, the constructor just returnsZero
of the right shape, but, if we passLiteral(array_of_zeros)
, this check
So why is that literal array with non-zero shape not indexed with free indices that we're trying to turn into shape with ComponentTensor
?
I think Firedrake MFE is something like this:
from firedrake import *
base = UnitIntervalMesh(1)
mesh = ExtrudedMesh(base, 1)
hel = FiniteElement("DG", "interval", 0)
vel = FiniteElement("CG", "interval", 1)
prod = HDiv(TensorProductElement(hel, vel))
V = FunctionSpace(mesh, prod)
v = TestFunction(V)
assemble(inner(as_tensor([[1, 1],[1, 1]]), grad(v))*dx)
This example fails with:
File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 629, in translate_argument
table = ctx.entity_selector(callback, mt.restriction)
File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 100, in entity_selector
return callback(self.entity_ids[0])
File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 617, in callback
finat_dict = ctx.basis_evaluation(element, mt, entity_id)
File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 289, in basis_evaluation
return finat_element.basis_evaluation(mt.local_derivatives,
File "/home/ksagiyam/current/firedrake/src/FInAT/finat/hdivcurl.py", line 71, in basis_evaluation
return self._transform_evaluation(core_eval)
File "/home/ksagiyam/current/firedrake/src/FInAT/finat/hdivcurl.py", line 66, in _transform_evaluation
return {alpha: promote(table)
File "/home/ksagiyam/current/firedrake/src/FInAT/finat/hdivcurl.py", line 66, in <dictcomp>
return {alpha: promote(table)
File "/home/ksagiyam/current/firedrake/src/FInAT/finat/hdivcurl.py", line 64, in promote
return gem.ComponentTensor(gem.Indexed(u, zeta), beta + zeta)
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/gem.py", line 48, in __call__
obj = super(NodeMeta, self).__call__(*args, **kwargs)
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/gem.py", line 664, in __new__
assert set(multiindex) <= set(expression.free_indices)
AssertionError
The values of objects around https://github.com/FInAT/FInAT/blob/a5cb72f607627140b519d02a63213979c2b53c46/finat/hdivcurl.py#L61 right before ComponentTensor
complains are:
beta : (Index(86), Index(87))
beta[0].extent : 1
beta[1].extent : 2
zeta : (Index(88),)
zeta[0].extent : 2
table : Zero((1, 2))
v : Zero(())
v.shape : ()
v.free_indices : ()
self.transform(v) : [Zero(()), Zero(())]
u : Literal(array([0., 0.]))
u.shape : (2,)
u.free_indices : ()
Indexed(u, zeta) : Indexed(Literal(array([0., 0.])), (Index(88),))
So, when we do v = gem.partial_indexed(table, beta)
, we loose indices beta
, which ComponentTensor
eventually complains.
So I thought I had changed things such that the basis evaluation call wouldn't return a literal Zero
, so that select_expression would work. Did I constant fold the zeros too early? If table
were a gem.Literal
this would probably work?
If we make the table
gem.Literal
(maybe we need to change this: https://github.com/FInAT/FInAT/blob/a5cb72f607627140b519d02a63213979c2b53c46/finat/fiat_elements.py#L139), the example runs longer, but we then hit:
File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 629, in translate_argument
table = ctx.entity_selector(callback, mt.restriction)
File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 100, in entity_selector
return callback(self.entity_ids[0])
File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 624, in callback
square = fiat_to_ufl(filtered_dict, mt.local_derivatives)
File "/home/ksagiyam/current/firedrake/src/tsfc/tsfc/fem.py", line 606, in fiat_to_ufl
tensor, = constant_fold_zero([tensor])
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/optimise.py", line 184, in constant_fold_zero
return [mapper(e) for e in exprs]
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/optimise.py", line 184, in <listcomp>
return [mapper(e) for e in exprs]
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 204, in __call__
result = self.function(node, self)
File "/usr/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 235, in reuse_if_untouched
new_children = list(map(self, node.children))
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 204, in __call__
result = self.function(node, self)
File "/usr/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 235, in reuse_if_untouched
new_children = list(map(self, node.children))
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 204, in __call__
result = self.function(node, self)
File "/usr/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 235, in reuse_if_untouched
new_children = list(map(self, node.children))
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 204, in __call__
result = self.function(node, self)
File "/usr/lib/python3.8/functools.py", line 875, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 239, in reuse_if_untouched
return node.reconstruct(*new_children)
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/node.py", line 53, in reconstruct
return type(self)(*self._cons_args(args))
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/gem.py", line 48, in __call__
obj = super(NodeMeta, self).__call__(*args, **kwargs)
File "/home/ksagiyam/current/firedrake/src/tsfc/gem/gem.py", line 664, in __new__
assert set(multiindex) <= set(expression.free_indices)
AssertionError
I think I need to look at the details.
Maybe this is not directly related to this error, but translate_argument
calls ctx.entity_selector
, which, in some cases, calls gem.select_expression
; but it seems that the callback
defined inside translate_argument
is actually called before gem.select_expression
is called and this callback
calls fiat_to_ufl
, which calls constant_fold_zero
.
OK, I guess this makes sense. If the constant folding can be pushed later, I guess that is fine (I tried to keep the change localised, but obviously this is an example problematic case). The other approach is to make select_expression
smarter, but I ran out of brains to figure that out at the time.
I think that #275 should probably be reverted until this issue is fixed (so that firedrake is not broken).
Yes. constant_fold_zero()
being called before select_expression()
seems independent of the error posted above, but I cannot figure out why.
How do we properly revert specific commits? Could you do that?
Another observation:
If we merely remove constant_fold_zero()
in fiat_to_ufl()
, tests/extrusion/test_cylinder.py::test_betti2_cylinder
hangs in complex mode.
If we merely remove
constant_fold_zero()
infiat_to_ufl()
,tests/extrusion/test_cylinder.py::test_betti2_cylinder
hangs in complex mode.
I bet that is because no zero-simplification happens and the kernel takes forever in loopy?
I think I fixed it; please see https://github.com/firedrakeproject/tsfc/pull/282.
Attempt to fix Firedrake tests failure https://github.com/firedrakeproject/firedrake/actions/runs/2767636395
Firedrake tests: https://github.com/firedrakeproject/firedrake/pull/2511