haeusser / learning_by_association

This repository contains code for the paper Learning by Association - A versatile semi-supervised training method for neural networks (CVPR 2017) and the follow-up work Associative Domain Adaptation (ICCV 2017).
https://vision.in.tum.de/members/haeusser
Apache License 2.0
151 stars 63 forks source link

Tensorflow Version Syntax Mismatched #7

Closed mlearning closed 7 years ago

mlearning commented 7 years ago

In backend.py line 70, the tf.concat has the new signature:

  tf.concat(batch_images, 0), tf.concat(batch_labels, 0)

In tarin.py line 301 and 309, it has the old signature:

  t_sup_emb = tf.concat(0, [
                    t_sup_emb, semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                                       FLAGS.emb_size)
                ])

Is this meant for Tensorflow 1.0 or above? Could you please kindly list the dependencies in the readme file, such as Python version, Tensorflow Version and Numpy version etc?

When I ran it with Python 3.5 and Tensorflow 1.1, I ran into a problem when reading the stl-10 bin file:

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x92 in position 0: invalid start byte
File "/root/DL/computer-vision/learning_by_association/semisup/tools/stl10.py", line 67, in extract_images

 imgs = np.fromstring(f.read(), np.uint8)  

So I ran it with Python 2.7 and tensorflow 0.12.1 by changing the above tf.concat signature to the old version, and encounter another error: backend.py, line 179, in add_semisup_loss

    loss_aba = tf.losses.softmax_cross_entropy(    

AttributeError: 'module' object has no attribute 'losses'

This is an indication that the backend.py is using Tensorflow version 1.0 or higher.

However in some other places I see indications of lower version being used. Your clarification or fixes will be very appreciated!

mlearning commented 7 years ago

I figured out that it is with Python 2.7 and Tensorflow 1.0 or above. I am able to start the training process now.

The tf.concat in training.py with the wrong signature is not throwing an error because virtual embedding is not enabled.

haeusser commented 7 years ago

I think the issue you're running into is not related to TF but to reading the STL10 data. Did you download the data set, unzipped it and adjusted the path in semisup/tools/data_dirs.py?

I should fix the tf.concat calls.

haeusser commented 7 years ago

fixed in bfe53989f9883610f7a6ec2322b2c6177dc1524b