lancedb / lance

Modern columnar data format for ML and LLMs implemented in Rust. Convert from parquet in 2 lines of code for 100x faster random access, vector index, and data versioning. Compatible with Pandas, DuckDB, Polars, Pyarrow, with more integrations coming..
https://lancedb.github.io/lance/
Apache License 2.0
3.86k stars 213 forks source link

v2 format - weird issue with categorical columns #3005

Open benmayersohn opened 5 days ago

benmayersohn commented 5 days ago

I noticed a strange issue when trying to use take to load a subset of rows from a lance v2 dataset. The dataset has categorical columns. In polars these are represented as a int32 -> large_string mapping, but since large string dictionaries aren't currently supported in v2 (https://github.com/lancedb/lance/issues/2828), I convert the columns to int32 -> string after converting the polars dataframe to an arrow table. Then I save it as a lance v2 dataset.

When I load one of these categorical columns in its entirety via ds.to_table and convert to polars via pl.from_arrow, the categorical column looks fine. But for certain subsets of rows, I end up with an incorrect number of levels. Here is a reproducible example:

import polars as pl
import lance
import numpy as np
import pyarrow as pa
from tqdm import tqdm

num_categories = 1000
num_rows = 100_000_000
categories = [str(x) for x in list(range(num_categories))]
df = pl.DataFrame({'a': np.random.randint(0, num_categories, num_rows)})
df = df.with_columns(pl.col('a').cast(pl.String).cast(pl.Enum(categories)))
out_filename = 'test.lance'
df = df.to_arrow()
df = pa.table([df['a'].cast(pa.dictionary(pa.int32(), pa.string()))], names=['a'])
lance.write_dataset(df, out_filename)
# Before
# After
# Before
# After
# ... a whole lot of Before\nAfter print statements, which seem to come from AlreadyDictionaryEncoder in dictionary.rs ...
# <lance.dataset.LanceDataset object at 0x7f3d249a74c0>

# Looks fine
entry = ds.to_table()
entry = pl.from_arrow(entry)
categories = entry['a'].cat.get_categories()
len(categories)
# 1000

# Break up row indices from 0 to num_rows in batches of 1024
ds = lance.dataset(out_filename)
batch_size = 1024
batches = [list(range(i, min(i + batch_size, num_rows))) for i in range(0, num_rows, batch_size)]
with tqdm(batches) as pbar:
    for batch in pbar:
        entry = ds.take(batch)
        entry = pl.from_arrow(entry)
        categories = entry['a'].cat.get_categories()
        if len(categories) != num_categories:
            print(f'Error: {len(categories)} categories found in batch from {batch[0]} to {batch[-1]}')
            print(categories.to_list())
            break
        else:
            pbar.set_description(f'Num categories: {len(categories)}')

# Incorrect number of categories
Error: 685 categories found in batch from 87040 to 88063
['4', '5', '20', '21', '22', '24', '40', '48', '52', '57', '58', '66', '68', '75', '78', '80', '81', '88', '91', '95', '105', '108', '115', '121', '122',
 '126', '135', '137', '139', '143', '148', '149', '150', '168', '169', '173', '177', '185', '188', '194', '197', '201', '202', '204', '213', '215', '217',
 '224', '226', '230', '231', '232', '236', '237', '245', '252', '253', '254', '256', '261', '262', '265', '269', '277', '288', '289', '293', '295', '299',
 '303', '304', '308', '313', '319', '325', '327', '330', '334', '337', '338', '341', '342', '346', '350', '352', '353', '356', '358', '361', '362', '364',
 '368', '369', '370', '372', '374', '375', '376', '378', '380', '389', '391', '395', '407', '408', '411', '414', '422', '424', '427', '428', '436', '437',
 '438', '440', '445', '446', '447', '450', '454', '459', '464', '465', '467', '473', '476', '477', '478', '489', '500', '502', '504', '509', '510', '511',
 '513', '516', '522', '537', '538', '546', '548', '560', '564', '565', '566', '576', '577', '579', '583', '585', '589', '596', '604', '609', '615', '616',
 '617', '627', '631', '636', '637', '644', '645', '648', '653', '656', '665', '666', '667', '672', '674', '679', '685', '687', '690', '691', '693', '700',
 '703', '707', '708', '709', '711', '712', '719', '724', '726', '728', '732', '733', '737', '741', '757', '765', '766', '769', '771', '772', '774', '775',
 '778', '779', '789', '795', '799', '800', '801', '807', '809', '811', '812', '814', '816', '819', '820', '823', '824', '826', '829', '830', '834', '837',
 '841', '844', '848', '851', '855', '856', '858', '861', '865', '868', '869', '875', '876', '879', '888', '890', '891', '893', '899', '900', '902', '904',
 '907', '909', '911', '914', '915', '917', '919', '921', '922', '927', '928', '930', '931', '938', '940', '941', '944', '949', '953', '954', '959', '960',
 '962', '964', '970', '972', '973', '977', '981', '982', '987', '988', '989', '991', '994', '0', '2', '6', '7', '9', '14', '16', '17', '18', '19', '20',
 '24', '30', '31', '32', '33', '35', '38', '39', '42', '44', '45', '46', '47', '49', '50', '51', '53', '54', '56', '60', '62', '63', '64', '67', '69',
 '70', '71', '72', '77', '81', '82', '85', '87', '89', '92', '96', '99', '104', '111', '113', '114', '118', '119', '123', '124', '125', '127', '128',
 '129', '131', '132', '133', '134', '136', '138', '140', '142', '145', '146', '147', '148', '151', '153', '154', '156', '158', '159', '161', '162',
 '164', '165', '166', '167', '171', '174', '179', '181', '182', '190', '191', '192', '193', '195', '196', '198', '199', '200', '203', '206', '207',
 '212', '218', '221', '223', '228', '231', '234', '235', '239', '240', '244', '247', '249', '252', '257', '258', '259', '263', '267', '268', '271',
 '272', '273', '274', '275', '276', '279', '281', '283', '289', '290', '291', '294', '296', '297', '298', '301', '309', '311', '315', '317', '318',
 '320', '321', '323', '324', '329', '331', '332', '334', '335', '336', '345', '348', '357', '360', '363', '366', '371', '373', '375', '388', '389',
 '390', '395', '396', '398', '399', '401', '404', '406', '407', '409', '410', '417', '418', '419', '421', '423', '430', '432', '439', '441', '443',
 '448', '451', '452', '455', '461', '462', '463', '468', '469', '472', '479', '481', '484', '485', '486', '488', '489', '490', '491', '495', '496',
 '497', '498', '502', '503', '504', '508', '509', '512', '514', '515', '516', '518', '520', '521', '524', '527', '529', '530', '531', '539', '540',
 '541', '544', '547', '549', '551', '553', '554', '560', '561', '563', '567', '568', '569', '572', '575', '578', '580', '584', '585', '586', '591',
 '593', '594', '598', '601', '603', '607', '608', '610', '611', '614', '617', '618', '620', '622', '623', '626', '628', '629', '630', '635', '638',
 '640', '641', '642', '643', '647', '654', '655', '658', '659', '662', '664', '669', '670', '671', '673', '676', '679', '680', '681', '682', '684',
 '688', '689', '692', '695', '697', '701', '704', '705', '706', '715', '716', '718', '719', '720', '722', '723', '725', '727', '729', '731', '737',
 '743', '744', '745', '747', '752', '753', '754', '756', '758', '760', '761', '762', '763', '772', '777', '790', '791', '792', '794', '798', '800',
 '802', '803', '804', '806', '807', '808', '809', '810', '813', '815', '818', '821', '826', '828', '830', '833', '834', '839', '842', '844', '847',
 '850', '857', '866', '868', '870', '872', '873', '874', '885', '889', '893', '896', '901', '903', '905', '907', '908', '910', '913', '915', '916', '920',
 '923', '926', '933', '935', '941', '942', '943', '946', '947', '948', '951', '952', '954', '956', '961', '962', '963', '964', '971', '976', '979', '984',
 '985', '986', '988', '989', '992', '995', '997', '999']

I don't encounter this issue when I save the dataset in v1 format:

import shutil

shutil.rmtree(out_filename)

lance.write_dataset(df, out_filename, data_storage_version='legacy')
# <lance.dataset.LanceDataset object at 0x7f3874f51180>

ds = lance.dataset(out_filename)
ds.data_storage_version
# 0.1

with tqdm(batches) as pbar:
    for batch in pbar:
        entry = ds.take(batch)
        entry = pl.from_arrow(entry)
        categories = entry['a'].cat.get_categories()
        if len(categories) != num_categories:
            print(f'Error: {len(categories)} categories found in batch from {batch[0]} to {batch[-1]}')
            print(categories.to_list())
            break
        else:
            pbar.set_description(f'Num categories: {len(categories)}')

# Num categories: 1000: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97657/97657 [01:15<00:00, 1301.87it/s]

Any help would be appreciated. Thanks!

westonpace commented 4 days ago

I think what you are describing is maybe enum and not categorical?

Enum considers the categories "fixed" and "part of the data" (e.g. the list of categories is important and constant)

Categorical considers the categories "flexible" and "just a tool to compress data". As a result, if a batch doesn't need some categories (because they are never referenced) then we don't store them. Also, it's possible two batches have a different order to the categories.

We don't have support for Enum in Lance but it wouldn't be too difficult to add I think.

Unfortunately, Arrow does not have any distinction between Categorical and Enum and Polars is not wrapping Enum as "an extension type on top of categorical". As a result, the two arrays look identical when they are converted to Arrow and passed to us:

import polars as pl
import lance
import numpy as np
import pyarrow as pa

num_categories = 1000
num_rows = 10_000
categories = [str(x) for x in list(range(num_categories))]
df = pl.DataFrame({'a': np.random.randint(0, num_categories, num_rows)})
df = df.with_columns(pl.col('a').cast(pl.String).cast(pl.Enum(categories)))
print(df.to_arrow().schema.field(0))
# pyarrow.Field<a: dictionary<values=large_string, indices=uint32, ordered=0>>

df = df.with_columns(pl.col('a').cast(pl.String).cast(pl.Categorical()))
print(df.to_arrow().schema.field(0))
# pyarrow.Field<a: dictionary<values=large_string, indices=uint32, ordered=0>>

For now, I think we could offer a top-level flag in write_dataset which controls whether we use "enum style" or "categorical style" for storing arrays. Would that work for you? Or are you using both category and enum in your application and need them preserved?

Also, I don't actually know what polars will do on conversion back to polars. I.e. when converting from arrow to polars is there some way to flag that dictionary data should be considered "enum" vs "categorical"? Maybe you can provide a schema when converting from arrow to polars?

Are you using Lance's to_polars? Or are you converting to polars yourself?

benmayersohn commented 4 days ago

Thanks for the response! I'm only using Enum, not Categorical (I meant categorical in a general sense - sorry for the ambiguity).

A flag sounds good! I didn't know Lance had a to_polars method - I use pl.from_arrow after calling ds.take on the Lance dataset. For now I'll try providing an explicit schema when converting to polars and see if that works.

westonpace commented 4 days ago

Great. Just to be clear, there's still work we'll need to do on our side (adding the flag and making sure we don't throw away levels) in addition to providing a schema. I'll try and find some time to get to it this week.

benmayersohn commented 4 days ago

Great - thanks so much for your hard work on this excellent project!