OATML / bdl-benchmarks

Bayesian Deep Learning Benchmarks
Apache License 2.0
663 stars 94 forks source link

TF 2.0 full release breaks image preprocessing #4

Open jarrodhaas opened 5 years ago

jarrodhaas commented 5 years ago

Hi OATML,

(My apologies in advance as preview mode was indicating some formatting issues that I'm not able to fix.)

TF 2.0 breaks image preprocessing for bdl-bechmarks. The problem appears to be that TF 2.0 changed how it works with python's local symbol table since TF 2.0 Beta.

This means that transforms.compose is no longer able to properly compose a transformation for TF's dataset.map function. The information that the compose function requires is no longer present in the local table:

TF 2.0 Beta provides:

output of locals(): { 'nargs': 1, 'f': <bdlb.diabetic_retinopathy_diagnosis.benchmark.DiabeticRetinopathyDiagnosisBecnhmark._preprocessors..Parse object at 0x7f6a6314bcf8>, 'inspect': <module 'inspect' from '/usr/local/lib/python3.6/inspect.py'>, 'x': {'image': <tf.Tensor 'args_0:0' shape=(None, 256, 256, 3) dtype=uint8>, 'label': <tf.Tensor 'args_1:0' shape=(None,) dtype=int64>, 'name': <tf.Tensor 'args_2:0' shape=(None,) dtype=string>}, 'self': <bdlb.core.transforms.Compose object at 0x7f6a5c1c7240>}

While TF 2.0 provides only:

output of locals(): { 'caller_fn_scope': <tensorflow.python.autograph.core.function_wrappers.FunctionScope object at 0x7efaf43c0080>, 'kwargs': None, 'args': (), 'options': <tensorflow.python.autograph.core.converter.ConversionOptions object at 0x7efaf43c02b0>, 'f': }


A simple proposed fix to get things working again is to not rely on locals() to discern information about a class being passed to the composition, but instead to define explicit class function signatures, so that we rely solely on how python can interpret itself (i.e. use only the inspect module). Note, of course, that this solution is not robust to potential changes in preprocessing functions that may be composed.

First, we create a unique function signature for the CastX() corner case in bdl-benchmarks/bdlb/diabetic_retinopathy_diagnosis/benchmark.py:

`def call(self, x, y_nochange):

    return tf.cast(x, self.dtype), y_nochange`

Then we use this to compose instead of locals(), in bdl-benchmarks/bdlb/core/transforms.py:

`def call(self, x): import inspect

for f in self.trans:

  last_x, last_y = None, None  

  nargs = len(inspect.signature(f).parameters)

  if (nargs == 2) and ("y" in inspect.signature(f).parameters):
    print("y in locals")    
    x, y = f(x, y)
    last_x = 1
    last_y = 1

  else:
    if nargs == 1:
      x = f(x)
      last_x = 1

    else:
      x, y = f(x[0], x[1])
      last_x = 1
      last_y = 1  

# If the last function has in the composition 2 variables to return, do so, otherwise return only 1 variable

if last_y != None:
  return x, y
else:
  return x`