Closed liutianlin0121 closed 1 year ago
Two small changes:
(1) Use optax.inject_hyperparams to anneal learning rate, so that we can directly read out the learning rate from state. (2) Use numpy_collate in dataloader, so that we don't need to convert from pytorch tensor to numpy.
optax.inject_hyperparams
Two small changes:
(1) Use
optax.inject_hyperparams
to anneal learning rate, so that we can directly read out the learning rate from state. (2) Use numpy_collate in dataloader, so that we don't need to convert from pytorch tensor to numpy.