bytedance / byteps

A high performance and generic framework for distributed DNN training
Other
3.63k stars 488 forks source link

tensorflow.push_pull CANNOT calculate AVERAGE, it always gives SUM #323

Open HugoZHL opened 3 years ago

HugoZHL commented 3 years ago

Describe the bug In Tensorflow push_pull function, op can never equal Average since op is a value from Enum class while Average is just a str. This bug causes the push_pull function in Tensorflow to always calculate the SUM of tensors instead of AVERAGE.

To Reproduce Steps to reproduce the behavior:

  1. add a log print('Type of op and Average: ', type(op), type(Average)) after line op = handle_average_backwards_compatibility(op, average) in tensorflow.push_pull function
  2. run whatever a program that uses tensorflow.push_pull, e.g. codes that use DistributedOptimizer.
  3. See the types of op and Average, which will be printed during the setup of Tensorflow graph
  4. The log will show Type of op and Average: <enum 'ReduceOps'> <class 'str'>

Expected behavior Here we cannot compare enum type with str type. There're two solutions to this problem:

  1. use op.value to compare with Average in tensorflow.push_pull function
  2. change ReduceOps class from a Enum class to a normal class

Screenshots here in tensorflow.ops the ReduceOps inherits Enum image

here in tensorflow.__init__ the program directly compares op (which is RecudeOps.Average) to Average (which is a str) image

the function uses op == Average to judge whether calculating average of tensors or sum of tensors now it can only give sum of tensors.

Environment (please complete the following information): Whatever

Additional context Add any other context about the problem here.

bobzhuyb commented 3 years ago

@pleasantrabbit I think we can support Average easily

pleasantrabbit commented 3 years ago

@pleasantrabbit I think we can support Average easily

Yes, we'll add Average support.

HugoZHL commented 3 years ago

https://github.com/bytedance/byteps/pull/324 here's the pull request to fix this bug.