Closed jvrsgsty closed 5 years ago
Ugh... Basically had to roll back to almost as this was before. The reason we are doing this this way, is that the TPUEstimator API wants the dataset func to strictly adhere to this signature: it must have a positional argument named params.
So what I propose instead, is inferring self.on_tpu
depending on which dataset_func
is called.
Sadly, we do have to adhere to two different call signatures, for now.
This makes sense to me, merging now.
I found it weird that to use this unified data provider on tpu you have to specify you are doing things on tpu twice: once with the
on_tpu
arg in the constructor, and once when you choose which function to pass in the data params dictionary.I also thought it would be best if we keep things consistent between gpu and tpu and I removed
is_train
from the constructor args, and pass it via the kwargs in the dataset func.With this, usage would go from
to this, which would only differ from a GPU call in the
on_tpu
argument.