triton-inference-server / server

The Triton Inference Server provides an optimized cloud and edge inferencing solution.
https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html
BSD 3-Clause "New" or "Revised" License
8.41k stars 1.49k forks source link

Tensorflow models using a lot of memory - loading twice #6339

Open amey-matroid opened 1 year ago

amey-matroid commented 1 year ago

Description

I recently upgraded triton version from 22.04 to 23.07 and noticed that tensorflow models (i.e platform: tensorflow_savedmodel) take up twice the memory for the same model. For instance, a VGG face implementation uses

We also observed a steady increase in memory footprint for the same model across versions:

We noticed that prior to 22.07 models only got loaded once during initialization, all builds 22.07+ the model loads twice during initialization despite setting the MAX_SESSION_SHARE_COUNT to be the same as the instance_group: count

This is the config file that we use

name: "VGG_FACE"
platform: "tensorflow_savedmodel"
max_batch_size: 1000
input {
  name: "VGG_FACE"
  data_type: TYPE_FP32
  dims: 224
  dims: 224
  dims: 3
}
output {
  name: "VGG_FACE"
  data_type: TYPE_FP32
  dims: 4096
}
instance_group {
  count: 2
}
dynamic_batching {
  preferred_batch_size: 2
  default_queue_policy {
    max_queue_size: 200
  }
}
optimization {
  execution_accelerators {
    gpu_execution_accelerator {
      name: "auto_mixed_precision"
    }
  }
}
parameters {
  key: "MAX_SESSION_SHARE_COUNT"
  value {
    string_value: "2"
  }
}

Triton Information 23.07

Are you using the Triton container or did you build it yourself? Triton container

To Reproduce I think using any tensorflow model can be used to reproduce this

Expected behavior Expect the model to not take the twice the memory

dyastremsky commented 1 year ago

Thank you for submitting this bug with detailed reproduction information! We have filed a ticket to investigate.

Ticket reference: DLIS-5721.

amey-matroid commented 1 year ago

We further investigated this and found out the issue is not necessarily a triton issue but a tensorflow bug. On loading multiple different models on tensorflow2.8 and tensorflow2.9 off the shelf containers, we noticed a spike in the memory consumption (almost twice) always. The CUDA version was 11.2 for both and the drivers used were same too so we dont think this is necessarily a CUDA or driver issue either.

We tried building our own tensorflow backend with tensorflow2.8 in the latest triton container but since it is not trivial, we decided against doing it. If this is something that can be provided by triton, I think it would help with tensorflow models.

As a temporary fix, we found that adding the flags CUDA_MODULE_LOADING=LAZY and TF_GPU_ALLOCATOR=cuda_malloc_async (available since CUDA11.6) to reduce the system memory and gpu memory footprint to acceptable levels (almost half).

Hope this additional information helps you narrow the issue. Let me know if you need anything else from me.

dyastremsky commented 1 year ago

Thanks for investigating and sharing your findings! This will be quite helpful. The directions for building customer TensorFlow may be worth a try, if you have not yet.

We will investigate. If this is a TensorFlow issue that exists outside of Triton, hopefully TensorFlow is working on a fix as well.

CC: @tanmayv25