bessagroup / f3dasm

Framework for Data-Driven Design & Analysis of Structures & Materials (F3DASM)
https://f3dasm.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
115 stars 30 forks source link

Creating variants of the _Data object to allow gradient tracking and avoid unnecessary casting to numpy #208

Open mpvanderschelling opened 1 year ago

mpvanderschelling commented 1 year ago

The problem

At the moment, the ExperimentData object consists of input_data, output_data, jobs, and domain. These are all custom objects that are private (except the Domain) object:

Focussing on the input_data, any data (e.g. pd.DataFrame, numpy array, csv-file) that is given to ExperimentData will be converted to the _Data object. The _Data object back-end is pandas. This means that internally the data will be casted to something that is compatible with pandas datastorage; numpy

For automated differentiation tools this might be problematic, since the gradient needs to be 'tracked'. Any casting to numpy will break the chain.

In v1.4.3, we are using autograd.numpy to track these gradients and for tensorflow optimizers a conversion function will provide the 'custom gradient' so that it works with casting to numpy.

Additionally, optimized libraries will experience overhead costs when doing this conversion back an forth between e.g. jax arrays and numpy arrays

Proposal

Because the ExperimentData object is only depending on _Data and not directly on a pandas DataFrame, we can create a variant of the _Data object for any underlying datatype (e.g. a dictionary of tensorflow tensors). We need to implement all the methods of the _Data object for that particular datatype.

Then, the user can choose upon creation of the ExperimentData object if they want to use the 'normal' backend (e.g. pandas/numpy) or any specialized backend (e.g. tensorflow, pytorch, jax).

This could also be inferred automatically when providing initial input_data.

First steps

This issue will investigate if we can implement this by starting with a _Data variant that works with an jax dataformat.

mpvanderschelling commented 1 year ago

@SNMS95 ; I created an issue that might be relevant for your application with f3dasm. Feel free to add things here that might address this issue!