jasonrig / address-net

A package to structure Australian addresses
MIT License
195 stars 86 forks source link

Retrain model scrypt #18

Open taratuncho opened 2 years ago

taratuncho commented 2 years ago

Hello, First of all, thank you for the opportunity to use the code you wrote.

I'm trying to train a new model, but the result I get after that is very wrong.

{'street_name': '168A SEPARATION STREET NO', 'locality_name': 'COTE, VIC 3070'}

The code I use is the following, can you share your code or information where I might be mistaken?

Thank you so much.


import argparse
import datetime
import tensorflow as tf

import addressnet.dataset as dataset
from addressnet.model import model_fn

def _get_estimator(model_fn, model_dir):
    config = tf.estimator.RunConfig(tf_random_seed=17, keep_checkpoint_max=5, log_step_count_steps=2000,
                                    save_checkpoints_steps=2000)
    return tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, config=config)

def train(tfrecord_input_file: str, model_output_file: str):
    input_file_only = os.path.basename(tfrecord_input_file)
    model_output_file_path = f'{model_output_file}/{input_file_only}'

    #print('Start training...')
    #print(f'tfrecord_input_file={tfrecord_input_file}')
    #print(f'model_output_file={model_output_file}')

    #print('Get estimator...')
    address_net_estimator = _get_estimator(model_fn, model_output_file_path)

    #print('Load dataset...')
    tfdataset = dataset.dataset(tfrecord_input_file)

    #print('Training model...')
    start = datetime.datetime.now()
    model = address_net_estimator.train(tfdataset)
    end = datetime.datetime.now()

    print('Evaluate model...')
    evaluation = model.evaluate(tfdataset)
    print(f'evaluation={evaluation}')

    print(f'Finished training in {end - start} sec on file {input_file_only}. '
                f'Model saved to {model_output_file_path}')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tfrecord_input_file", help="Tfrecord input file from generate_tf_records.py")
    parser.add_argument("--model_output_file", help="Model output file")
    args = parser.parse_args()

    train(args.tfrecord_input_file, args.model_output_file)
dylanhogg commented 2 years ago

@taratuncho, I am curious to know what text address input you used to get that output result? Without the input it is hard to diagnose the output.

I ran the input text "168A SEPARATION STREET NO, COTE, VIC 3070" through a live demo of the model (trained model was supplied by @jasonrig in this repo) at https://address-app.infocruncher.com/ which retuned:


{
"number_first": "168",
"number_first_suffix": "A",
"street_name": "SEPARATION",
"street_type": "STREET",
"street_suffix": "NORTH",
"locality_name": ", COTE",
"state": "VICTORIA",
"postcode": "3070"
}```