noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.5k stars 67 forks source link

RAM UTILISATION IS INCREASING RAPIDLY #145

Open UTSAV-44 opened 4 weeks ago

UTSAV-44 commented 4 weeks ago

For enforcing model to give response in json format, I am using ExLlamaV2TokenEnforcerFilter and ExLlamaV2PrefixFilter classes and appending to to filters list and passing as filters for generating output from model. As my usecase are limited so ,I thought of caching these both class by storing it in a dict and reusing it. But by doing this I observed that system ram utilization is increasing and after few iterations it leads to Out of Memory. Usually it takes 10-15 GB of system RAM but overtime the memory usage goes over 128 GB causing OOM. I tried getting the class which is creating this issue and found that ExLlamaV2TokenEnforcerFilter is not resetting some captured memory which is creating this problem.

We tried reinitalizing certain variables as below but it did not impact any memory reclaiming.

    self.universal_filter_map[use_case_id][0].token_sequence = []
    self.universal_filter_map[use_case_id][1].current_prefixes = set()
    self.universal_filter_map[use_case_id][1].current_str = ""
    self.universal_filter_map[use_case_id][1].prefix_strings = ["{", " {"]

I have logged this issue on ExllamaV2 ----- https://github.com/turboderp/exllamav2/issues/639

I am sharing the code snippet for complete implementation.


def run_mihup_llm_inference(self, call_transcript: str, prompt_tuples: List[Tuple]) -> List[json]:
      self.cache.reset()
      common_transcript = format_transcript_text(call_transcript)
      prompts = []
      filters = []
      use_case_ids = []
      for upper_tuple in prompt_tuples:
          use_case_id = upper_tuple[1]
          use_case_ids.append(use_case_id)
          p = upper_tuple[0]
          prompt_str = p[0]
          prompt_question_combined = format_llama3_prompt(mihup_system_prompt, common_transcript + prompt_str)
          prompts.append(prompt_question_combined)
          filter_schema_parser = p[1]

          print_memory_usage()

          if use_case_id not in self.universal_filter_map:
              print("Not found in the cache memory")

              self.universal_filter_map[use_case_id] = [
                  ExLlamaV2TokenEnforcerFilter(filter_schema_parser, self.tokenizer),
                  ExLlamaV2PrefixFilter(self.model, self.tokenizer, ["{", " {"])
              ]
          else:
              self.universal_filter_map[use_case_id][0].token_sequence = []
              self.universal_filter_map[use_case_id][1].current_prefixes = set()
              self.universal_filter_map[use_case_id][1].current_str = ""
              self.universal_filter_map[use_case_id][1].prefix_strings = ["{", " {"]
              print("Found in the cache memory")

          print("length of map : ", len(self.universal_filter_map[use_case_id]))
          # Create fresh instances each time
          filters.append(self.universal_filter_map[use_case_id])

      # print(prompts)

      outputs = self.generator.generate(
          prompt=prompts,
          filters=filters,
          filter_prefer_eos=True,
          max_new_tokens=1536,
          add_bos=True,
          stop_conditions=get_llama3_stop_conditions(self.tokenizer),
          completion_only=True,
          encode_special_tokens=True,
      )

      final_output = []
      skipped_index = []
      for i in range(len(outputs)):
          output_json = None
          try:
              output_json = json.loads(outputs[i])
          except ValueError as e:
              skipped_index.append(i)
              print("error: ", outputs[i])
          if output_json is not None:
              final_output.append(json.loads(outputs[i]))

      # assert len(final_output) == len(use_case_ids)

      # gc.collect()
      print_memory_usage()

      use_case_id_key = "use_case_id"
      for idx in range(len(final_output)):
          if idx not in skipped_index:
              final_output[idx][use_case_id_key] = use_case_ids[idx]

      return final_output
noamgat commented 4 weeks ago

Hi, LMFE by default caches all encountered prefixes. The prefix cache cannot be emptied if there are in-flight requests. However, from time to time, you can clear it. If you want to clear it without modifying any code, you can do something like

filter = ExLlamaV2TokenEnforcerFilter(filter_schema_parser, self.tokenizer)
for i in range(10000):
      # use filter here
      if i % 100 == 0:
          filter.token_enforcer.prefix_states = {}   # this is the important line

Let me know if this helps you solve the problem

UTSAV-44 commented 3 weeks ago

Hello , I tried the suggested solution, but the RAM usage is still increasing, although at a slower rate.There is something which is still being cached.