Closed ksachdeva closed 3 years ago
Hi, thanks for your comment and feedback,
Let me start by highlighting that we believe TFP to be a brilliant library, that serves it users very well, and whose API has been well thought through and is rightfully the industry standard for working with probability distributions. So why did we build distrax? I hope the following will be helpful to understand our motivations:
First, TFP supports both JAX and TF as the backend, but this came at a cost: TFP components are based on a single "source of truth" written in TF, from which the JAX code is automatically generated via a script. This was a good decision from a software engineering perspective, as it minimises code maintenance and ensures consistency of behaviour across backends; indeed, we also agree that if, as it is often the case, a user is only interested in using or combining existing components, this solution has no noticeable downsides.
Our researchers however, often need to modify existing components to explore some new research ideas, and to iterate quickly over new ideas, they need to be able to do so in the language they are most familiar with (JAX). In this setting, the TFP approach to supporting both TF and JAX backends, requires researchers to read, understand and write TensorFlow code.
Our solution to this has been to embrace most of the TFP API, strive to ensure interoperability, but enable researchers to work with and modify JAX native distributions. We hope this strikes a good trade off between the specific needs of our researchers, and avoiding needlessly fragmenting a community that is already well served by TFP in many other aspects.
Hi,
I sincerely hope that my comments here will be purely seen in the spirit of constructive feedback.
I do not comment (not anymore!) when I see fragmentation and duplication of implementations but I felt (based on your README) that, unlike many engineers who systematically re-invent the wheel you still have been cognizant about it (by creating interop with TFP and clearly stating that it is not a replacement of TFP) and I appreciate you for that.
However, here is what I infer from your README -
One of the primary goals of this library is to be able to create custom distributions & bijectors easily. That said, indirectly it does imply that it is difficult to do it in TFP.
And somewhere based on my experience with TFP, I am inclined to agree with you. That said, wouldn't it be a better approach to make it easier by either modifying/augmenting the TFP APIs to achieve the said goal and/or enhancing its documentation & guides?
TFP developers have made an attempt to have both JAX and TF as the backend, and of course, trying to take care of multiple backends creates its own challenges but so far it seems that they have managed to do it. From a pure s/w engineering perspective, it is not something new to have multiple backends or design layers to target different systems.
As much as I appreciate JAX, it is very frustrating to see that the community believing in this technology is systematically keen on replicating the work (flax, Trax, haiku, neural-tangents, objax ... and many others) instead of choosing to collaborate and find common directions.
My apologies if this has come out to be harsh in any way.
Regards Kapil