mesh-adaptation / goalie

Goal-oriented error estimation and mesh adaptation for finite element problems solved using Firedrake
Other
1 stars 1 forks source link

How to deal with time-dependent Constants and prescribed fields in forms #73

Open stephankramer opened 1 year ago

stephankramer commented 1 year ago

This came up looking at the "bubble shear" case which has a time-dependent advective velocity which is implemented through a symbolic expression which includes a Constant t which is updated in the forward model every timestep. However when doing goal-oriented enriched DWR this Constant is not being updated so the residual is being evaluated at the wrong time level.

Discussing this with @acse-ej321 we came up with various possible solutions

The latter option is an attempt to deal with the more general case where entire fields might be time-dependently prescribed (and not in the form of a UFL-symbolic expression dependent on t), e.g. a prescribed (boundary) forcing field that is interpolated every timestep.

jwallwork23 commented 1 year ago

Ah yes, good catch. I agree that the more generic approach would probably be most suitable.

ddundo commented 6 months ago

Hi @acse-ej321, I see this has been assigned to you - have you started working on it maybe? :)

It's relevant for my problem where I have both a time-dependent constant a and where I need to take care of @stephankramer's https://github.com/pyroteus/goalie/issues/55#issuecomment-1702280655. So to solve this, I need 1) something to pass simulation time to get_form and 2) something to indicate that we are in the indicate_errors environment now, so I can update that the solution at the correct timestep is passed.

I have a working solution locally which looks something like this:

def get_form(self):
  def form(index, sols, **kwargs):
      if "err_ind_time" in kwargs:  # i.e. we are inside error_indicator
          # update s which depend on sols
          s = get_s(sols)

          # update a which depends on time
          time = kwargs["err_ind_time"]
          a = get_a(time)

      else:  # we are not inside error_indicator and the user passes "s" and "a" Functions from get_solver
          s = kwargs["s"]
          a = kwargs["a"]

      solver_fields = {"u":sols["u"], "h":sols["h"][0], "h_":sols["h"][1], "s":s, "a":a,}

      F_h = get_form_for_h(**solver_fields)

      if "err_ind_time" in _fields:
          solver_fields["h"] = sols["h"][1]    # This takes care of Stephan's comment (#55)
      F_u = get_form_for_u(**solver_fields)

      return {"u": F_u, "h": F_h}
return form

In indicate_errors I then did

tp = self.time_partition
err_ind_time = tp.subintervals[i][0] + j * tp.timesteps[i]
forms = mesh_seq_e.form(i, mapping, err_ind_time=err_ind_time)

and I have to call this for all exported timesteps, so I rebuild the forms (since "auxiliary fields" are not transferred).

What do you think? I think it's a good solution in general, so I would like to get it up on the main branch before I make a release and cite goalie. But I'm happy to keep the issue open and explore other solutions that Stephan mentions!

Edit: @jwallwork23 I actually see now that I can make a release for a branch other than main too. Is that a good idea?

acse-ej321 commented 6 months ago

@ddundo I have a local edit following the first bullet point in @stephankramer's https://github.com/pyroteus/goalie/issues/55#issuecomment-1702280655 doing a similar hack in a slightly different way.

I define t in the my get_solver and update it in the time stepper:

t = fd.Constant(t_start)
...
        while t_ < t_end - 1.0e-05:
              fd.solve(F == 0, u,bcs=bcs, ad_block_tag="u")
              if self.qoi_type == "time_integrated":
                  self.J += qoi(t)
              u_.assign(u)
              t_ += dt
              t.assign(t_)
          return solution_map

And in my get form I pass t as a Constant from get_solver:

def get_form(self):
        def form(index, solutions,t,field="u"):

And like you in go_mesh_seq.py, I did something similar to indicate errors:

            for f, fs_e in enriched_spaces.items():
                u[f] = Function(fs_e)
                u_[f] = Function(fs_e)
                mapping[f] = (u[f], u_[f])
                u_star[f] = Function(fs_e)
                u_star_next[f] = Function(fs_e)
                u_star_e[f] = Function(fs_e)

            # Get forms for each equation in enriched space
            t_ = Constant(1.0)  # ej321
            t0_ = mesh_seq_e.time_partition.subintervals[i][0]  # ej321
            dt_ = mesh_seq_e.time_partition.timesteps[i]  # ej321

            # ej321 - adding try/except block to hack time dependent form
            def num_of_args(*args):
                return len(args)

            if num_of_args(mesh_seq_e.form) == 3:
                forms = mesh_seq_e.form(i, mapping, t_)  # ej321
            else:
                forms = mesh_seq_e.form(i, mapping)  # ej321 - the original

            if not isinstance(forms, dict):
                raise TypeError(
                    "The function defined by get_form should return a dictionary"
                    f", not type '{type(forms)}'."
                )

            # Loop over each strongly coupled field
            for f in self.fields:
                # Loop over each timestep
                for j in range(len(self.solutions[f]["forward"][i])):

                    t_.assign(t0_ + j * dt_)  # ej321
                    # Update fields
                    transfer(self.solutions[f][FWD][i][j], u[f])

I think your workaround is closer to a the more general case Stephan was suggesting. I would be happy with whatever solution you and @jwallwork23 come up with - please feel free to switch ownership

ddundo commented 6 months ago

Thanks @acse-ej321! Nice idea with avoiding changing the interface! I think we essentially do the same thing in get_solver too, but I define a NonlinearVariationalSolver outside the for loop. So I don't need to pass time to get_form unless we are in indicate_errors.

I will push this later and get an example up so we can discuss :)

jwallwork23 commented 6 months ago

I actually see now that I can make a release for a branch other than main too. Is that a good idea?

Yeah it is fairly common practice to create releases of branches, although try not to do this for branches that are very far from main.

jwallwork23 commented 1 month ago

Another idea:

Give your Functions names when you set up the form/solver:

f = Function(space, name="forcing1")

Then stash them in a dict:

def update_forcing1(t):
    ...

{"forcing1: update_forcing1}

which can then be used by Goalie so it knows how to update the Function in time. Similarly for constants, choosing space = FunctionSpace(mesh, "R", 0).

ddundo commented 2 weeks ago

In https://github.com/mesh-adaptation/goalie/pull/137#issuecomment-2254171102 we agreed to take this approach:

Here is the idea that I had how we could do this automatically, without requiring extra effort from the user.

We could identify these changing fields and extract them from the coefficients method of the variational form. I made a small gist to help demonstrate: https://gist.github.com/ddundo/92fc7d9fd24471a37c5a903ddd035554 - I labelled the important part towards the end. Here it's all inside the "user code", but it could all be done really nicely in the goalie code. I have a clear idea how to do this nicely :)

So we could extract and save a copy of these fields at each exported timestep, before then using them in indicate_errors. This is potentially memory-intensive, so I also suggest that we do the following. At the moment, goalie is set up so that we solve the forward and adjoint problem over all subintervals, and only then do we call indicate_errors. I think that we should instead call indicate_errors after each individual subinterval. That way:

  • we only need to store these changing fields for the export timesteps in the subinterval, rather than over the whole time interval
  • similarly for forward_old, adjoint and adjoint_next fields which we are now storing over the whole time interval: unless we want to return them (could add a kwarg for it), we could drop them after each call to indicate_errors and free up memory after each subinterval