slimgroup / InvertibleNetworks.jl

A Julia framework for invertible neural networks
MIT License
149 stars 21 forks source link

Release type constraints #34

Closed sethaxen closed 2 years ago

sethaxen commented 2 years ago

Throughout the package, it seems numerical arguments are restricted to Float32. This prevents usage of the code with custom real numeric types. Can this constraint be relaxed to <:Real, or even no constraint at all? Likewise, in a number of places array types are constraint to Array, so that the code cannot be used with custom array types. Can that likewise be relaxed? Fewer type constraints means greater potential for composeability with other packages.

mloubout commented 2 years ago

String typying, i.e Array{Float32} makes julia compiles much more efficient code than generic types since it know at compile time what input type to expect. In practice, we never use Float32 is the standard in ML which is why we went with it, Float16 isn't really fully supported accross the Julia codebase yet and Float64 is usually quite overkill. <:Real or un-typed signature would be very inefficient. What type are you trying to use? We could add the dispatch for this type if user need it.

Concerning Array most places use AbstractArray to allow for generic and custom types. We may have missed some instances of it we will fix it. There may also be places that require Array if the depends on an external function not supporting AbstractArray.

sethaxen commented 2 years ago

String typying, i.e Array{Float32} makes julia compiles much more efficient code than generic types since it know at compile time what input type to expect. In practice, we never use Float32 is the standard in ML which is why we went with it, Float16 isn't really fully supported accross the Julia codebase yet and Float64 is usually quite overkill. <:Real or un-typed signature would be very inefficient.

The usual way to handle this is to ensure that your functions preserve the types of the inputs, i.e. if a user passes in a Float32 or Float16, it should not get promoted to a Float64 anywhere, but to not constrain the types. So if a user provides a Float32, they get the same compiled efficient code, but they are no coerced into doing so. The type constraints don't make the code more efficient. Then you can document the Float32 performance recommendation. This is for example how Flux does it.

This is a great write-up on why types in Julia should be used only for dispatch, not for constraining arguments.

What type are you trying to use? We could add the dispatch for this type if user need it.

I'm not thinking of a specific type here, but it's not uncommon for a user to implement and use some custom numeric type like those in Unitful.jl or to try differentiating something using ForwardDiff.Dual, and if I'm writing a generic package that uses InvertibleNetworks under the hood, then even if I think they should only use Float32s, I don't want to unnecessarily constrain their types.

Concerning Array most places use AbstractArray to allow for generic and custom types. We may have missed some instances of it we will fix it.

Great!

mloubout commented 2 years ago

The usual way to handle this is to ensure that your functions preserve the types of the inputs,

Well that's not trivial when working with other packages and supporting multiple julia version. For example in julia 1.4, the compatible Zygote version doesn't always return the loss input type for backpropagaion. So unless we enforce the type the whole AD doesn't work for julia <1.6.

Nw we understand that it may be a bit too strong and we will look into making it better but at the time, due to some dependencies and some older julia version support we had to go that way.

I had a check for Array. Where did it cause you problems? There is very few instances of it, most being for Parameters or for dispatch against CuArray.

sethaxen commented 2 years ago

I had a check for Array. Where did it cause you problems? There is very few instances of it, most being for Parameters or for dispatch against CuArray.

I didn't hit any errors from this, but I noticed a few while looking through the code. For example, this definition: https://github.com/slimgroup/InvertibleNetworks.jl/blob/e22195dfb251794d1a686205dc20a20f13672281/src/utils/objective_functions.jl#L26

Well that's not trivial when working with other packages and supporting multiple julia version. For example in julia 1.4, the compatible Zygote version doesn't always return the loss input type for backpropagaion. So unless we enforce the type the whole AD doesn't work for julia <1.6.

Nw we understand that it may be a bit too strong and we will look into making it better but at the time, due to some dependencies and some older julia version support we had to go that way.

I understand that. Arbitrary type-promotion is a problem in Zygote. ChainRules v1 introduced machinery what would cause cotangents of Float32 to always be converted a Float32. When Zygote uses that (see https://github.com/FluxML/Zygote.jl/pull/1044), perhaps this will no longer be a concern.

mloubout commented 2 years ago

Likewise, in a number of places array types are constraint to Array,

This should be fixed after #37.

For Float32, we will see when we can switch to more type stable layers but will have to stick to float32 for now. WE'll let you know when we made progress on that front.

mloubout commented 2 years ago

34 should fix this issue, hope it helps.

EDIT: once done with bug fixes

mloubout commented 2 years ago

@sethaxen should be all fixed now, hope it helps your software.