google / ml-compiler-opt

Infrastructure for Machine Learning Guided Optimization (MLGO) in LLVM.
Apache License 2.0
629 stars 93 forks source link

Why is the length of the reward limited to 3 or more? #328

Closed jun-shibata closed 1 month ago

jun-shibata commented 10 months ago

The following filtering appears to truncate data when more than two observations are not available for each object.

  def _file_dataset_fn(data_path):
    dataset = (
        tf.data.Dataset.list_files(data_path).shuffle(
            files_buffer_size).interleave(
                input_dataset, cycle_length=num_readers, block_length=1)
        # Due to a bug in collection, we sometimes get empty rows.
        .filter(lambda string: tf.strings.length(string) > 0).apply(
            tf.data.experimental.shuffle_and_repeat(shuffle_buffer_size)).map(
                parser_fn, num_parallel_calls=num_map_threads)
        # Only keep sequences of length 2 or more.
        .filter(lambda traj: tf.size(traj.reward) > 2))   <- HERE

In our experiments, in cases where one or two optimizations are performed on each object, all data may be filtered, causing hangs in subsequent processing. Do you know why this filter is set? Is it OK to set the filter condition arbitrary?

mtrofin commented 10 months ago

I assume this is here (could you please use the permalink feature like I did - easier to follow)

Do you mean that you have cases where for every module in a particular corpus sample, there are at most 2 decisions? If there are hangs, I believe the right fix would be to handle that graciously, because there could well be modules where there are no decisions, so we should skip over and retry (i.e. resample). Probably in local_data_collector.py.

The value "2" there IIRC was an optimization - i.e. there was little to gain from such short trajectories.

jun-shibata commented 10 months ago

Thank you! As you say, this happens when there are at most two decisions on each object. At the behavioral cloning stage, if all the size of trajectories in the default trace are less than 3, the dataset is empty. As a result, hang occurs here. On the other hand, I understand that such a small dataset is not practical. If you don't mind, there may be something I can do to fix it, so I'll leave this issue open.

mtrofin commented 10 months ago

Sounds good, and patches are very welcome!