davisyoshida / tf2-gradient-checkpointing

Simple gradient checkpointing for eager mode execution
MIT License
46 stars 7 forks source link

checkpointing for tensor types other than float32? #3

Closed zacman2400 closed 3 years ago

zacman2400 commented 3 years ago

Suppose I want to checkpoint repeated iterations of an FFT for instance where complex64 data is required, would it be possible to relax the requirement in:

flat_inputs = nest.flatten(args) + nest.flatten(list(kwargs.values())) flat_inputs = [x for x in flat_inputs if tf.is_tensor(x)] flat_inputs = [x for x in flat_inputs if x.dtype == tf.float32]

to allow for checkpointing of complex tensors with gradient computation or is there more to it? As far as I understand (I could be mistaken) tf does allow for gradient computation with complex64 types?

davisyoshida commented 3 years ago

I haven't done anything with complex numbers in tf, so no promises, but I don't see why it wouldn't work. I'd recommend trying a small model which you can run with and without checkpointing to make sure you get the same gradient values.