tensorflow / transform

Input pipeline framework
Apache License 2.0
984 stars 212 forks source link

memory leak when using tft.vocabulary with labels argument #181

Open AnnKatrinBecker opened 4 years ago

AnnKatrinBecker commented 4 years ago

Ubuntu 18.04 tensorflow 2.2.0 tfx 0.21.4

I am generating a vocabulary from my own dataset containing 28GB TFRecords with short description strings (up to 20 words) and integer labels from 1-100.

Generating the vocabulary without labels works fine. But as soon as the labels argument is provided to tft.vocabulary memory usage increases dramatically >100GB until the processes gets killed due to running out of memory.

`def preprocessing_fn(inputs): label = inputs['label'] desc = inputs['description'] desc = tf.strings.lower(desc.values)

remove all numbers and punctuations

desc = tf.strings.regex_replace(desc, "[^a-zA-Z¿]+", " ")
tokens = tf.strings.split(desc)
ngrams = tf.strings.ngrams(tokens, [1,2])
ngrams = ngrams.to_sparse()

tft.vocabulary(ngrams, top_k=100000, labels=tf.sparse.to_dense(label),
               vocab_filename='ngrams_100k_labels')
return {'description': desc, 'label': label}`

`def main():

Brings data into the pipeline

examples = external_input(
    'directory with tfrecords')
example_gen = ImportExampleGen(input=examples)
examples = example_gen.outputs['examples']

# Import schema
schema_importer = ImporterNode(
    instance_name='imported_schema',
    source_uri='pipelines/test/SchemaGen/schema/3',
    artifact_type=Schema)

### Perform transformation
transform = Transform(examples=example_gen.outputs['examples'],
                      schema=schema_importer.outputs['result'],
                      module_file='preprocessing.py')

pipe = pipeline.Pipeline( pipeline_name='test', pipeline_root='pipelines/test', components=[ example_gen, schema_importer, transform, ], metadata_connection_config=metadata.sqlite_metadata_connection_config( 'test/metadata.db'), enable_cache=True, beam_pipeline_args=['--direct_num_workers=0'])

absl.logging.set_verbosity(absl.logging.INFO)
BeamDagRunner().run(pipe)`
mrcolo commented 3 years ago

Any update on this? I'm getting a similar problem import a 38GB dataset using ImportExampleGen.

arghyaganguly commented 3 years ago

For large scale datasets it is recommended to use DataflowRunner (on GCP) or FlinkRunner or SparkRunner if it is a on-premises execution.