tensorflow / kfac

An implementation of KFAC for TensorFlow
Apache License 2.0
197 stars 41 forks source link

Wrong incompatible versions + worse than Adam performance. #45

Open ghost opened 3 years ago

ghost commented 3 years ago

According to the docs, this optimizer is supposed to converge much faster (>3.5x) and with fewer iterations (>14x) than SGD with Momentum. In the K-fac vs Adam example, I encountered several issues to be able to run the example which are the result of incompatibility issues between tfds and tf. I kept reinstalling and ended up with the following versions:


Here's a modified version of the jupyter notebook. I get an error by the second to last cell:

AssertionError                            Traceback (most recent call last)
<ipython-input-11-e1357b45c3db> in <module>()
      5                                 seed=SEED,
      6                                 repeat_validation=True,
----> 7                                 use_augmentation=True)
      9 model = resnet_v2(input_shape=info['input_shape'],

15 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_datasets/core/tfrecords_reader.py in _str_to_relative_instruction(spec)
    355   res = _SUB_SPEC_RE.match(spec)
    356   if not res:
--> 357     raise AssertionError('Unrecognized instruction format: %s' % spec)
    358   unit = '%' if res.group('from_pct') or res.group('to_pct') else 'abs'
    359   return ReadInstruction(

AssertionError: Unrecognized instruction format: NamedSplit('train')(tfds.percent[:80])

Which for the sake of the example I hard-coded by setting res = _SUB_SPEC_RE.match('test')

I then I did 2 runs, one using Adam and the other using K-fac. I was expecting a performance close to what was mentioned in the docs. However here's a summary of what I got:

Note: The following stats are calculated for 60 epochs

execution time: Adam 18 minutes - 45 minutes K-fac
best seen validation accuracy: Adam 84% - 91% K-fac
best seen training accuracy: Adam 98% - 98% K-fac
best seen loss: Adam 0.2511 - 0.2524 K-fac

And you'll find the output of both runs in the notebook I included above.

james-martens commented 2 years ago

Thanks for your persistence trying to get the code to work. Unfortunately I don't really know much about Python library versions and this code isn't really being maintained anymore.

Regarding performance, you probably won't see a large performance gain when optimizing ResNet architectures as in the Colab, unless you use very large batches sizes (see https://arxiv.org/abs/1907.04164). That "14x" figure applies only to a certain architecture (deep autoencoders), and isn't meant to be universal. However, I can see from the README that the phrasing suggests otherwise, and so I've removed it. So far, the most compelling applications of K-FAC that I'm aware of are to deep autoencoders and vanilla networks using DKS/TAT. See https://arxiv.org/abs/2110.01765

james-martens commented 2 years ago

I've updated the package documentation and hopefully fixed the issues with the installation.

tranvansang commented 2 years ago


Here's a modified version of the jupyter notebook. I get an error by the second to last cell:

@unsignedrant I have exactly the same error., It took me much time to figure out the error Could you share your modified version again because now I can't access your provided link?

Thanks a lot!

tranvansang commented 2 years ago

I kept reinstalling and ended up with the following versions:

tensorflow==1.15 tensorflow-datasets==3.0.0 tensorflow-probability==0.7.0

I follow your specified versions but encountered an error.

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
kfac 0.2.4 requires tensorflow-probability==0.8, but you have tensorflow-probability 0.7.0 which is incompatible.
dm-sonnet 1.36 requires tensorflow-probability<0.9.0,>=0.8.0, but you have tensorflow-probability 0.7.0 which is incompatible.

Setting tensorflow-probability to 0.8.0 solved the installation issue but I still face the error

AssertionError: Unrecognized instruction format: NamedSplit('train')(tfds.percent[:80])

Following here, I changed

  training_pct = int(100.0 * TRAINING_SIZE / (TRAINING_SIZE + VALIDATION_SIZE))
  train_split = tfds.Split.TRAIN.subsplit(tfds.percent[:training_pct])
  validation_split = tfds.Split.TRAIN.subsplit(tfds.percent[training_pct:])

  train_data, info = tfds.load('cifar10:3.*.*', with_info=True, split=train_split)
  val_data = tfds.load('cifar10:3.*.*', split=validation_split)
  test_data = tfds.load('cifar10:3.*.*', split='test')


  (train_data, val_data, test_data), info = tfds.load(
    split=('train[:80]', 'train[:10]', 'train[:10]'),

then another error occurred

TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_163779/2518032390.py in <module>
      5                                 seed=SEED,
      6                                 repeat_validation=True,
----> 7                                 use_augmentation=True)
      9 model = resnet_v2(input_shape=info['input_shape'],

/tmp/ipykernel_163779/3970928241.py in get_input_pipeline(batch_size, use_augmentation, seed, crop_amount, drop_remainder, repeat_validation)
     77     batch_size = max(TRAINING_SIZE, VALIDATION_SIZE, TEST_SIZE)
---> 79   train_data = train_data.map(_parse_fn).shuffle(8192, seed=seed).repeat()
     80   if use_augmentation:
     81     train_data = train_data.map(

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls)
   1218     """
   1219     if num_parallel_calls is None:
-> 1220       return MapDataset(self, map_func, preserve_cardinality=True)
   1221     else:
   1222       return ParallelMapDataset(

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality, use_legacy_function)
   3432         self._transformation_name(),
   3433         dataset=input_dataset,
-> 3434         use_legacy_function=use_legacy_function)
   3435     variant_tensor = gen_dataset_ops.map_dataset(
   3436         input_dataset._variant_tensor,  # pylint: disable=protected-access

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
   2711       resource_tracker = tracking.ResourceTracker()
   2712       with tracking.resource_tracker_scope(resource_tracker):
-> 2713         self._function = wrapper_fn._get_concrete_function_internal()
   2714         if add_to_graph:
   2715           self._function.add_to_graph(ops.get_default_graph())

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal(self, *args, **kwargs)
   1851     """Bypasses error checking when getting a graph function."""
   1852     graph_function = self._get_concrete_function_internal_garbage_collected(
-> 1853         *args, **kwargs)
   1854     # We're returning this concrete function to someone, and they may keep a
   1855     # reference to the FuncGraph without keeping a reference to the

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   1845     if self.input_signature:
   1846       args, kwargs = None, None
-> 1847     graph_function, _, _ = self._maybe_define_function(args, kwargs)
   1848     return graph_function

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2145         graph_function = self._function_cache.primary.get(cache_key, None)
   2146         if graph_function is None:
-> 2147           graph_function = self._create_graph_function(args, kwargs)
   2148           self._function_cache.primary[cache_key] = graph_function
   2149         return graph_function, args, kwargs

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2036             arg_names=arg_names,
   2037             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2038             capture_by_value=self._capture_by_value),
   2039         self._function_attributes,
   2040         # Tell the ConcreteFunction to clean up its graph once it goes out of

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    913                                           converted_func)
--> 915       func_outputs = python_func(*func_args, **func_kwargs)
    917       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py in wrapper_fn(*args)
   2705           attributes=defun_kwargs)
   2706       def wrapper_fn(*args):  # pylint: disable=missing-docstring
-> 2707         ret = _wrapper_helper(*args)
   2708         ret = structure.to_tensor_list(self._output_structure, ret)
   2709         return [ops.convert_to_tensor(t) for t in ret]

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py in _wrapper_helper(*args)
   2650         nested_args = (nested_args,)
-> 2652       ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
   2653       # If `func` returns a list of tensors, `nest.flatten()` and
   2654       # `ops.convert_to_tensor()` would conspire to attempt to stack

~/anaconda3/envs/kfac/lib/python3.7/site-packages/tensorflow_core/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    235       except Exception as e:  # pylint:disable=broad-except
    236         if hasattr(e, 'ag_error_metadata'):
--> 237           raise e.ag_error_metadata.to_exception(e)
    238         else:
    239           raise

TypeError: in converted code:

    TypeError: tf___parse_fn() takes 1 positional argument but 2 were given


finally, I changed to

  # train_data, info = tfds.load('cifar10:3.*.*', with_info=True, split=train_split)
  # val_data = tfds.load('cifar10:3.*.*', split=validation_split)
  # test_data = tfds.load('cifar10:3.*.*', split='test')
  (train_data, val_data, test_data), info = tfds.load(
    split=('train[:80]', 'train[:10]', 'train[:10]'),


!pip install tensorflow-datasets==1.3.0

Then the ipynd file worked.

tranvansang commented 5 months ago

@james-martens can you share source code of experiments conducted in this paper "Distributed Second-Order Optimization using Kronecker-Factored Approximations"?


The paper shows that K-FAC outperforms SGD+momentum under from middle-sized to large (256, 512, 2048) batch sizes settings. I am curious about the hyperparameter settings of the experiment.


