rejuvyesh / PyCallChainRules.jl

Differentiate python calls from Julia
MIT License
56 stars 2 forks source link

Use `dlpack` for array interop #10

Closed rejuvyesh closed 2 years ago

rejuvyesh commented 2 years ago

Depends on https://github.com/pabloferz/DLPack.jl/issues/7

rejuvyesh commented 2 years ago
julia> judge(median(results), median(no_dlpack[1]))
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "pytorchhub" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "bs=16" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+18.40% => regression)
                          "functorch" => TrialJudgement(+15.16% => regression)
                          "jl" => TrialJudgement(+585.51% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+10.65% => regression)
                          "functorch" => TrialJudgement(+244.91% => regression)
                          "jl" => TrialJudgement(+426.57% => regression)
          "bs=32" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+136.70% => regression)
                          "functorch" => TrialJudgement(+147.69% => regression)
                          "jl" => TrialJudgement(+841.88% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+122.57% => regression)
                          "functorch" => TrialJudgement(+132.32% => regression)
                          "jl" => TrialJudgement(+213.17% => regression)
          "bs=8" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+400.93% => regression)
                          "functorch" => TrialJudgement(+391.76% => regression)
                          "jl" => TrialJudgement(+699.24% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+427.55% => regression)
                          "functorch" => TrialJudgement(+403.57% => regression)
                          "jl" => TrialJudgement(+471.85% => regression)
          "bs=1" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+1582.03% => regression)
                          "functorch" => TrialJudgement(+1324.41% => regression)
                          "jl" => TrialJudgement(+1247.94% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(+7.27% => regression)
                          "functorch" => TrialJudgement(+7.67% => regression)
                          "jl" => TrialJudgement(-14.63% => improvement)
  "pytorchmlp" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "bs=16" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-12.30% => improvement)
                          "functorch" => TrialJudgement(-7.23% => improvement)
                          "jl" => TrialJudgement(+19.95% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-1.89% => invariant)
                          "functorch" => TrialJudgement(-2.81% => invariant)
                          "jl" => TrialJudgement(+13.12% => regression)
          "bs=32" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-17.17% => improvement)
                          "functorch" => TrialJudgement(-8.52% => improvement)
                          "jl" => TrialJudgement(+27.38% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-2.17% => invariant)
                          "functorch" => TrialJudgement(-4.80% => invariant)
                          "jl" => TrialJudgement(+19.82% => regression)
          "bs=8" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-15.03% => improvement)
                          "functorch" => TrialJudgement(-6.87% => improvement)
                          "jl" => TrialJudgement(+9.58% => regression)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-0.59% => invariant)
                          "functorch" => TrialJudgement(-2.66% => invariant)
                          "jl" => TrialJudgement(+6.15% => regression)
          "bs=1" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: []
                  "forward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-14.15% => improvement)
                          "functorch" => TrialJudgement(-7.62% => improvement)
                          "jl" => TrialJudgement(+2.07% => invariant)
                  "backward" => 3-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                          "torch" => TrialJudgement(-4.48% => invariant)
                          "functorch" => TrialJudgement(-4.81% => invariant)
                          "jl" => TrialJudgement(-0.07% => invariant)

Might be doing this wrong but not beneficial on CPU? Need a better machine to evaluate.

rejuvyesh commented 2 years ago

Likely some issue with GC.@preserve. Need to figure out a MWE.

julia> for i in 1:10; TestEnv.activate() do; include("test/test_pytorch.jl"); end; end
Precompiling project...
  1 dependency successfully precompiled in 2 seconds (16 already precompiled)
Test Summary: | Pass  Total
dlpack        |    4      4
┌ Warning: `vendor()` is deprecated, use `BLAS.get_config()` and inspect the output instead
│   caller = npyinitialize() at numpy.jl:67
└ @ PyCall ~/.julia/packages/PyCall/L0fLP/src/numpy.jl:67
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:68
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.009436185 0.0996897 0.20193027; 0.38426346 -0.28908443 0.08106785], Float32[1.938506f-39 0.0 5.01105f-33; 0.0 4.5915f-41 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:68 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:68
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.096376784, -0.57178324], Float32[8.28208f-40, 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:68 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:73
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.009436185 0.0996897 0.20193027; 0.38426346 -0.28908443 0.08106785], Float32[1.938506f-39 0.0 5.01105f-33; 0.0 4.5915f-41 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:73 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:73
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.096376784, -0.57178324], Float32[8.28208f-40, 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:73 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:77
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.009436185 0.0996897 0.20193027; 0.38426346 -0.28908443 0.08106785], Float32[1.938506f-39 0.0 5.01105f-33; 0.0 4.5915f-41 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:77 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
linear: Test Failed at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:77
  Expression: isapprox(torchparams[i], linwrap.params[i], atol = 0.0001, rtol = 0.0001)
   Evaluated: isapprox(Float32[0.096376784, -0.57178324], Float32[8.28208f-40, 0.0]; atol = 0.0001, rtol = 0.0001)
Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:445 [inlined]
 [2] macro expansion
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:77 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/stdlib/v1.7/Test/src/Test.jl:1283 [inlined]
 [4] top-level scope
   @ ~/.julia/dev/PyCallChainRules/test/test_pytorch.jl:59
Test Summary: | Pass  Fail  Total
linear        |   14     6     20
ERROR: LoadError: Some tests did not pass: 14 passed, 6 failed, 0 errored, 0 broken.
in expression starting at /home/jagupt/.julia/dev/PyCallChainRules/test/test_pytorch.jl:58

vs All tests pass in:

julia> for i in 1:10; GC.enable(false); TestEnv.activate() do; include("test/test_pytorch.jl"); end; GC.enable(true) end