google-research / big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Apache License 2.0
2.16k stars 147 forks source link

Question regarding value range (-1,1) #37

Closed LeoXinhaoLee closed 1 year ago

LeoXinhaoLee commented 1 year ago

Hi, thank you so much for releasing code for these inspiring works. I notice that the config file uses value_range(-1, 1) instead of vgg_value_range. Is (-1,1) necessary for reproducing results on a normal imagenet dataset?

Thank you very much for your time and help.

andsteing commented 1 year ago

We used value_range(-1,1) for most of our experiments, so you should use that same preprocessing when you use the models. If you use pre-trained models, you should check the original configuration (e.g. big_vision/configs/vit_i1k.py). If you train from scratch, it shouldn't make a difference whether you use value_range(-1,1) or vgg_value_range.

LeoXinhaoLee commented 1 year ago

Thank you very much for your kind reply. We've successfully reproduced the vit_s_i1k result after 90 epochs! If you don't mind, one more thing that is troubling us is the speed of data loading by TFDS from a gs bucket. Our TFRecords are organized in a different way and we are using an old-fashioned way to load them, which seems quite slow (on v3-64 90ep takes 2h instead of 6.5 / 8 = 0.8 h). Would you mind telling us what causes our way to be much slower than yours? Huge thanks!

def training(config):
  data_dir = os.environ['TFDS_DATA_DIR']
  file_pattern = f"{data_dir}/train-in1k-*-of-*"
  num_records = 1281167

  file_list = tf.io.gfile.glob(file_pattern)
  dataset = tf.data.TFRecordDataset(file_list, num_parallel_reads=tf.data.AUTOTUNE)
  if jax.process_index() == 0:
    print(f'Total Number of Samples: {num_records}')

  # Partition of the entire data set
  partition_size = num_records // jax.process_count()
  if jax.process_index() < jax.process_count() - 1:
    dataset = dataset.skip(jax.process_index() * partition_size).take(partition_size)
  else:
    dataset = dataset.skip(jax.process_index() * partition_size).take(num_records - jax.process_index() * partition_size)

  # Define a function to parse each record
  def parse_record(record):
    features = {
      'image/encoded': tf.io.FixedLenFeature([], tf.string),
      'image/class/label': tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(record, features)
    image = example['image/encoded']
    label = example['image/class/label']
    return {'image': image, 'labels': label}

  def parse_transform(record, transform):
    sample = transform({'image':record['image']})
    label = tf.one_hot(record['labels'], 1000)
    return {'_mask': tf.constant(1), 'image': sample['image'], 'labels': label}

  global_bs = config.input.batch_size
  local_bs = global_bs // jax.process_count()
  num_ex_per_process = [partition_size for _ in range(jax.process_count() - 1)]
  num_ex_per_process.append(num_records - (jax.process_count() - 1) * partition_size)
  num_batches = math.ceil(1. * max(num_ex_per_process) / local_bs)

  dataset = dataset.map(parse_record, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.cache()
  dataset = dataset.repeat(None)  # Each node repeats its partition indefinitely
  dataset = dataset.shuffle(config.input.shuffle_buffer_size)

  transform = pp_builder.get_preprocess_fn(config.input.pp, log_data=False)
  dataset = dataset.map(functools.partial(parse_transform, transform=transform), num_parallel_calls=tf.data.AUTOTUNE)

  dataset = dataset.batch(local_bs)
  dataset = dataset.prefetch(tf.data.AUTOTUNE)

  return dataset, num_records
andsteing commented 1 year ago

I would try the following two experiments to find the root cause of the slowness:

  1. Do an epoch of training with fake data (e.g. black images) – this should be much faster, otherwise the slowness has nothing to do with the data loading.
  2. Store the data locally and run another epoch of training. If this is faster, then the slowness is due to loading data from cloud storage. You could then try to better colocate the data with your training.
  3. If loading locally is also slow, then I would indeed try to optimize above loading code. I don't see anything obviously wrong with that code, so maybe 2. will solve the issue at hand.
LeoXinhaoLee commented 1 year ago

Thank you so much for your kind advice! We found that it's indeed 2. that is slowing us down, and we will improve it.

Another thing we noticed is that in the vit_s16_i1k experiment, we additionally return the grads and updates of update_fn() (no other changes), and find that after the 1st and 2nd iterations, grads is all 0 except for head parameters (kernel, bias), and from the 3rd iteration grad is non-zero for all parameters. Is this phenomenon expected? Thank you so much for your help!

andsteing commented 1 year ago

That's probably due to the head_zeroinit = True default.

https://github.com/google-research/big_vision/blob/47ac2fd075fcb66cadc0e39bd959c78a6080070d/big_vision/models/vit.py#L162

LeoXinhaoLee commented 1 year ago

Oh, I see! The learning rate at the 1st iteration is 0 so no parameter is updated. In the 2nd iteration, the kernel weights of the last head layer are still 0 and thus no gradient is back-proped to earlier layers.

Thank you so much for helping us out! Really appreciate it!