VowpalWabbit / py-vowpal-wabbit-next

Experimental new Python bindings for the VowpalWabbit library
https://vowpal-wabbit-next.readthedocs.io/en/latest/
BSD 3-Clause "New" or "Revised" License
12 stars 5 forks source link

Memory error when online training a loaded oaa model #80

Closed FabianKaiser closed 1 year ago

FabianKaiser commented 1 year ago

When trying out the script in #4510 and changing to a oaa model (with much less bits, it really is slow), an access violation occurs when trying to online train the model.

Simply replace the prediction and evaluation with

testing_data.apply(lambda x: model.learn_one(text_parser.parse_line(to_vw_format(x.text, mapping[x.target]))), axis=1)
print('Hello World')

The Hello World will not be reached (in pytest it actually throws a fatal access violation, with a normal call it simply stops).

jackgerrits commented 1 year ago

Can you please provide the full repro? What are the arguments you are using? How did you train the model you are testing?

FabianKaiser commented 1 year ago

Sure:

import vowpal_wabbit_next as vw
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score
import os

def to_vw_format(text: str, label=None) -> str:
    if label is None:
        label = ''
    return f'{label} |text {text}'

data = pd.read_csv('reddit_data_sample_local.csv')

target_var = 'target'
training_var = 'text'

cleaned_targets = data[target_var].dropna()
unique_labels = cleaned_targets.unique().tolist()
num_classes = len(unique_labels)
numbers = list(range(num_classes))
mapping = dict(zip(unique_labels, numbers))

training_data = data.sample(frac=0.9, random_state=25)
testing_data = data.drop(training_data.index)

training_data = training_data.dropna(subset=[training_var]).sample(frac=1).reset_index(drop=True)

vw_training_file_name = 'train.vw'

with open(vw_training_file_name, "wb") as f:
    for text, label in zip(training_data[training_var], training_data[target_var]):
        vw_label = mapping[label]
        example = to_vw_format(
            text, vw_label) + ' \n'
        f.write(example.encode())

os.makedirs('model', exist_ok=True)

params = {
    'loss_function': 'logistic',
    'data': vw_training_file_name,
    'cache': True,
    'kill_cache': True,
    'final_regressor': 'model/model.vw',
    'compressed': True,
    'oaa': num_classes,
    'bit_precision': 20,
    'passes': 1,
    'example_queue_limit': 256,
    'learning_rate': 0.5,
}
cmdline = [f"--{k}" if isinstance(v, bool) else f"--{k}={v}" for k, v in params.items()]
print(vw.run_cli_driver(cmdline)[0])

learned_model = open('model/model.vw', 'rb').read()

params = {
    'loss_function': 'logistic',
    'probabilities': True,
}
cmdline = [f"--{k}" if isinstance(v, bool) else f"--{k}={v}" for k, v in params.items()]
model = vw.Workspace(cmdline, model_data=learned_model)
text_parser = vw.TextFormatParser(model)
testing_data.apply(lambda x: model.learn_one(text_parser.parse_line(to_vw_format(x.text, mapping[x.target]))), axis=1)
FabianKaiser commented 1 year ago

reddit_data_sample_local.csv

jackgerrits commented 1 year ago

Thanks for the repro and for finding this bug! Working on a fix now

FabianKaiser commented 1 year ago

Thanks for the fix! Can you estimate when it will be released? And - will you also switch to VW 9.8 then?

jackgerrits commented 1 year ago

Working on releasing it now, should be done in a few hours.

Yes, it will be targeting 9.8 including the PLT loss/training fix

jackgerrits commented 1 year ago

@FabianKaiser 0.4.0 is available on PyPi

FabianKaiser commented 1 year ago

Thanks!