NVIDIA-Merlin / core

Core Utilities for NVIDIA Merlin
Apache License 2.0
19 stars 14 forks source link

[BUG] `repartition` on `Dataset` removes tags from schema #179

Open radekosmulski opened 1 year ago

radekosmulski commented 1 year ago

image

Reproducer code:

import numpy as np
import cudf
import nvtabular as nvt
from merlin.schema.tags import Tags

purchases = cudf.DataFrame(
    data={'user_id': [0, 1, 2, 2],
          'price': [125.04, 23.07, 101.2, 2.34],
          'color': ['blue', 'blue', 'red', 'yellow'],
          'model': ['deluxe', 'compact', 'regular', 'regular']
})

out = ['price'] >> nvt.ops.AddMetadata(tags=[Tags.TARGET])

out += ['price'] >> nvt.ops.AddTags(tags=[Tags.CONTINUOUS])
out += ['user_id'] >> nvt.ops.TagAsUserID()
out += ['color', 'model'] >> nvt.ops.TagAsItemFeatures()
out += ['color', 'model'] >> nvt.ops.AddTags(tags=[Tags.CATEGORICAL])

ds = nvt.Dataset(purchases)
wf = nvt.Workflow(out)

ds_out = wf.fit_transform(ds)
ds_out.schema

ds_out = ds_out.repartition(5)

ds_out.schema
karlhigley commented 1 year ago

I think this is because repartition creates a brand new Dataset object which then tries to infer a schema from the raw data all over again, but it shouldn't be too hard to maintain the existing schema in this case.

Does it work if you supply schema=self.schema as a Dataset constructor argument in the definition of repartition?

rnyak commented 1 year ago

@radekosmulski any update based on Karl's comment above? thanks.

karlhigley commented 1 year ago

I think @sararb fixed this issue in #192

rnyak commented 1 year ago

@radekosmulski can you pls test this again with the latest branches pulled and see if this issues was fixed or not? Sara made a fix but not sure it solves your issue as well.