Open stephankramer opened 1 year ago
Ah yes, good catch. I agree that the more generic approach would probably be most suitable.
Hi @acse-ej321, I see this has been assigned to you - have you started working on it maybe? :)
It's relevant for my problem where I have both a time-dependent constant a
and where I need to take care of @stephankramer's https://github.com/pyroteus/goalie/issues/55#issuecomment-1702280655. So to solve this, I need 1) something to pass simulation time to get_form
and 2) something to indicate that we are in the indicate_errors
environment now, so I can update that the solution at the correct timestep is passed.
I have a working solution locally which looks something like this:
def get_form(self):
def form(index, sols, **kwargs):
if "err_ind_time" in kwargs: # i.e. we are inside error_indicator
# update s which depend on sols
s = get_s(sols)
# update a which depends on time
time = kwargs["err_ind_time"]
a = get_a(time)
else: # we are not inside error_indicator and the user passes "s" and "a" Functions from get_solver
s = kwargs["s"]
a = kwargs["a"]
solver_fields = {"u":sols["u"], "h":sols["h"][0], "h_":sols["h"][1], "s":s, "a":a,}
F_h = get_form_for_h(**solver_fields)
if "err_ind_time" in _fields:
solver_fields["h"] = sols["h"][1] # This takes care of Stephan's comment (#55)
F_u = get_form_for_u(**solver_fields)
return {"u": F_u, "h": F_h}
return form
In indicate_errors
I then did
tp = self.time_partition
err_ind_time = tp.subintervals[i][0] + j * tp.timesteps[i]
forms = mesh_seq_e.form(i, mapping, err_ind_time=err_ind_time)
and I have to call this for all exported timesteps, so I rebuild the forms (since "auxiliary fields" are not transferred).
What do you think? I think it's a good solution in general, so I would like to get it up on the main branch before I make a release and cite goalie. But I'm happy to keep the issue open and explore other solutions that Stephan mentions!
Edit: @jwallwork23 I actually see now that I can make a release for a branch other than main too. Is that a good idea?
@ddundo I have a local edit following the first bullet point in @stephankramer's https://github.com/pyroteus/goalie/issues/55#issuecomment-1702280655 doing a similar hack in a slightly different way.
I define t in the my get_solver
and update it in the time stepper:
t = fd.Constant(t_start)
...
while t_ < t_end - 1.0e-05:
fd.solve(F == 0, u,bcs=bcs, ad_block_tag="u")
if self.qoi_type == "time_integrated":
self.J += qoi(t)
u_.assign(u)
t_ += dt
t.assign(t_)
return solution_map
And in my get form
I pass t as a Constant from get_solver
:
def get_form(self):
def form(index, solutions,t,field="u"):
And like you in go_mesh_seq.py
, I did something similar to indicate errors
:
for f, fs_e in enriched_spaces.items():
u[f] = Function(fs_e)
u_[f] = Function(fs_e)
mapping[f] = (u[f], u_[f])
u_star[f] = Function(fs_e)
u_star_next[f] = Function(fs_e)
u_star_e[f] = Function(fs_e)
# Get forms for each equation in enriched space
t_ = Constant(1.0) # ej321
t0_ = mesh_seq_e.time_partition.subintervals[i][0] # ej321
dt_ = mesh_seq_e.time_partition.timesteps[i] # ej321
# ej321 - adding try/except block to hack time dependent form
def num_of_args(*args):
return len(args)
if num_of_args(mesh_seq_e.form) == 3:
forms = mesh_seq_e.form(i, mapping, t_) # ej321
else:
forms = mesh_seq_e.form(i, mapping) # ej321 - the original
if not isinstance(forms, dict):
raise TypeError(
"The function defined by get_form should return a dictionary"
f", not type '{type(forms)}'."
)
# Loop over each strongly coupled field
for f in self.fields:
# Loop over each timestep
for j in range(len(self.solutions[f]["forward"][i])):
t_.assign(t0_ + j * dt_) # ej321
# Update fields
transfer(self.solutions[f][FWD][i][j], u[f])
I think your workaround is closer to a the more general case Stephan was suggesting. I would be happy with whatever solution you and @jwallwork23 come up with - please feel free to switch ownership
Thanks @acse-ej321! Nice idea with avoiding changing the interface! I think we essentially do the same thing in get_solver too, but I define a NonlinearVariationalSolver outside the for loop. So I don't need to pass time to get_form unless we are in indicate_errors.
I will push this later and get an example up so we can discuss :)
I actually see now that I can make a release for a branch other than main too. Is that a good idea?
Yeah it is fairly common practice to create releases of branches, although try not to do this for branches that are very far from main
.
Another idea:
Give your Functions names when you set up the form/solver:
f = Function(space, name="forcing1")
Then stash them in a dict:
def update_forcing1(t):
...
{"forcing1: update_forcing1}
which can then be used by Goalie so it knows how to update the Function in time. Similarly for constants, choosing space = FunctionSpace(mesh, "R", 0)
.
In https://github.com/mesh-adaptation/goalie/pull/137#issuecomment-2254171102 we agreed to take this approach:
Here is the idea that I had how we could do this automatically, without requiring extra effort from the user.
We could identify these changing fields and extract them from the
coefficients
method of the variational form. I made a small gist to help demonstrate: https://gist.github.com/ddundo/92fc7d9fd24471a37c5a903ddd035554 - I labelled the important part towards the end. Here it's all inside the "user code", but it could all be done really nicely in the goalie code. I have a clear idea how to do this nicely :)So we could extract and save a copy of these fields at each exported timestep, before then using them in
indicate_errors
. This is potentially memory-intensive, so I also suggest that we do the following. At the moment, goalie is set up so that we solve the forward and adjoint problem over all subintervals, and only then do we callindicate_errors
. I think that we should instead callindicate_errors
after each individual subinterval. That way:
- we only need to store these changing fields for the export timesteps in the subinterval, rather than over the whole time interval
- similarly for
forward_old
,adjoint
andadjoint_next
fields which we are now storing over the whole time interval: unless we want to return them (could add a kwarg for it), we could drop them after each call toindicate_errors
and free up memory after each subinterval
This came up looking at the "bubble shear" case which has a time-dependent advective velocity which is implemented through a symbolic expression which includes a Constant
t
which is updated in the forward model every timestep. However when doing goal-oriented enriched DWR this Constant is not being updated so the residual is being evaluated at the wrong time level.Discussing this with @acse-ej321 we came up with various possible solutions
t
toform()
to which you would provide the Constant t that will be used in the form. When pyroteus then evaluates the indicator expressions it would provide its own instance of a Constant t, and set it to the right value in the inner loop. Downside is that this changes the interface, where allform()
callbacks in existing demos need to take and argument tt
as a Constant on the mesh sequence or time partition, where the convention would be that theform()
callback should be using that for time, and pyroteus would take care of setting it to the right value (although it would still be the user having to increment it in thesolver()
callback)The latter option is an attempt to deal with the more general case where entire fields might be time-dependently prescribed (and not in the form of a UFL-symbolic expression dependent on t), e.g. a prescribed (boundary) forcing field that is interpolated every timestep.