Hi, thanks for your wonderful work. From the paper, the batch normalization statistics updates for combined source and target images, and keep fixed when computing source images for the second time. However, from https://github.com/google-research/adamatch/blob/main/domain_adaptation/adamatch.py#L63 and L67, the model keeps training states. I am not familiar with Jax. Would you mind giving some explanations?
Hi, thanks for your wonderful work. From the paper, the batch normalization statistics updates for combined source and target images, and keep fixed when computing source images for the second time. However, from https://github.com/google-research/adamatch/blob/main/domain_adaptation/adamatch.py#L63 and L67, the model keeps training states. I am not familiar with Jax. Would you mind giving some explanations?