aai-institute / pyDVL

pyDVL is a library of stable implementations of algorithms for data valuation and influence function computation
https://pydvl.org
GNU Lesser General Public License v3.0
89 stars 9 forks source link

Use `pytree` operations to simplify code in influence #592

Open janosg opened 3 weeks ago

janosg commented 3 weeks ago

What are pytrees?

While reviewing #582, I got the feeling that we have quite a bit of code that re-implements the typical pytree operations as implemented in jax, optree or my instructional library pybaum.

The pytree operations solve the following problem in a general way: In math notation, we often need one-dimensional vectors but in code we want to represent things in richer data formats (e.g. dictionaries of arbitrary dimensional arrays). Prime examples are the parameters of a neural network. Pytrees are not a specific type themselves. For our purposes, any (nested) container of tensors or numbers is a pytree and the pytree operations are defined on it.

The most important operations are:

A full list is here

What is considered a leave depends on the registry of containers, which can be extended by users. For example, if torch tensors are not registered containers, tree_flatten would convert a nested dict of tensors into a list of tensors. If torch tensors are registered, it would flatten a nested dict of tensors into a list of numbers.

pydvl code that could be removed or simplified

Advantage of using pytree operations

Drawbacks of using pytree operations