pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
371 stars 109 forks source link

Pytorch backend slow with pymc model #1110

Open Ch0ronomato opened 4 days ago

Ch0ronomato commented 4 days ago

Description

@ricardoV94 did a nice perf improvement in https://github.com/pymc-devs/pymc/pull/7578 to try to speedup jitted backends. I tried out torch as well. The model performed quite slow.

mode t_sampling (seconds) manual measure (seconds)
NUMBA 2.483 11.346
PYTORCH (COMPILED) 206.503 270.188
PYTORCH (EAGER) 60.607 64.140

We need to investigate why

  1. Torch is so slow
  2. Torch compile is slower than eager mode

When doing perf evaluations, keep in mind that torch does a lot of caching. If you want a truly cache-less eval, you can either add torch.compiler.reset() or set the env variable to disable the dynamo cache (google it).

Ch0ronomato commented 4 days ago

The script I used for reference

import arviz as az
import numpy as np
import multiprocessing
import pandas as pd
import pymc as pm
import pytensor as pt
import pytensor.tensor.random as ptr
import time

def main():
    # Load the radon dataset
    data = pd.read_csv(pm.get_data("radon.csv"))
    data["log_radon"] = data["log_radon"].astype(np.float64)
    county_idx, counties = pd.factorize(data.county)
    coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

    # Create a simple hierarchical model for the radon dataset
    with pm.Model(coords=coords, check_bounds=False) as model:
        intercept = pm.Normal("intercept", sigma=10)

        # County effects
        raw = pm.ZeroSumNormal("county_raw", dims="county")
        sd = pm.HalfNormal("county_sd")
        county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

        # Global floor effect
        floor_effect = pm.Normal("floor_effect", sigma=2)

        # County:floor interaction
        raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
        sd = pm.HalfNormal("county_floor_sd")
        county_floor_effect = pm.Deterministic(
            "county_floor_effect", raw * sd, dims="county"
        )

        mu = (
            intercept
            + county_effect[county_idx]
            + floor_effect * data.floor.values
            + county_floor_effect[county_idx] * data.floor.values
        )

        sigma = pm.HalfNormal("sigma", sigma=1.5)
        pm.Normal(
            "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
        )

    from pymc.model.transform.optimization import freeze_dims_and_data
    model = freeze_dims_and_data(model)
    for mode in ("NUMBA", "PYTORCH"):
        start = time.perf_counter()
        trace = pm.sample(
            model=model, 
            cores=1,
            chains=1,
            tune=500, 
            draws=500, 
            progressbar=False, 
            compute_convergence_checks=False, 
            return_inferencedata=False,
            compile_kwargs=dict(mode=mode)
        )
        end = time.perf_counter()
        idata = pm.to_inference_data(trace, model=model)
        print(az.summary(idata, kind="diagnostics"))
        print(mode, trace._report.t_sampling, end - start)

if __name__ == "__main__":
    multiprocessing.freeze_support()
    main()
Ch0ronomato commented 4 days ago

I called compile logp and dlogp as well to narrow down the time

mode method time
PYTORCH logp 5.891
PYTORCH dlogp 4.054
NUMBA logp 2.376
NUMBA dlogp 2.244
ricardoV94 commented 4 days ago

You have to call it once (perhaps assert they output the same) to jit compile and then only timeit

Ch0ronomato commented 3 days ago

@ricardoV94 can you assign this issue to me by chance?

I profiled a bit more. The logp and dlogp pytensor functions don't take long to generate, but executing it is slower. Both numba and torch do show that it gets faster if you execute it multiple times, but numba is much faster

mode method time
PYTORCH logp.call 4.8469
PYTORCH logp.call 0.0013
PYTORCH logp.call 0.0008
PYTORCH dlogp.call 9.9691
PYTORCH dlogp.call 7.2030
PYTORCH dlogp.call 7.9464
NUMBA logp.call 2.8590
NUMBA logp.call 0.0001
NUMBA logp.call 0.0001
NUMBA dlogp.call 3.7824
NUMBA dlogp.call 0.0059
NUMBA dlogp.call 0.0001

I'm seeing some interesting data in the .explain call, it looks like join and alloc create a few different graph breaks, and then a lot of data dependent stuff happening as well. Seeing that just running logp once multiple times has the runtime reduce makes me thing that there is a bit more recomplication happening. Attached is the full dump of the torch compile logs, but here is the header that just outlines what happened to the graph

Graph Count: 8
Graph Break Count: 7
Op Count: 198
Break Reasons:
  Break Reason 1:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 35 in pytorch_funcified_fgraph>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
  Break Reason 2:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
  Break Reason 3:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 75 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_35>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
  Break Reason 4:
    Reason: Dynamic slicing on data-dependent value is not supported
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 133 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_75>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 77 in inc_subtensor>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in indices_from_subtensor>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in <genexpr>>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 131 in convert_indices>
  Break Reason 5:
    Reason: Dynamic slicing on data-dependent value is not supported
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 81 in torch_dynamo_resume_in_inc_subtensor_at_78>
  Break Reason 6:
    Reason: Dynamic slicing on data-dependent value is not supported
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 81 in torch_dynamo_resume_in_inc_subtensor_at_81>
  Break Reason 7:
    Reason: Dynamic slicing on data-dependent value is not supported
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpxndm7s5l, line 193 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_133>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/subtensor.py, line 77 in inc_subtensor>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in indices_from_subtensor>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 142 in <genexpr>>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/tensor/subtensor.py, line 131 in convert_indices>

explaination.txt

The first thing I'm gonna do is clean up some of the warnings. These timings are only possible because of sending the warning logs out to dev null, if you don't the timings balloon a bit.

ricardoV94 commented 3 days ago

The first time is the compilation, it's not as relevant since it's a one time thing. If we're recompiling multiple times that's a different thing.

Also are you using %timeit after the first call? A single call is going to be noisy.

Did you confirm the outputs match?

ricardoV94 commented 3 days ago

Those breaks are interesting but most are not data dependent? Like the slice is constant in this model. Can you enable that capture scalar outputs option?

Also are you freezing the model data and dims like in the original example? When we have static shapes we could forward those to the dispatch

Ch0ronomato commented 3 days ago

Yeah my bad; this is how the profiling looks

    from pymc.model.transform.optimization import freeze_dims_and_data
    model = freeze_dims_and_data(model)
    val = model.initial_point(123)
    for mode in ("PYTORCH", "NUMBA"):
        fn = model.compile_logp(mode=mode)
        for _ in range(3):
            start = time.perf_counter()
            fn(val)
            end = time.perf_counter()
            print("| ", " | ".join((mode, "logp.__call__", "{:.4f}".format(end - start))), " |")
Ch0ronomato commented 3 days ago

Btw, I only showed logp, here is dlogp, which is potentially more problematic

Graph Count: 16
Graph Break Count: 15
Op Count: 163
Break Reasons:
  Break Reason 1:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 33 in pytorch_funcified_fgraph>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
  Break Reason 2:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
  Break Reason 3:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 73 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_33>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>
  Break Reason 4:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 111 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_73>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/shape.py, line 10 in reshape>
  Break Reason 5:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/shape.py, line 10 in reshape>
  Break Reason 6:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 119 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_111>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 106 in alloc>
  Break Reason 7:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 139 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_119>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/shape.py, line 10 in reshape>
  Break Reason 8:
    Reason: Tensor.item
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 147 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_139>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 193 in inner_fn>
  Break Reason 9:
    Reason: Tensor.item
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 193 in inner_fn>
  Break Reason 10:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 173 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_149>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/shape.py, line 10 in reshape>
  Break Reason 11:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 179 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_173>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 106 in alloc>
  Break Reason 12:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 199 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_179>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/shape.py, line 10 in reshape>
  Break Reason 13:
    Reason: Tensor.item
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 207 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_199>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 193 in inner_fn>
  Break Reason 14:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /var/folders/2z/y4z6q41j7l93ysr2y2k_gsnh0000gn/T/tmpo4p1bvf1, line 229 in torch_dynamo_resume_in_pytorch_funcified_fgraph_at_209>
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/shape.py, line 10 in reshape>
  Break Reason 15:
    Reason: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
    User Stack:
      <FrameSummary file /Users/ch0ronomato/dev/pytensor/pytensor/link/pytorch/dispatch/basic.py, line 129 in join>

dlogp seems to be a bit more problematic. I reran timing this morning and got really poor numbers (8s), but last night it was "better". Not sure what's going on there.

import os
import arviz as az
import numpy as np
import multiprocessing
import pandas as pd
import pymc as pm
import pytensor as pt
import pytensor.tensor.random as ptr
import time

def main():
    # Load the radon dataset
    data = pd.read_csv(pm.get_data("radon.csv"))
    data["log_radon"] = data["log_radon"].astype(np.float64)
    county_idx, counties = pd.factorize(data.county)
    coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

    # Create a simple hierarchical model for the radon dataset
    with pm.Model(coords=coords, check_bounds=False) as model:
        intercept = pm.Normal("intercept", sigma=10)

        # County effects
        raw = pm.ZeroSumNormal("county_raw", dims="county")
        sd = pm.HalfNormal("county_sd")
        county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

        # Global floor effect
        floor_effect = pm.Normal("floor_effect", sigma=2)

        # County:floor interaction
        raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
        sd = pm.HalfNormal("county_floor_sd")
        county_floor_effect = pm.Deterministic(
            "county_floor_effect", raw * sd, dims="county"
        )

        mu = (
            intercept
            + county_effect[county_idx]
            + floor_effect * data.floor.values
            + county_floor_effect[county_idx] * data.floor.values
        )

        sigma = pm.HalfNormal("sigma", sigma=1.5)
        pm.Normal(
            "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
        )

    from pymc.model.transform.optimization import freeze_dims_and_data
    model = freeze_dims_and_data(model)
    val = model.initial_point(123)
    os.environ["PYTORCH_RECORD_EXPLAIN"] = "yes"
    for mode in ("PYTORCH", "NUMBA"):
        fn = model.compile_dlogp(mode=mode)
        for _ in range(3):
            start = time.perf_counter()
            fn(val)
            end = time.perf_counter()
            print("| ", " |".join((mode, "dlogp.__call__", "{:.4f}".format(end - start))), " |")

if __name__ == "__main__":
    multiprocessing.freeze_support()
    main()

explaination.txt

Ch0ronomato commented 3 days ago

Addressing the q's

Also are you freezing the model data and dims like in the original example? When we have static shapes we could forward those to the dispatch

yea, i pasted the profiling code, mb

Those breaks are interesting but most are not data dependent? Like the slice is constant in this model.

I find this a bit strange too; i think what might be happening is we have tensors of a single value (like what we had in #1031 ) and that's causing a log of "data dependent" operators? I need to understand the plumbing a bit more I think to really know.

Can you enable that capture scalar outputs option?

Unbelievably, this crashes

  File "/opt/anaconda3/envs/pytensor-dev/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5205, in evaluate_expr
    raise self._make_data_dependent_error(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: u0)

Did you confirm the outputs match?

Yes! just using np.testing.assert_allclose between the numba and torch backend

Also are you using %timeit after the first call?

I time 3 calls. I could time more but I mostly just cared to see that the first one was big and the rest were "smaller"

ricardoV94 commented 3 days ago

I time 3 calls. I could time more but I mostly just cared to see that the first one was big and the rest were "smaller"

The first slow next fast is really not surprising, compilation happens on the first call.

3 calls is not usually enough to measure without noise. I believe torch is slower of course but you should average the time it takes to eval like 100 times at least. You can use the %timeit magic to do this automatically on an ipython (or jupyter) environment.

ricardoV94 commented 3 days ago

What's the compiled function. Curious why it doesn't think those slice indices are constant. You can do fn.dprint() where fn is the compiled pytensor function

Ch0ronomato commented 3 days ago

Yea, agree that the pattern isn't surprising, but i was worried each call was getting a recompile (like the dlogp seems to do perhaps intermittently). I'll do a more robust measure.

What would a dynamic slice look like?

ricardoV94 commented 3 days ago
a = pt.scalar("a", dtype=int)
x = pt.vector("x")
out = x[:a]
Ch0ronomato commented 3 days ago

Here are the graphs, i dumped them into a text file. Logp first, then dlogp, separated by ---. It's quite large, i'll go through it later out.txt

ricardoV94 commented 3 days ago

looking at the times the problem is definitely the dlogp graph. How much better if we get it through torch autodiff?

Ch0ronomato commented 2 days ago

Okay just to drive the numbers point home, I did this in ipython.

fn = model.compile_logp(mode="PYTORCH")
fn(val)
%timeit -n 100 -r 100 fn(val)

2.18 ms ± 411 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)
fn = model.compile_logp(mode="NUMBA")
fn(val)
%timeit -n 100 -r 100 fn(val)

50.5 μs ± 4.29 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)

That's for logp. For dlogp

fn = model.compile_dlogp(mode="PYTORCH")
fn(val)
%timeit -n 100 -r 100 fn(val)

2.47 ms ± 430 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)
fn = model.compile_dlogp(mode="NUMBA")
fn(val)
%timeit -n 100 -r 100 fn(val)

62.3 μs ± 5.34 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)

I haven't looked at going through the torch autodiff engine for dlogp, I can put that on the list of possible things.

Ch0ronomato commented 2 days ago

I can't super tell from the dprint if the slices are constant or not. I do see Split and AdvancedSubtensor throughout the graphs, but I imagine just cause they are there doesn't mean they are dynamic indexing?

 │  ├─ True_div [id EG] 79
 │  │  ├─ Sub [id EH] 78
 │  │  │  ├─ SpecifyShape [id EI] 73
 │  │  │  │  ├─ Split{2}.1 [id DY] 72
 │  │  │  │  │  └─ ···
 │  │  │  │  └─ 1 [id DA]
 │  │  │  └─ ExpandDims{axis=0} [id EJ] 77
 │  │  │     └─ Sum{axes=None} [id EK] 76
 │  │  │        └─ Sub [id DZ] 71
 │  │  │           └─ ···
 │  ├─ Split{2}.0 [id DY] 72
 │  │  ├─ Sub [id DZ] 71
 │  │  │  ├─ Mul [id EA] 70
 │  │  │  │  ├─ AdvancedIncSubtensor1{no_inplace,inc} [id EB] 59
 │  │  │  │  │  ├─ Alloc [id EC] 58
 │  │  │  │  │  │  ├─ [0.] [id CQ]
 │  │  │  │  │  │  └─ Add [id ED] 57
 │  │  │  │  │  │     ├─ Shape_i{0} [id CC] 4
 │  │  │  │  │  │     │  └─ ···
 │  │  │  │  │  │     └─ 1 [id CS]
 │  │  │  │  │  ├─ True_div [id EE] 56
 │  │  │  │  │  │  ├─ Mul [id DW] 55
 │  │  │  │  │  │  │  └─ ···
 │  │  │  │  │  │  └─ ExpandDims{axis=0} [id CV] 42
 │  │  │  │  │  │     └─ ···
 │  │  │  │  │  └─ [ 0  0  0 ... 83 84 84] [id BI]
Ch0ronomato commented 2 days ago

There are two warnings I wanna dig into.

  1. There is a warning when doing a call to torch.tensor that is warning of a copy. So I'm gonna poke at that
  2. Warnings about data dependent ops. I wonder if it's because everything comes in as a torch.tensor, so torch thinks it's a variable it needs to track. We might be able avoid that...somehow lol
ricardoV94 commented 2 days ago

@Ch0ronomato you need to eval the function once before doing the timeit. The first call will take care of the jit compilation which shouldn't be mixed with the eval cost. That's why you're getting those warnings that the slowest eval took x longer than fastest

Ch0ronomato commented 2 days ago

you need to eval the function once before doing the timeit.

Cleaned up

ricardoV94 commented 2 days ago

So it's 4-5x slower than numba. Can you show what you get on the default c backend and jax on your machine as well?