henrysky / astroNN

Deep Learning for Astronomers with Keras
http://astronn.readthedocs.io/
MIT License
193 stars 51 forks source link

Keras's fit_generator failed when use_multiprocessing=True on WIndows only #2

Closed henrysky closed 6 years ago

henrysky commented 6 years ago

System information

Describe the problem

astroNN's generator is already thread safe

It is a known issue on Windows caused by python. Probably will work on Linux/MacOS.

So far the only issue is CPU can't generate data fast enough for a fast GPU (GTX970 or above and at least 4 threads CPU).

Only neccessary when you are using BCNN with GPU training

Link: https://github.com/matterport/Mask_RCNN/issues/13 Link: https://github.com/keras-team/keras/issues/6582

Source code / logs

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-2-17f261cd711f> in <module>()
      2 bcnn = Apogee_BCNN()
      3 bcnn.max_epochs = 75
----> 4 bcnn.train(x,y,x_err,y_err)

d:\university\ast425\astronn\astroNN\models\Apogee_BCNN.py in train(self, input_data, labels, inputs_err, labels_err)
    111                                        validation_steps=self.val_num // self.batch_size,
    112                                        epochs=self.max_epochs, verbose=2, workers=os.cpu_count(),
--> 113                                        callbacks=[reduce_lr, csv_logger], use_multiprocessing=True)
    114 
    115         # Call the post training checklist to save parameters

~\Anaconda3\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~\Anaconda3\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2097                             val_enqueuer = GeneratorEnqueuer(validation_data,
   2098                                                              use_multiprocessing=use_multiprocessing,
-> 2099                                                              wait_time=wait_time)
   2100                         val_enqueuer.start(workers=workers, max_queue_size=max_queue_size)
   2101                         validation_generator = val_enqueuer.get()

Suggestion

Detect user's OS and enable multiprocessing in fit_generator on MacOS and Linux

henrysky commented 6 years ago

closed due to e98e62c