ACEsuit / Polynomials4ML.jl

Polynomials for ML: fast evaluation, batching, differentiation
MIT License
12 stars 5 forks source link

WIP: CachedArray in Linear layer and migrating to `unwrap` from `parent` #58

Closed CheukHinHoJerry closed 1 year ago

CheukHinHoJerry commented 1 year ago

Idea is to release the input (possibly cached array) with something like this:

(l::LinearLayer{false})(x::AbstractMatrix, ps, st) = begin
   out = acquire!(st.pool, :bA, (size(x, 1), l.out_dim), eltype(x)); 
   mul!(out, parent(x), transpose(ps.W)); release!(x);
   return out, st
end

It also acquire a CachedArray that can be released later on. We also have the use_cache option as in PolyLuxLayer for the user to determined whether to use tmp or cache.

cortner commented 1 year ago

what's the status of this PR?

CheukHinHoJerry commented 1 year ago

I think this is basically done. I just want to finalise it after we have the unwrap function in ObjectPools.jl since D told me that he had some problem with the Adjoint. Please give me some extra hours tonight to finish it off.

cortner commented 1 year ago

Then you can take care of #59 at the same time? I wonder though whether this is matter of a separate PR to do this. And then merge the current one afterward?

CheukHinHoJerry commented 1 year ago

Yes - thank you for raising that. I will take care of that too.

CheukHinHoJerry commented 1 year ago

hmm - not sure why only the TmpArray interface CI test on julia nightly is failing and it even output warning from deprecating parent from tests of LinearLayer. Otherwise I think it should be ready to merge. Will take a look later today.

CheukHinHoJerry commented 1 year ago

Not sure why still - @cortner could you please take a look when you have time?

cortner commented 1 year ago

will try now.

cortner commented 1 year ago

odd - everything passes for 1.10-alpha1 on my machine i.e. I can't reproduce it locally. This basically means I don't want to waste our time on it :)

edit : on rerunning the CI the tests passed ok. odd.

cortner commented 1 year ago

@CheukHinHoJerry -- I've tested, reviewed and am ready to merge it. Are you comfortable for me to register this as a batch release since it only changes the internals but not the interface? But if you strongly prefer we can treat it as breaking.

CheukHinHoJerry commented 1 year ago

Yes I think a batch release is good enough. Thank you.

Update: take another look to the CI it still has the warning for using parent. But I also prefer to merge it anyways.

cortner commented 1 year ago

@CheukHinHoJerry -- I can now repduce those warnings on J1.10-beta1, but they are gone on nightly. So probably an intermediate bug. I'll have to explore this a bit more before merging and registering.

CheukHinHoJerry commented 1 year ago

Thanks a lot for letting me know. I think this is a bug from ObjectPools.jl since I failed the test locally from there as I mentioned to you last time.

cortner commented 1 year ago

thanks, I'll test this. Will also test on the latest backport 1.10 branch

cortner commented 1 year ago

I actually cannot reproduce any failures on any version of 1.9.x and I wonder whether there is any point losing sleep over this in an unreleased version of Julia. Could be a bug in Pkg for all we know.

CheukHinHoJerry commented 1 year ago

I only get error in ObjectPools.jl when I run ‘]test’ but it works fine when I run the script directly. Maybe that also implies it is a problem of Pkg?

Given that the CI of stable release J1.8 and J1.9 both pass stably, I have no problem if you prefer merging that.

cortner commented 1 year ago

I found the issue:

in_d = 10; out_d = 5; N = 7
l = P4ML.LinearLayer(in_d, out_d; feature_first = false)
ps, st = LuxCore.setup(MersenneTwister(1234), l)
X1 = randn(N, in_d)
Y1_ = [ l(X1[i, :], ps, st)[1] for i = 1:N ]
hcat(Y1_...)

It appears that hcat now calls parent whereas before it did not.

cortner commented 1 year ago

I don't think this is a bug from our end but it is fortuitious that we are changing to unwrap so we avoid this completely.

My suggestion : even though we did document the parent usage in P4ML, it was only for the first 0.3.0 release. Let's hold our nose for five minutes to do the following:

Are you happy with that? I think it is a very small violation of semver and we have no users of the latest versions of ObjectPools and P4ML yet so it should be minimally disruptive IF AT ALL

cortner commented 1 year ago

I only get error in ObjectPools.jl when I run ‘]test’ but it works fine when I run the script directly. Maybe that also implies it is a problem of Pkg?

I still can't reproduce that. Can you open a separate issue in Objectpools with more details?

CheukHinHoJerry commented 1 year ago

I still can't reproduce that. Can you open a separate issue in Objectpools with more details?

Yes, I will do that now. Also I am happy with what you suggest. Removing parent should be fine (I think). If it causes other issues in other packages it wouldn't be a difficult fix.

CheukHinHoJerry commented 1 year ago

I only get error in ObjectPools.jl when I run ‘]test’ but it works fine when I run the script directly. Maybe that also implies it is a problem of Pkg?

I still can't reproduce that. Can you open a separate issue in Objectpools with more details?

I tried that with J1.9.2 and it works fine but it is broken on J.1.9.0. Maybe we can just keep that in mind? I don't think it is the problem from our end.

Ref https://github.com/ACEsuit/ObjectPools.jl/issues/13#issue-1827147837

cortner commented 1 year ago

actually please don't do it yet. I had another idea.

cortner commented 1 year ago

here is what I added - will register Objectpools 0.3.1 in a moment:

# for some reason @deprecate doesn't work here. Getting 
# error messages we don't understand. This here will do until 
# we tag the next version. 
# In Julia 1.10-beta1 it seems that `parent` is used in ways we didn't 
# expect, causing warnings all over the place, hence we keep this only for 
# older versions of Julia. 
if VERSION < v"1.10-"
   function Base.parent(pA::FlexCachedArray)
      @warn("Use of `parent` to obtain the PtrArray of a FlexCachedArray is deprecated. Use `unwrap` instead.")
      return unwrap(pA)
   end
end 
cortner commented 1 year ago

Once the new ObjectPools is registered I will add a version bound here, test, merge and register.

CheukHinHoJerry commented 1 year ago

Thank you for your help - much messier than I expected.