Unity-Technologies / barracuda-release

Other
564 stars 76 forks source link

UpSampling Layers Do Not Use Correct Interpolation #317

Open hayden-donnelly opened 1 year ago

hayden-donnelly commented 1 year ago

When passing a tensor through an up-sampling layer, Barracuda always applies nearest neighbor interpolation instead of the layer's specified interpolation. Below is some Python and C# code for reproduction.

Create model in Python/Keras

import keras
from keras import layers
import tf2onnx

# Define model
inputs = layers.Input(shape=(20, 20, 1))
outputs = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(inputs)
model = keras.Model(inputs, outputs)

# Convert to ONNX
tf2onnx.convert.from_keras(model, output_path='upsampler.onnx')

Test in Unity

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.Barracuda;
using System;

public class UpSamplerTest : MonoBehaviour
{
    private Tensor RandomTensor(int width, int height)
    {
        System.Random random = new System.Random();
        Tensor temp = new Tensor(1, height, width, 1);
        for(int y = 0; y < height, y++)
        {
            for(int x = 0; x < width; x++)
            {
                temp[0, y, x, 0] = (float)random.NextDouble();
            }
        }
        return temp;
    }

    private void Start()
    {
        Model runtimeModel = ModelLoader.Load(modelAsset);
        using(var worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, runtimeModel))
        {
            Tensor input = RandomTensor(20, 20);
            worker.Execute(input);
            Tensor output = worker.PeekOutput();

            for(int i = 0; i < 10; i++)
            {
                Debug.Log(output[i]);
            }

            input.Dispose();
            output.Dispose();
        }
    }
}

The output of the test will be 5 unique values because the input tensor is up-sampled by a factor of 2 with nearest neighbor interpolation. If it were correctly performing bilinear interpolation, it should output 10 unique values.