tensorflow / tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
https://js.tensorflow.org
Apache License 2.0
18.51k stars 1.93k forks source link

Prediction zeroed out at precise tensor size on JAX converted code on WebGL backend #7448

Open benoitparis opened 1 year ago

benoitparis commented 1 year ago

Might be related to https://github.com/tensorflow/tfjs/issues/7430

How to reproduce:

import jax.numpy as jnp
import tensorflowjs as tfjs

params = {}

def main_ok(params):
    lin = jnp.linspace(0.0, 1.0, num=159, endpoint=False)
    return jnp.outer(lin, lin)

def main_fail(params):
    lin = jnp.linspace(0.0, 1.0, num=160, endpoint=False)
    return jnp.outer(lin, lin)

tfjs.converters.convert_jax(
    main_ok,
    params,
    input_signatures=[],
    model_dir='./main_ok/'
)

tfjs.converters.convert_jax(
    main_fail,
    params,
    input_signatures=[],
    model_dir='./main_fail/'
)

Displayed with Chrome Version 110.0.5481.178 (Build officiel) (64 bits):

<html lang="en">
  <head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.2.0"> </script>
  </head>
  <body>
    <canvas id="cv"></canvas>
    <script type="text/javascript">
      async function getOutput() {
        // tf.setBackend('cpu'); // Makes it ok
        // const model = await tf.loadGraphModel('/main_ok/model.json'); // Makes it ok
        const model = await tf.loadGraphModel('/main_fail/model.json'); // Fail
        let result = model.predict();
        console.log(result.dataSync());
        const canvas = document.getElementById("cv");
        await tf.browser.toPixels(result, canvas);
      }
      getOutput();
    </script>
  </body>
</html>

Pipfile:

[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"

[packages]
numpy = "==1.23.5"
jaxlib = "==0.4.4"
jax = "==0.4.4"
tensorflow = "==2.11.0"
tensorflowjs = "==4.2.0"

[dev-packages]

[requires]
python_version = "3.10"

Describe the current behavior

Outputs a black canvas; tensor content is zeroes

Describe the expected behavior

Should output a gradiented canvas, just like it does at size 159. Setting the backend to CPU solves this as well.

benoitparis commented 1 year ago

Might be related to the fact that:

159*159 = 25281 160*160 = 25600

Now, some power of 2 might be suspect, but a power of 2 times 100 is quite intriguing. ¯\_(ツ)_/¯

benoitparis commented 1 year ago

Following the instructions https://github.com/tensorflow/tfjs/issues/7430 and https://github.com/tensorflow/tfjs/issues/1936, I tried the following:

    tf.env().set('WEBGL_USE_SHAPES_UNIFORMS', false);
    tf.env().set('WEBGL_PACK', false);
    tf.env().set('WEBGL_CONV_IM2COL', false);

And the bug remained

gaikwadrahul8 commented 1 year ago

Hi, @benoitparis

Apologize for the delayed response and I tried to replicate the same issue from my end and I'm getting the same result for tf.setBackend('cpu') and for below 02 lines of code after JAX converted code on WebGL by using above code with @tensorflow/tfjs@4.10.0 so could you please try from your end once again with @tensorflow/tfjs@4.10.0 let us know, Is it working as expected or not ? Thank you!

If issue persists please let us know or Am I missing something here? Thank you!

const model = await tf.loadGraphModel('/main_ok/model.json'); // Makes it ok
const model = await tf.loadGraphModel('/main_fail/model.json'); 

I'm getting below output in all 03 cases which you mentioned in above code snippet :

image