awslabs / keras-apache-mxnet

[DEPRECATED] Amazon Deep Learning's Keras with Apache MXNet support
https://github.com/awslabs/keras-apache-mxnet/wiki
Other
290 stars 65 forks source link

Fix axis error in normalization layer when loading model from tf backend saved h5 #258

Closed leondgarse closed 4 years ago

leondgarse commented 4 years ago

Fix axis error in normalization layer when loading model from tf backend saved h5

Summary

When loading model saved by tensorflow backend keras h5 file, met an error:

/opt/anaconda3/lib/python3.7/site-packages/keras/layers/normalization.py in build(self, input_shape)
     98
     99     def build(self, input_shape):
--> 100         dim = input_shape[self.axis]
    101         print(input_shape, self.axis, dim)
    102         if dim is None
TypeError: tuple indices must be integers or slices, not list

I write a little demo to reproduce it:

PR Overview

It seems in tensorflow backend keras, axis in BatchNormalization is a list, so I add an isinstance test to the self.axis init. Then the load_model function passed.

roywei commented 4 years ago

Hi @leondgarse , thank you so much for your contribution.

The build error you see in CI may not due to your code change, I will double check and get back to you. We are having some problem in the CI system.

In the meantime, could you add a unit test here? https://github.com/awslabs/keras-apache-mxnet/tree/master/tests/keras/backend You can use your reproducible code as a test and add a new file, something like mxnet_tf_model_test.py

Thanks!

leondgarse commented 4 years ago

Hi @roywei, I added a unit test file tests/keras/backend/mxnet_tf_model_test.py, and my local test is ok:

import mxnet_tf_model_test
aa = mxnet_tf_model_test.TestMXNetTfModel()
aa.test_batchnorm_layer_reload()
# Using MXNet backend
# axis =  [1]
# (1, 10) 1 10
# axis =  -1
# (1, 10) 1 10
# axis =  1
# (1, 10) 1 10

It tests loading a tf backend saved model, and then loading a mxnet backend saved model, to make sure everything alright.

leondgarse commented 4 years ago

Thanks for your work. Is my triggering the right method? What's the failure this time?

roywei commented 4 years ago

@leondgarse could you try push an empty commit and trigger CI? git commit --allow-empty -m "trigger ci"

It seems a multi-threaded test failed in our test environment(docker), and it seems random. So re-trigger should work. Our nightly tests have been passing for a few days.


=================================== FAILURES ===================================
--
699 | ________________________________ test_warnings _________________________________
700 | [gw1] linux -- Python 3.7.6 /root/.pyenv/versions/3.7.6/bin/python3.7
701 |  
702 | @pytest.mark.skipif(sys.version_info < (3,),
703 | reason='Cannot catch warnings in python 2')
704 | def test_warnings():
705 | a = Input(shape=(3,), name='input_a')
706 | b = Input(shape=(3,), name='input_b')
707 |  
708 | a_2 = Dense(4, name='dense_1')(a)
709 | dp = Dropout(0.5, name='dropout')
710 | b_2 = dp(b)
711 |  
712 | model = Model([a, b], [a_2, b_2])
713 |  
714 | optimizer = 'rmsprop'
715 | loss = 'mse'
716 | loss_weights = [1., 0.5]
717 | model.compile(optimizer, loss, metrics=[], loss_weights=loss_weights,
718 | sample_weight_mode=None)
719 |  
720 | @threadsafe_generator
721 | def gen_data(batch_sz):
722 | while True:
723 | yield ([np.random.random((batch_sz, 3)),
724 | np.random.random((batch_sz, 3))],
725 | [np.random.random((batch_sz, 4)),
726 | np.random.random((batch_sz, 3))])
727 |  
728 | with pytest.warns(Warning) as w:
729 | out = model.fit_generator(gen_data(4),
730 | steps_per_epoch=10,
731 | use_multiprocessing=True,
732 | >                                     workers=2)
733 |  
734 | tests/keras/engine/test_training.py:607:
735 | _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
736 | keras/legacy/interfaces.py:91: in wrapper
737 | return func(*args, **kwargs)
738 | keras/engine/training.py:1433: in fit_generator
739 | initial_epoch=initial_epoch)
740 | keras/engine/training_generator.py:181: in fit_generator
741 | generator_output = next(output_generator)
742 | keras/utils/data_utils.py:695: in get
743 | inputs = self.queue.get(block=True).get()
744 | /root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py:651: in get
745 | self.wait(timeout)
746 | /root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py:648: in wait
747 | self._event.wait(timeout)
748 | /root/.pyenv/versions/3.7.6/lib/python3.7/threading.py:552: in wait
749 | signaled = self._cond.wait(timeout)
750 | _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
751 |  
752 | self = <Condition(<unlocked _thread.lock object at 0x7fba4402c990>, 0)>
753 | timeout = None
754 |  
755 | def wait(self, timeout=None):
756 | """Wait until notified or until a timeout occurs.
757 |  
758 | If the calling thread has not acquired the lock when this method is
759 | called, a RuntimeError is raised.
760 |  
761 | This method releases the underlying lock, and then blocks until it is
762 | awakened by a notify() or notify_all() call for the same condition
763 | variable in another thread, or until the optional timeout occurs. Once
764 | awakened or timed out, it re-acquires the lock and returns.
765 |  
766 | When the timeout argument is present and not None, it should be a
767 | floating point number specifying a timeout for the operation in seconds
768 | (or fractions thereof).
769 |  
770 | When the underlying lock is an RLock, it is not released using its
771 | release() method, since this may not actually unlock the lock when it
772 | was acquired multiple times recursively. Instead, an internal interface
773 | of the RLock class is used, which really unlocks it even when it has
774 | been recursively acquired several times. Another internal interface is
775 | then used to restore the recursion level when the lock is reacquired.
776 |  
777 | """
778 | if not self._is_owned():
779 | raise RuntimeError("cannot wait on un-acquired lock")
780 | waiter = _allocate_lock()
781 | waiter.acquire()
782 | self._waiters.append(waiter)
783 | saved_state = self._release_save()
784 | gotit = False
785 | try:    # restore state no matter what (e.g., KeyboardInterrupt)
786 | if timeout is None:
787 | >               waiter.acquire()
788 | E               Failed: Timeout >1200.0s
789 |  
790 | /root/.pyenv/versions/3.7.6/lib/python3.7/threading.py:296: Failed
791 | ----------------------------- Captured stdout call -----------------------------
792 | Epoch 1/1
793 | ----------------------------- Captured stderr call -----------------------------
794 |  
795 | +++++++++++++++++++++++++++++++++++ Timeout ++++++++++++++++++++++++++++++++++++
796 |  
797 | ~~~~~~~~~~~~~~~~~~~~ Stack of Thread-272 (140435986855680) ~~~~~~~~~~~~~~~~~~~~~
798 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 890, in _bootstrap
799 | self._bootstrap_inner()
800 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 926, in _bootstrap_inner
801 | self.run()
802 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 870, in run
803 | self._target(*self._args, **self._kwargs)
804 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py", line 470, in _handle_results
805 | task = get()
806 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/connection.py", line 250, in recv
807 | buf = self._recv_bytes()
808 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
809 | buf = self._recv(4)
810 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/connection.py", line 379, in _recv
811 | chunk = read(handle, remaining)
812 |  
813 | ~~~~~~~~~~~~~~~~~~~~ Stack of Thread-271 (140436114863872) ~~~~~~~~~~~~~~~~~~~~~
814 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 890, in _bootstrap
815 | self._bootstrap_inner()
816 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 926, in _bootstrap_inner
817 | self.run()
818 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 870, in run
819 | self._target(*self._args, **self._kwargs)
820 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py", line 422, in _handle_tasks
821 | for taskseq, set_length in iter(taskqueue.get, None):
822 |  
823 | ~~~~~~~~~~~~~~~~~~~~ Stack of Thread-270 (140438099777280) ~~~~~~~~~~~~~~~~~~~~~
824 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 890, in _bootstrap
825 | self._bootstrap_inner()
826 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 926, in _bootstrap_inner
827 | self.run()
828 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 870, in run
829 | self._target(*self._args, **self._kwargs)
830 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/multiprocessing/pool.py", line 413, in _handle_workers
831 | time.sleep(0.1)
832 |  
833 | ~~~~~~~~~~~~~~~~~~~~ Stack of Thread-269 (140435410818816) ~~~~~~~~~~~~~~~~~~~~~
834 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 890, in _bootstrap
835 | self._bootstrap_inner()
836 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 926, in _bootstrap_inner
837 | self.run()
838 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 870, in run
839 | self._target(*self._args, **self._kwargs)
840 | File "/codebuild/output/src242250753/src/github.com/awslabs/keras-apache-mxnet/keras/utils/data_utils.py", line 681, in _run
841 | executor.apply_async(next_sample, (self.uid,)), block=True)
842 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/queue.py", line 139, in put
843 | self.not_full.wait()
844 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/threading.py", line 296, in wait
845 | waiter.acquire()
846 |  
847 | ~~~~~~~~~~~~~~~~~~~~~ Stack of <unknown> (140440431216384) ~~~~~~~~~~~~~~~~~~~~~
848 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 285, in _perform_spawn
849 | reply.run()
850 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 220, in run
851 | self._result = func(*args, **kwargs)
852 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 967, in _thread_receiver
853 | msg = Message.from_io(io)
854 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 432, in from_io
855 | header = io.read(9)  # type 1, channel 4, payload 4
856 | File "/root/.pyenv/versions/3.7.6/lib/python3.7/site-packages/execnet/gateway_base.py", line 400, in read
857 | data = self._read(numbytes - len(buf))
858 |  
859 | +++++++++++++++++++++++++++++++++++ Timeout ++++++++++++++++++++++++++++++++++++
roywei commented 4 years ago

@leondgarse It seems you PR is constantly failing with the same multi-threaded test above. I'm not sure why it's failing. Nightly test all passed. You PR is not affect that test.

Could you try reset to commit 6e230a9 and do a git pull --rebase instead of merge? Ideally your PR should not contain my changes.

I will also double check why the test constantly fails on your case.

Sorry for the inconvenience caused.

leondgarse commented 4 years ago

Ya! This is much more like it, here is my commands:

git reset --hard 6e230a9
git pull upstream master --rebase
git push --force
roywei commented 4 years ago

@leondgarse Awsome! merging now. Thanks for your contribution!