Closed zacman2400 closed 3 years ago
I haven't done anything with complex numbers in tf, so no promises, but I don't see why it wouldn't work. I'd recommend trying a small model which you can run with and without checkpointing to make sure you get the same gradient values.
Suppose I want to checkpoint repeated iterations of an FFT for instance where complex64 data is required, would it be possible to relax the requirement in:
flat_inputs = nest.flatten(args) + nest.flatten(list(kwargs.values())) flat_inputs = [x for x in flat_inputs if tf.is_tensor(x)] flat_inputs = [x for x in flat_inputs if x.dtype == tf.float32]
to allow for checkpointing of complex tensors with gradient computation or is there more to it? As far as I understand (I could be mistaken) tf does allow for gradient computation with complex64 types?