openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.47k stars 374 forks source link

[oneDNN] XLA:CPU performance regression with oneDNN Dot op integration #5875

Open penpornk opened 10 months ago

penpornk commented 10 months ago

Reproducing instruction:

# With regression
$ pip install jax[cpu]==0.4.16 equinox
$ python repro.py

# Before regression
$ pip install jax[cpu]==0.4.14 equinox
$ python repro.py

repro.py:

import jax
import jax.numpy as jnp
import jax.random as jr
import timeit

# Increasing this value seems to increase the performance drop.
SIZE = 1

@jax.jit
@jax.vmap
@jax.grad
def f(x):
    out = 0
    for i in range(SIZE):
        y = x @ jnp.array([i, i + 1], dtype=jnp.float32)
        out = out + y * x[0]
    return out

x = jr.normal(jr.PRNGKey(0), (2, 2))
f(x)  # compile
speed = timeit.timeit(lambda: f(x).block_until_ready(), number=10_000)
print(speed)

Increasing SIZE seems to increase the performance drop. The reporter got the following times on their machine:

penpornk commented 10 months ago

@TensorFlow-MKL @mdfaijul @agramesh1 Could you please help take a look? Thank you very much! :)

mdfaijul commented 10 months ago

@penpornk Thanks for the reproducer. Could you please tell us (1) Which platform (AVX2/AVX512/AMX)? (2) What is the semantic of this expression, x @ jnp.array([i, i + 1], dtype=jnp.float32) ?

mdfaijul commented 9 months ago

@penpornk This PR fixes the regression https://github.com/tensorflow/tensorflow/pull/62299