Closed hvgazula closed 1 year ago
def get_models_and_class(model_name):
model_class_map = {
"causal": AutoModelForCausalLM,
"seq2seq": AutoModelForSeq2SeqLM,
"mlm": AutoModelForMaskedLM,
}
model_list_map = {
"causal": CAUSAL_MODELS,
"seq2seq": SEQ2SEQ_MODELS,
"mlm": MLM_MODELS,
}
models = None
for model_key, model_list in model_list_map.items():
if model_name == model_key or model_name in model_list:
models = model_list if model_name == model_key else [model_name]
model_class = model_class_map.get(model_key, None)
break
if not models or not model_class:
print("Invalid Model List or Model Class")
return models, model_class
MODELS, model_class = get_models_and_class(models, model_class)
@zkokaja @VeritasJoker Which one do you prefer, the original one or the most recent one☝️? Or please suggest an improvement if you can think of one.
def get_models_and_class(model_name):
model_class_map = {
"causal": (CAUSAL_MODELS, AutoModelForCausalLM),
"seq2seq": (SEQ2SEQ_MODELS, AutoModelForSeq2SeqLM),
"mlm": (MLM_MODELS, AutoModelForMaskedLM),
}
models, mod_class = None, None
for model_key, (model_list, model_class) in model_class_map.items():
if model_name == model_key:
models, mod_class = model_list, model_class
elif model_name in model_list:
models, mod_class = [model_name], model_class
else:
continue
if not models or not mod_class:
print("Invalid Model List or Model Class")
return models, mod_class
MODELS, model_class = get_models_and_class(models, mod_class)
Okay I am spending too much time here 😛
https://github.com/hassonlab/247-pickling/blob/b8253c6a5088c5b2f8e75e6a5cf7f3440b82cbf5/scripts/tfsemb_download.py#L137-L148