tensorflow / transform

Input pipeline framework
Apache License 2.0
984 stars 214 forks source link

preprocessing_fn executed multiple times in pipeline #276

Closed pritamdodeja closed 1 year ago

pritamdodeja commented 2 years ago

I am trying to analyze the execution flow of my program, and have instrumented my functions as such:

def inspect_function_execution(func):                                                                                                                                                                      
    @wraps(func)                                                                                                                                                                                           
    def _(*args, **kwargs):                                                                                                                                                                                
        if INSTRUMENT_EXECUTION:                                                                                                                                                                           
            print(f"From wrapper function: Executing function named: {func.__name__}, with arguments: {args}, and keyword arguments: {kwargs}.")                                                           
            # print(f"From wrapper function: {func}")                                                                                                                                                      
            start_time = time.time()                                                                                                                                                                       
            return_value = func(*args, **kwargs)                                                                                                                                                           
            end_time = time.time()                                                                                                                                                                         
            elapsed_time = end_time - start_time                                                                                                                                                           
            print(f"From wrapper function: Execution of {func.__name__} took {elapsed_time} seconds.")                                                                                                     
            return return_value                                                                                                                                                                            
        else:                                                                                                                                                                                              
            return func(*args, **kwargs)                                                                                                                                                                   
    return _     

My pipeline function is as follows:

@inspect_function_execution                                                                                                                                                                                                                                                                                                                                                          
def my_tfrecord_writer(*args, **kwargs):                                                                                                                                                                                                                                                                                                                                             
    return beam.io.WriteToTFRecord(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                  

# {{{ pipeline function                                                                                                                                                                                                                                                                                                                                                              
from apache_beam.options.pipeline_options import PipelineOptions                                                                                                                                                                                                                                                                                                                     
pipeline_options = PipelineOptions(runner='DirectRunner', direct_num_workers=1)                                                                                                                                                                                                                                                                                                      
@inspect_function_execution                                                                                                                                                                                                                                                                                                                                                          
def pipeline_function(prefix_string, preprocessing_fn):                                                                                                                                                                                                                                                                                                                              
    with beam.Pipeline(options=pipeline_options) as pipeline:                                                                                                                                                                                                                                                                                                                        
        with tft_beam.Context(temp_dir=tempfile.mkdtemp()):                                                                                                                                                                                                                                                                                                                          
            # Create a TFXIO to read the census data with the schema. To do this we                                                                                                                                                                                                                                                                                                  
            # need to list all columns in order since the schema doesn't specify the                                                                                                                                                                                                                                                                                                 
            # order of columns in the csv.                                                                                                                                                                                                                                                                                                                                           
            # We first read CSV files and use BeamRecordCsvTFXIO whose .BeamSource()                                                                                                                                                                                                                                                                                                 
            # accepts a PCollection[bytes] because we need to patch the records first                                                                                                                                                                                                                                                                                                
            # (see "FixCommasTrainData" below). Otherwise, tfxio.CsvTFXIO can be used                                                                                                                                                                                                                                                                                                
            # to both read the CSV files and parse them to TFT inputs:                                                                                                                                                                                                                                                                                                               
            # csv_tfxio = tfxio.CsvTFXIO(...)                                                                                                                                                                                                                                                                                                                                        
            # raw_data = (pipeline | 'ToRecordBatches' >> csv_tfxio.BeamSource())                                                                                                                                                                                                                                                                                                    
            csv_tfxio = tfxio.BeamRecordCsvTFXIO(                                                                                                                                                                                                                                                                                                                                    
                physical_format='text',                                                                                                                                                                                                                                                                                                                                              
                column_names=CSV_COLUMNS,                                                                                                                                                                                                                                                                                                                                            
                schema=_SCHEMA)                                                                                                                                                                                                                                                                                                                                                      

            # Read in raw data and convert using CSV TFXIO.  Note that we apply                                                                                                                                                                                                                                                                                                      
            # some Beam transformations here, which will not be encoded in the TF                                                                                                                                                                                                                                                                                                    
            # graph since we don't do the from within tf.Transform's methods                                                                                                                                                                                                                                                                                                         
            # (AnalyzeDataset, TransformDataset etc.).  These transformations are just                                                                                                                                                                                                                                                                                               
            # to get data into a format that the CSV TFXIO can read, in particular                                                                                                                                                                                                                                                                                                   
            # removing spaces after commas.                                                                                                                                                                                                                                                                                                                                          
            raw_data = (                                                                                                                                                                                                                                                                                                                                                             
                pipeline                                                                                                                                                                                                                                                                                                                                                             
                | 'ReadTrainData' >> beam.io.ReadFromText(                                                                                                                                                                                                                                                                                                                           
                    file_pattern=train_file_path, coder=beam.coders.BytesCoder(), skip_header_lines=1)                                                                                                                                                                                                                                                                               
                # | 'FixCommasTrainData' >> beam.Map(                                                                                                                                                                                                                                                                                                                                
                #     lambda line: line.replace(b', ', b','))                                                                                                                                                                                                                                                                                                                        
                | 'DecodeTrainData' >> csv_tfxio.BeamSource()                                                                                                                                                                                                                                                                                                                        

                )                                                                                                                                                                                                                                                                                                                                                                    
            raw_dataset = (raw_data, csv_tfxio.TensorAdapterConfig())                                                                                                                                                                                                                                                                                                                

            transformed_dataset, transform_fn = (                                                                                                                                                                                                                                                                                                                                    
                raw_dataset | tft_beam.AnalyzeAndTransformDataset(                                                                                                                                                                                                                                                                                                                   
                    preprocessing_fn, output_record_batches=True))                                                                                                                                                                                                                                                                                                                   

        # Transformed metadata is not necessary for encoding.                                                                                                                                                                                                                                                                                                                        
            transformed_data, _ = transformed_dataset                                                                                                                                                                                                                                                                                                                                

        # Extract transformed RecordBatches, encode and write them to the given                                                                                                                                                                                                                                                                                                      
        # directory.                                                                                                                                                                                                                                                                                                                                                                 
            tfrecord_directory = os.path.join(WORKING_DIRECTORY, prefix_string)                                                                                                                                                                                                                                                                                                      
            if os.path.exists(tfrecord_directory) and os.path.isdir(                                                                                                                                                                                                                                                                                                                 
                    tfrecord_directory):                                                                                                                                                                                                                                                                                                                                             
                shutil.rmtree(tfrecord_directory)                                                                                                                                                                                                                                                                                                                                    
            transform_fn_output = os.path.join(tfrecord_directory,                                                                                                                                                                                                                                                                                                                   
                                               'transform_output')                                                                                                                                                                                                                                                                                                                   
            tfrecord_file_path_prefix = os.path.join(tfrecord_directory,                                                                                                                                                                                                                                                                                                             
                                                     prefix_string)                                                                                                                                                                                                                                                                                                                  
            data_written = (transformed_data | 'EncodeTrainData' >> beam.FlatMapTuple(                                                                                                                                                                                                                                                                                               
                lambda batch, x: RecordBatchToExamples(batch)) |                                                                                                                                                                                                                                                                                                                     
                # lambda x, y: y) |                                                                                                                                                                                                                                                                                                                                                  
                #    "Logging info" >> beam.Map(_logging) )                                                                                                                                                                                                                                                                                                                          
                #'WriteTrainData' >> beam.io.WriteToTFRecord(                                                                                                                                                                                                                                                                                                                        
                'WriteTrainData' >> my_tfrecord_writer(                                                                                                                                                                                                                                                                                                                              
                tfrecord_file_path_prefix, ))                                                                                                                                                                                                                                                                                                                                        
            _ = (                                                                                                                                                                                                                                                                                                                                                                    
                transform_fn                                                                                                                                                                                                                                                                                                                                                         
                | "WriteTransformFn" >>                                                                                                                                                                                                                                                                                                                                              
                tft_beam.WriteTransformFn(transform_fn_output))                                                                                                                                                                                                                                                                                                                      
    return True if data_written else False                                                                           

When I run this pipeline, I see that preprocessing_fn is executed four times when preprocessing_fn is the identity function, and five times when there is something to analyze. Is this the expected behavior? The data is only getting written once as expected. Find below the detailed output:

pipeline_function(prefix_string=PREFIX_STRING, preprocessing_fn=preprocessing_fn)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
From wrapper function: Executing function named: pipeline_function, with arguments: (), and keyword arguments: {'prefix_string': 'transformed_tfrecords', 'preprocessing_fn': <function preprocessing_fn at 0x7fee60cc14c0>}.                                                                                                                                                        
From wrapper function: Executing function named: preprocessing_fn, with arguments: ({'dropoff_latitude': <tf.Tensor 'inputs_copy:0' shape=(None, 1) dtype=float32>, 'dropoff_longitude': <tf.Tensor 'inputs_1_copy:0' shape=(None, 1) dtype=float32>, 'fare_amount': <tf.Tensor 'inputs_2_copy:0' shape=(None, 1) dtype=float32>, 'key': <tf.Tensor 'inputs_3_copy:0' shape=(None, 1) dtype=string>, 'passenger_count': <tf.Tensor 'inputs_4_copy:0' shape=(None, 1) dtype=float32>, 'pickup_datetime': <tf.Tensor 'inputs_5_copy:0' shape=(None, 1) dtype=string>, 'pickup_latitude': <tf.Tensor 'inputs_6_copy:0' shape=(None, 1) dtype=float32>, 'pickup_longitude': <tf.Tensor 'inputs_7_copy:0' shape=(None, 1) dtype=float32>},), and keyword arguments: {}.        
From wrapper function: Execution of preprocessing_fn took 0.16660404205322266 seconds.                                                                                                                                                                                                                                                                                               
From wrapper function: Executing function named: preprocessing_fn, with arguments: ({'dropoff_latitude': <tf.Tensor 'PlaceholderWithDefault:0' shape=(None, 1) dtype=float32>, 'dropoff_longitude': <tf.Tensor 'PlaceholderWithDefault_1:0' shape=(None, 1) dtype=float32>, 'fare_amount': <tf.Tensor 'PlaceholderWithDefault_2:0' shape=(None, 1) dtype=float32>, 'key': <tf.Tensor 'PlaceholderWithDefault_3:0' shape=(None, 1) dtype=string>, 'passenger_count': <tf.Tensor 'PlaceholderWithDefault_4:0' shape=(None, 1) dtype=float32>, 'pickup_datetime': <tf.Tensor 'PlaceholderWithDefault_5:0' shape=(None, 1) dtype=string>, 'pickup_latitude': <tf.Tensor 'PlaceholderWithDefault_6:0' shape=(None, 1) dtype=float32>, 'pickup_longitude': <tf.Tensor 'PlaceholderWithDefault_7:0' shape=(None, 1) dtype=float32>},), and keyword arguments: {}.                                                                                                                                                                                                                                                                                                     
From wrapper function: Execution of preprocessing_fn took 0.10583901405334473 seconds.                                                                                                                                                                                                                                                                                               
From wrapper function: Executing function named: my_tfrecord_writer, with arguments: ('./working_area/transformed_tfrecords/transformed_tfrecords',), and keyword arguments: {}.                                                                                                                                                                                                     
From wrapper function: Execution of my_tfrecord_writer took 6.246566772460938e-05 seconds.                                                                                                                                                                                                                                                                                           
WARNING:root:Make sure that locally built Python SDK docker image has Python 3.8 interpreter.                                                                                                                                                                                                                                                                                        
From wrapper function: Executing function named: preprocessing_fn, with arguments: ({'dropoff_latitude': <tf.Tensor 'inputs_copy:0' shape=(None, 1) dtype=float32>, 'dropoff_longitude': <tf.Tensor 'inputs_1_copy:0' shape=(None, 1) dtype=float32>, 'fare_amount': <tf.Tensor 'inputs_2_copy:0' shape=(None, 1) dtype=float32>, 'key': <tf.Tensor 'inputs_3_copy:0' shape=(None, 1) dtype=string>, 'passenger_count': <tf.Tensor 'inputs_4_copy:0' shape=(None, 1) dtype=float32>, 'pickup_datetime': <tf.Tensor 'inputs_5_copy:0' shape=(None, 1) dtype=string>, 'pickup_latitude': <tf.Tensor 'inputs_6_copy:0' shape=(None, 1) dtype=float32>, 'pickup_longitude': <tf.Tensor 'inputs_7_copy:0' shape=(None, 1) dtype=float32>},), and keyword arguments: {}.        
From wrapper function: Execution of preprocessing_fn took 0.10547733306884766 seconds.                                                                                                                                                                                                                                                                                               
INFO:tensorflow:Assets written to: /tmp/tmp2v_wm341/tftransform_tmp/f6bdf64d965642e380b72cff98377346/assets                                                                                                                                                                                                                                                                          
INFO:tensorflow:Assets written to: /tmp/tmp2v_wm341/tftransform_tmp/f6bdf64d965642e380b72cff98377346/assets                                                                                                                                                                                                                                                                          
From wrapper function: Executing function named: preprocessing_fn, with arguments: ({'dropoff_latitude': <tf.Tensor 'inputs_copy:0' shape=(None, 1) dtype=float32>, 'dropoff_longitude': <tf.Tensor 'inputs_1_copy:0' shape=(None, 1) dtype=float32>, 'fare_amount': <tf.Tensor 'inputs_2_copy:0' shape=(None, 1) dtype=float32>, 'key': <tf.Tensor 'inputs_3_copy:0' shape=(None, 1) dtype=string>, 'passenger_count': <tf.Tensor 'inputs_4_copy:0' shape=(None, 1) dtype=float32>, 'pickup_datetime': <tf.Tensor 'inputs_5_copy:0' shape=(None, 1) dtype=string>, 'pickup_latitude': <tf.Tensor 'inputs_6_copy:0' shape=(None, 1) dtype=float32>, 'pickup_longitude': <tf.Tensor 'inputs_7_copy:0' shape=(None, 1) dtype=float32>},), and keyword arguments: {}.        
From wrapper function: Execution of preprocessing_fn took 0.10962820053100586 seconds.                                                                                                                                                                                                                                                                                               
INFO:tensorflow:Assets written to: /tmp/tmp2v_wm341/tftransform_tmp/48267742bcab42519f74be7d14a49fbe/assets                                                                                                                                                                                                                                                                          
INFO:tensorflow:Assets written to: /tmp/tmp2v_wm341/tftransform_tmp/48267742bcab42519f74be7d14a49fbe/assets                                                                                                                                                                                                                                                                          
From wrapper function: Executing function named: preprocessing_fn, with arguments: ({'dropoff_latitude': <tf.Tensor 'PlaceholderWithDefault:0' shape=(None, 1) dtype=float32>, 'dropoff_longitude': <tf.Tensor 'PlaceholderWithDefault_1:0' shape=(None, 1) dtype=float32>, 'fare_amount': <tf.Tensor 'PlaceholderWithDefault_2:0' shape=(None, 1) dtype=float32>, 'key': <tf.Tensor 'PlaceholderWithDefault_3:0' shape=(None, 1) dtype=string>, 'passenger_count': <tf.Tensor 'PlaceholderWithDefault_4:0' shape=(None, 1) dtype=float32>, 'pickup_datetime': <tf.Tensor 'PlaceholderWithDefault_5:0' shape=(None, 1) dtype=string>, 'pickup_latitude': <tf.Tensor 'PlaceholderWithDefault_6:0' shape=(None, 1) dtype=float32>, 'pickup_longitude': <tf.Tensor 'PlaceholderWithDefault_7:0' shape=(None, 1) dtype=float32>},), and keyword arguments: {}.                                                                                                                                                                                                                                                                                                     
From wrapper function: Execution of preprocessing_fn took 0.10708332061767578 seconds.                                                                                                                                                                                                                                                                                               
From wrapper function: Execution of pipeline_function took 3.1847333908081055 seconds.             
zoyahav commented 2 years ago

Yes this is expected, the preprocessing_fn will get traced several times as you saw, this is an implementation detail of how this library performs the data analysis.

singhniraj08 commented 1 year ago

@pritamdodeja,

Execution of preprocessing_fn is expected multiple times. Kindly let us know if this issue can be closed. Thank you!

pritamdodeja commented 1 year ago

You can close it. Thank you!

On Mon, Feb 13, 2023, 6:38 AM Niraj Singh @.***> wrote:

@pritamdodeja https://github.com/pritamdodeja,

Execution of preprocessing_fn is expected multiple times. Kindly let us know if this issue can be closed. Thank you!

— Reply to this email directly, view it on GitHub https://github.com/tensorflow/transform/issues/276#issuecomment-1427796617, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACX63TYG6EYU56WGCKAFDVDWXIMLHANCNFSM5XF2TKVQ . You are receiving this because you were mentioned.Message ID: @.***>