signebedi / gptty

ChatGPT wrapper in your TTY
MIT License
50 stars 7 forks source link

[model] add model validation #32

Closed signebedi closed 1 year ago

signebedi commented 1 year ago

We should validate that the model passed is an actual model.

def get_available_models():
    response = openai.Model.list()
    return [model.id for model in response['data']]

def is_valid_model(model_name):
    available_models = get_available_models()
    return model_name in available_models

def validate_model_type(model_name):
    if ('davinci' in model_name or 'curie' in model_name) and is_valid_model(model_name):
        return 'completion'
    elif 'gpt' in model_name and is_valid_model(model_name):
        return 'chat completion'
    else:
        raise Exception()

....

    # Set the parameters for the OpenAI completion API
    model_engine = configs['model'].rstrip('\n')

    try:
        model_type = validate_model_type(model_engine)
    except:
        click.echo(f"{RED}FAILED to validate the model name '{model_engine}'. Are you sure this is a valid OpenAI model? Check the available models at <https://platform.openai.com/docs/models/overview> and try again.{RESET}")
        return