mlr-org / mlr3

mlr3: Machine Learning in R - next generation
https://mlr3.mlr-org.com
GNU Lesser General Public License v3.0
927 stars 86 forks source link

`task$truth()` doesn't respect row order after filtering task #896

Closed bblodfon closed 1 week ago

bblodfon commented 1 year ago

Is the following expected behavior?

library(mlr3)

spam = tsk('spam')
all(spam$truth() == spam$truth(rows = 1:spam$nrow)) # ok
#> [1] TRUE

spam$filter(rows = sample(1:spam$nrow, 42))
all(spam$truth() == spam$truth(rows = 1:spam$nrow)) # not ok!
#> [1] FALSE

Created on 2023-02-03 with reprex v2.0.2

be-marc commented 1 year ago

My first comment was wrong. Looks like a bug. I'll take a look. Thanks for letting us know.

be-marc commented 1 year ago

Okay no bug. After calling spam$filter(rows = sample(1:spam$nrow, 42)), spam$truth(rows = 1:spam$nrow) still retrieves from the complete data set i.e rows 1 to 42. spam$truth() retrieves the rows set in $row_roles$use.

spam = tsk('spam')

spam$filter(rows = sample(1:spam$nrow, 42))
spam$row_roles$use
#>  [1]  152  154  323  352  384  467  706  915  925 1191 1226 1294 1448 1492 1613 1707 1881 1962 2005 2060 2108 2211 2304 2357 2455 2656 3213 3362 3378 3438 3675 3713 3743 3823 3890 3942 3984 4065 4186 4383 4472 4516

all(spam$truth() == spam$truth(rows = spam$row_roles$use))
#> [1] TRUE

The behavior is correct.

be-marc commented 1 year ago

The assumption that rows = 1:spam$nrow refers to the filtered data set is wrong. rows always refers to the whole data set even after filtering.

bblodfon commented 1 year ago

I see, thanks for the explanation! The same is true (rows and row_ids refer to the whole dataset) when doing something along the lines of: learner$train(task, row_ids = 1:42)$predict(task, row_ids = 42:50).

Maybe this needs to be clarified in the documentation somewhere? (I got really confused about it)

be-marc commented 1 week ago

Done