vojtsek / to-llm-bot

34 stars 7 forks source link

MultiWoz Loading Ground Truth State and Domain Incorrectly #3

Open carlguo2 opened 5 months ago

carlguo2 commented 5 months ago

I noticed when running the run.py script using ground truth state the JGA was really low (~0.5) instead of 1.0. I looked into load_mwoz() in loaders.py and saw that the gt_state and gt_domain seem to not be recording in a way that matches the evaluations in metrics.py and utils.py.

In ground truth state, only the first active_intent part of the ground truth state for the dialogue turn is getting recorded link:

def load_mwoz():
  ...
  state = dialog['turns']['frames'][tn]['state']            
  if len(state) == 0:
    state = {}
  else:
    state = [state[i]['slots_values'] for i in range(len(state))]
  ...

It also didn't accumulate state from previous turns, so it was unable to fully match with gold_states in evaluation. I made some changes to how gt_state is loaded and assembled and now I get JGA close to 1 when running with gt_state and get better overall results on Success and Inform.

def load_mwoz():
  ...
  state = dialog['turns']['frames'][tn]['state']            
  if len(state) == 0:
    state = {}
  else:
    state = [state[i]['slots_values'] for i in range(len(state))]
    state = [{k: v[0] for k, v in zip(state[i]['slots_values_name'], state[i]['slots_values_list']) } for i in range(len(state)) if len(state[i]['slots_values_name']) > 0]
   new_state = last_state 
   for i in range(len(state)):
     for sl, val in state[i].items():
       domain, name = sl.split('-')
       slots_per_domain[domain].add(name)
       if domain not in new_state:
         new_state[domain] = {name: val}
       else:
         new_state[domain][name] = val
  ...

I also noticed that the way gt_domain is loaded in load_mwoz() is different than the way domain is calculated in get_domain_estimates_from_state().

Specifically, when loading gt_domain, load_mwoz() grabs the first element in the "services" field in the multiwoz dialogue data link:

def load_mwoz():
  ...
  if len(dialog['services']) > 0:
    domain_gt = dialog['services'][0]

I tried to change how gt_domain is calculated based on the method in get_domain_estimates_from_state() and an improvement in domain accuracy (0.4 -> 0.7)

def load_mwoz():
  ...
  for dialog in data:
    ...
    for tn in range(0, len(dialog['turns']['utterance']), 2):
      ...
      if len(dialog['services']) > 1:
        changed_domains = state_update.keys()
        if len(changed_domains) == 0:  
          if domain_gt == '':
            domain_gt = ''
          else:
            if len(old_changed_domains) > 1:
              old_changed_domains = [x for x in old_changed_domains if x in new_state and x != domain_gt]
              if len(old_changed_domains) > 0:
                domain_gt = old_changed_domains[0]   # cond 3
        else:
          domain_gt = max(changed_domains, key=lambda x: len(new_state[x]))  # cond 2
        old_changed_domains = changed_domains
      elif len(dialog['services']) == 1:
        domain_gt = dialog['services'][0]
      else:
        domain_gt = ''
  ...