hassonlab / 247-pickling

Contains code to create pickles from raw/processed data
1 stars 9 forks source link

Refactor code #112

Closed hvgazula closed 1 year ago

hvgazula commented 1 year ago

https://github.com/hassonlab/247-pickling/blob/b8253c6a5088c5b2f8e75e6a5cf7f3440b82cbf5/scripts/tfsemb_download.py#L137-L148

def get_models_and_class(model_name):
    model_class_map = {
        "causal": AutoModelForCausalLM,
        "seq2seq": AutoModelForSeq2SeqLM,
        "mlm": AutoModelForMaskedLM,
    }

    model_list_map = {
        "causal": CAUSAL_MODELS,
        "seq2seq": SEQ2SEQ_MODELS,
        "mlm": MLM_MODELS,
    }

    model_class = model_class_map.get(model_name, None)
    models = model_list_map.get(model_name, None)

    if not models and not model_class:
        for key, value in model_list_map.items():
            if model_name in value:
                return [model_name], model_class_map[key]

    return models, model_class

MODELS, model_class = get_models_and_class(models, model_class)
hvgazula commented 1 year ago
def get_models_and_class(model_name):
    model_class_map = {
        "causal": AutoModelForCausalLM,
        "seq2seq": AutoModelForSeq2SeqLM,
        "mlm": AutoModelForMaskedLM,
    }

    model_list_map = {
        "causal": CAUSAL_MODELS,
        "seq2seq": SEQ2SEQ_MODELS,
        "mlm": MLM_MODELS,
    }

    models = None
    for model_key, model_list in model_list_map.items():
        if model_name == model_key or model_name in model_list:
            models = model_list if model_name == model_key else [model_name]
            model_class = model_class_map.get(model_key, None)
            break

    if not models or not model_class:
        print("Invalid Model List or Model Class")

    return models, model_class

MODELS, model_class = get_models_and_class(models, model_class)
hvgazula commented 1 year ago

@zkokaja @VeritasJoker Which one do you prefer, the original one or the most recent one☝️? Or please suggest an improvement if you can think of one.

hvgazula commented 1 year ago
def get_models_and_class(model_name):
    model_class_map = {
        "causal": (CAUSAL_MODELS, AutoModelForCausalLM),
        "seq2seq": (SEQ2SEQ_MODELS, AutoModelForSeq2SeqLM),
        "mlm": (MLM_MODELS, AutoModelForMaskedLM),
    }

    models, mod_class = None, None
    for model_key, (model_list, model_class) in model_class_map.items():
        if model_name == model_key:
            models, mod_class = model_list, model_class
        elif model_name in model_list:
            models, mod_class = [model_name], model_class
        else:
            continue

    if not models or not mod_class:
        print("Invalid Model List or Model Class")

    return models, mod_class

MODELS, model_class = get_models_and_class(models, mod_class)
hvgazula commented 1 year ago

Okay I am spending too much time here 😛