magenta / ddsp

DDSP: Differentiable Digital Signal Processing
https://magenta.tensorflow.org/ddsp
Apache License 2.0
2.92k stars 341 forks source link

Adding a PretrainedCREPEEmbeddingLoss to training #229

Open JCBrouwer opened 4 years ago

JCBrouwer commented 4 years ago

Hello, I've trained a model for a while using the solo_instrument config at 48 kHz, but the audio is still fairly noisy even after 117k steps (spectral loss is ~9 on average).

I'd like to continue training with the PretrainedCREPEEmbeddingLoss() enabled as well to encourage more natural / perceptually realistic synthesis.

I've tried just adding the loss into the ae.gin file, but get the following error which I don't really understand:

Traceback (most recent call last):
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/hans/code/maua-ddsp/ddsp/training/ddsp_run.py", line 231, in <module>
    console_entry_point()
  File "/home/hans/code/maua-ddsp/ddsp/training/ddsp_run.py", line 227, in console_entry_point
    app.run(main)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/hans/code/maua-ddsp/ddsp/training/ddsp_run.py", line 205, in main
    report_loss_to_hypertune=FLAGS.hypertune,
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/gin/config.py", line 1078, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/gin/utils.py", line 49, in augment_exception_message_and_reraise
    six.raise_from(proxy.with_traceback(exception.__traceback__), None)
  File "<string>", line 3, in raise_from
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/gin/config.py", line 1055, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/hans/code/maua-ddsp/ddsp/training/train_util.py", line 185, in train
    trainer.build(next(dataset_iter))
  File "/home/hans/code/maua-ddsp/ddsp/training/trainers.py", line 134, in build
    _ = self.run(tf.function(self.model.__call__), batch)
  File "/home/hans/code/maua-ddsp/ddsp/training/trainers.py", line 129, in run
    return self.strategy.run(fn, args=args, kwargs=kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1211, in run
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py", line 2585, in call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 585, in _call_for_each_replica
    self._container_strategy(), fn, args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/distribute/mirrored_run.py", line 78, in call_for_each_replica
    return wrapped(args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 904, in _call
    return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2828, in __call__
    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
AssertionError: in user code:

    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:896 fn_with_cond  *
        functools.partial(self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201 wrapper  **
        return target(*args, **kwargs)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:507 new_func
        return func(*args, **kwargs)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py:1180 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/ops/cond_v2.py:92 cond_v2
        op_return_value=pred)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:986 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1848 _filtered_call
        cancellation_manager=cancellation_manager)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1877 _call_flat
        for v in self._func_graph.variables:
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:489 variables
        return tuple(deref(v) for v in self._weak_variables)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:489 <genexpr>
        return tuple(deref(v) for v in self._weak_variables)
    /home/hans/.conda/envs/ddsp/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:482 deref
        "Called a function referencing variables which have been deleted. "

    AssertionError: Called a function referencing variables which have been deleted. This likely means that function-local variables were created and not referenced elsewhere in the program. This is generally a mistake; consider storing variables in an object attribute on first call.

  In call to configurable 'train' (<function train at 0x7f995ef00268>)

How can I train with this loss enabled?

jesseengel commented 4 years ago

Can you give more details on your initial config/run command and the one used for restarting the job? Are you warmstarting from the pretrained checkpoint but adding a new loss?

JCBrouwer commented 4 years ago

Yes I want to warmstart with the pretrained checkpoint. Although I get the same error when training from scratch with the crepe embedding loss added in ae.gin.

My original training command:

python -m ddsp.training.ddsp_run \
  --mode=train \
  --alsologtostderr \
  --save_dir="/home/hans/modelzoo/neuro-bass-ddsp-48kHz/" \                                                             
  --gin_file=models/solo_instrument.gin \               
  --gin_file=datasets/tfrecord.gin \                      
  --gin_param="TFRecordProvider.file_pattern='/home/hans/datasets/neuro-bass-ddsp/48kHz/train.tfrecord*'" \
  --gin_param="batch_size=16" \
  --gin_param="train_util.train.num_steps=300000" \
  --gin_param="train_util.train.steps_per_save=3000" \
  --gin_param="trainers.Trainer.checkpoints_to_keep=10" \
  --gin_param="TFRecordProvider.example_secs=4" \
  --gin_param="TFRecordProvider.sample_rate=48000" \
  --gin_param="TFRecordProvider.frame_rate=250" \
  --gin_param="Additive.n_samples=192000" \
  --gin_param="Additive.sample_rate=48000" \
  --gin_param="FilteredNoise.n_samples=192000"

Then after having trained overnight, I've added PretrainedCREPEEmbeddingLoss() in ae.gin (which solo_instrument.gin inherits from):

Autoencoder.losses = [
    @losses.SpectralLoss(),
    @losses.PretrainedCREPEEmbeddingLoss(),
]

Then I'm running and getting the error (the error is the same with or without --restore_dir):

python -m ddsp.training.ddsp_run \
  --mode=train \
  --alsologtostderr \
  --save_dir="/home/hans/modelzoo/neuro-bass-ddsp-48kHz-crepe/" \
  --restore_dir="/home/hans/modelzoo/neuro-bass-ddsp-48kHz/" \
  --gin_file=models/solo_instrument.gin \        
  --gin_file=datasets/tfrecord.gin \
  --gin_param="TFRecordProvider.file_pattern='/home/hans/datasets/neuro-bass-ddsp/48kHz/train.tfrecord*'" \
  --gin_param="batch_size=16" \
  --gin_param="train_util.train.num_steps=300000" \
  --gin_param="train_util.train.steps_per_save=3000" \
  --gin_param="trainers.Trainer.checkpoints_to_keep=10" \
  --gin_param="TFRecordProvider.example_secs=4" \
  --gin_param="TFRecordProvider.sample_rate=48000" \
  --gin_param="TFRecordProvider.frame_rate=250" \
  --gin_param="Additive.n_samples=192000" \
  --gin_param="Additive.sample_rate=48000" \
  --gin_param="FilteredNoise.n_samples=192000"
JCBrouwer commented 4 years ago

Update: I've found that running with only a single GPU (via CUDA_VISIBLE_DEVICES=0) does work to train with the PretrainedCREPEEmbeddingLoss.

Is there a way to allow the PretrainedCREPEEmbeddingLoss to work with multi-gpu training?