google-coral / project-teachable

Example Project: Teachable Machine
https://coral.withgoogle.com/projects/teachable-machine/
Apache License 2.0
26 stars 14 forks source link

embedding vectors are always the same #7

Closed maiermic closed 4 years ago

maiermic commented 4 years ago

Even though I’m using totally different images, I always get the same embedding vectors. To reproduce do

git clone https://github.com/google-coral/project-teachable.git
cd project-teachable
git checkout 84b725fae87ca7080c227284dfc9800df0ab46fe

Install requirements of the project. I use

edgetpu==2.11.1
numpy==1.17.2
Pillow==6.1.0
pycairo==1.18.1
PyGObject==3.34.0
pyparsing==2.4.2
svgwrite==1.3.1

Download images

wget https://coral.withgoogle.com/static/projects/images/teachable-machine/teachable-machine-web-067.jpg
wget https://coral.withgoogle.com/static/projects/images/teachable-machine/teachable-machine-web-096.jpg

Create file embedding_example.py with content:

from PIL import Image

from embedding import EmbeddingEngine

engine = EmbeddingEngine('models/mobilenet_quant_v1_224_headless_edgetpu.tflite')
image_067 = Image.open('teachable-machine-web-067.jpg')
image_096 = Image.open('teachable-machine-web-096.jpg')
emb_1 = engine.DetectWithImage(image_067)
emb_2 = engine.DetectWithImage(image_096)
print(emb_1)
print(emb_2)
print(emb_1 != emb_2)
assert any(emb_1 != emb_2), 'Embeddings of different images should not be equal'

Run example

$ python embedding_example.py
INFO: Initialized TensorFlow Lite runtime.
W :67] Minimum runtime version required by package (2) is lower than expected (10).
W :67] Minimum runtime version required by package (2) is lower than expected (10).
[0.         0.11764239 0.         ... 1.035253   0.         0.47056955]
[0.         0.11764239 0.         ... 1.035253   0.         0.47056955]
[False False False ... False False False]
Traceback (most recent call last):
  File "embedding_example.py", line 13, in <module>
    assert any(emb_1 != emb_2), 'Embeddings of different images should not be equal'
AssertionError: Embeddings of different images should not be equal

Note: I'm using a Coral USB Accelerator on Linux Mint 19.2, but I suppose you get the same result on the Coral Dev Board.


Can anyone confirm this issue?

Namburger commented 4 years ago

@maiermic possibly because the embedding engine is returning the result as a reference instead of value? This way the result of the first one is the same object as the second one (think pointers in c). For example, take a look at this code snippet:

engine = EmbeddingEngine('models/mobilenet_quant_v1_224_headless_edgetpu.tflite')
emb_1 = engine.DetectWithImage(Image.open('teachable-machine-web-067.jpg'))
print('emp_1 at first:\n', emb_1)
emb_2 = engine.DetectWithImage(Image.open('teachable-machine-web-096.jpg'))
print('emb_2:\n', emb_2)
print('emb_1 after running inference on second image:\n', emb_1)

yeilds this result:

emp_1 at first:
 [1.7411073  0.47056955 0.         ... 0.09411391 0.3293987  0.3293987 ]
emb_2:
 [0.         0.11764239 0.         ... 1.0117245  0.         0.47056955]
emb_1 after running inference on second image:
 [0.         0.11764239 0.         ... 1.0117245  0.         0.47056955]

Weird, huh lol? The gist is that since it's returning a reference to an object, then that object changes as you run inference on another image, your emp_1 will also get that change if the object change (I hope this isn't too confusing). To achieve your goal of saving the exact state of the output for future use, you'll need to save a copy instead of a reference like so:

import copy

engine = EmbeddingEngine('models/mobilenet_quant_v1_224_headless_edgetpu.tflite')

emb_1 = copy.copy(engine.DetectWithImage(Image.open('teachable-machine-web-067.jpg')))
print('emp_1:\n', emb_1)
emb_2 = copy.copy(engine.DetectWithImage(Image.open('teachable-machine-web-096.jpg')))
print('emb_2:\n', emb_2)
assert any(emb_1 != emb_2), 'Embeddings of different images should not be equal'
print(emb_1 != emb_2)

which give the resulting output:

emp_1:
 [1.7411073  0.47056955 0.         ... 0.09411391 0.3293987  0.3293987 ]
emb_2:
 [0.         0.11764239 0.         ... 1.0117245  0.         0.47056955]
[ True  True False ...  True  True  True]

Since python is not a super explicit language, this topic can get quite confusing, but I don't see no foul in the code base here at the moment. Here is another related article regarding passing parameter by reference vs value: https://robertheaton.com/2014/02/09/pythons-pass-by-object-reference-as-explained-by-philip-k-dick/

Hope this helps!

maiermic commented 4 years ago

@Namburger Thank you very much. That helps a lot :smile: The foul in the code base is that it is not documented that the returned reference is always the same. This is also not apparent from the signature of the method in C++.

Now that the source code has been published, you can figure out what is going on: The wrapper internally passes a pointer output that is set inside the call using the data of the vector stored in the engine. Hence, you get two different object references if you use two engines:

from PIL import Image

from embedding import EmbeddingEngine

engine_1 = EmbeddingEngine(
    'models/mobilenet_quant_v1_224_headless_edgetpu.tflite')
engine_2 = EmbeddingEngine(
    'models/mobilenet_quant_v1_224_headless_edgetpu.tflite')

emb_1 = engine_1.DetectWithImage(Image.open('teachable-machine-web-067.jpg'))
print('emp_1:\n', emb_1)
emb_2 = engine_2.DetectWithImage(Image.open('teachable-machine-web-096.jpg'))
print('emb_2:\n', emb_2)
assert any(emb_1 != emb_2), 'Embeddings of different images should not be equal'
print(emb_1 != emb_2)
Namburger commented 4 years ago

@maiermic you are correct, I was going to suggest using 2 engine at first, but I didn't think that was efficient :P

maiermic commented 4 years ago

@mtyka This issue is not fixed. The documentation is still lacking a description of this behavior.

The foul in the code base is that it is not documented that the returned reference is always the same. This is also not apparent from the signature of the method in C++.