Closed LeoXinhaoLee closed 1 year ago
We used value_range(-1,1)
for most of our experiments, so you should use that same preprocessing when you use the models. If you use pre-trained models, you should check the original configuration (e.g. big_vision/configs/vit_i1k.py
). If you train from scratch, it shouldn't make a difference whether you use value_range(-1,1)
or vgg_value_range
.
Thank you very much for your kind reply. We've successfully reproduced the vit_s_i1k result after 90 epochs! If you don't mind, one more thing that is troubling us is the speed of data loading by TFDS from a gs bucket. Our TFRecords are organized in a different way and we are using an old-fashioned way to load them, which seems quite slow (on v3-64 90ep takes 2h instead of 6.5 / 8 = 0.8 h). Would you mind telling us what causes our way to be much slower than yours? Huge thanks!
def training(config):
data_dir = os.environ['TFDS_DATA_DIR']
file_pattern = f"{data_dir}/train-in1k-*-of-*"
num_records = 1281167
file_list = tf.io.gfile.glob(file_pattern)
dataset = tf.data.TFRecordDataset(file_list, num_parallel_reads=tf.data.AUTOTUNE)
if jax.process_index() == 0:
print(f'Total Number of Samples: {num_records}')
# Partition of the entire data set
partition_size = num_records // jax.process_count()
if jax.process_index() < jax.process_count() - 1:
dataset = dataset.skip(jax.process_index() * partition_size).take(partition_size)
else:
dataset = dataset.skip(jax.process_index() * partition_size).take(num_records - jax.process_index() * partition_size)
# Define a function to parse each record
def parse_record(record):
features = {
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'image/class/label': tf.io.FixedLenFeature([], tf.int64),
}
example = tf.io.parse_single_example(record, features)
image = example['image/encoded']
label = example['image/class/label']
return {'image': image, 'labels': label}
def parse_transform(record, transform):
sample = transform({'image':record['image']})
label = tf.one_hot(record['labels'], 1000)
return {'_mask': tf.constant(1), 'image': sample['image'], 'labels': label}
global_bs = config.input.batch_size
local_bs = global_bs // jax.process_count()
num_ex_per_process = [partition_size for _ in range(jax.process_count() - 1)]
num_ex_per_process.append(num_records - (jax.process_count() - 1) * partition_size)
num_batches = math.ceil(1. * max(num_ex_per_process) / local_bs)
dataset = dataset.map(parse_record, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.repeat(None) # Each node repeats its partition indefinitely
dataset = dataset.shuffle(config.input.shuffle_buffer_size)
transform = pp_builder.get_preprocess_fn(config.input.pp, log_data=False)
dataset = dataset.map(functools.partial(parse_transform, transform=transform), num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(local_bs)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset, num_records
I would try the following two experiments to find the root cause of the slowness:
Thank you so much for your kind advice! We found that it's indeed 2. that is slowing us down, and we will improve it.
Another thing we noticed is that in the vit_s16_i1k experiment, we additionally return the grads
and updates
of update_fn()
(no other changes), and find that after the 1st and 2nd iterations, grads
is all 0 except for head parameters (kernel, bias), and from the 3rd iteration grad
is non-zero for all parameters. Is this phenomenon expected? Thank you so much for your help!
That's probably due to the head_zeroinit = True
default.
Oh, I see! The learning rate at the 1st iteration is 0 so no parameter is updated. In the 2nd iteration, the kernel weights of the last head layer are still 0 and thus no gradient is back-proped to earlier layers.
Thank you so much for helping us out! Really appreciate it!
Hi, thank you so much for releasing code for these inspiring works. I notice that the config file uses
value_range(-1, 1)
instead ofvgg_value_range
. Is (-1,1) necessary for reproducing results on a normal imagenet dataset?Thank you very much for your time and help.