tidyfun / tf

S3 classes and methods for tidy functional data
https://tidyfun.github.io/tf/
GNU Affero General Public License v3.0
7 stars 2 forks source link

refactor: migration to `vec_arith` API #133

Closed m-muecke closed 2 weeks ago

m-muecke commented 1 month ago

Note: I haven't included the dispatch case for when one is logical, since I don't believe that ever make sense. But you had it previously implemented it.

fabian-s commented 1 month ago

Thanks, this already mostly works, but i see at least 3 major issues (some of them may have been present before as well...)

library(tf)
#> 
#> Attaching package: 'tf'
#> The following objects are masked from 'package:stats':
#> 
#>     sd, var

x <- tf_rgp(3)
x_i <- tf_sparsify(x)
y <- tfb(x)
#> Percentage of input data variability preserved in basis representation
#> (per functional observation, approximate):
#>    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#>   99.90   99.90   99.90   99.93   99.95  100.00
y_pc <- tfb_fpc(x)

# 1. tfd-tfb ops don't seem to dispatch to the correct method / don't generate the right exception
x + y
#> Error in UseMethod("vec_arith.tfd", y): no applicable method for 'vec_arith.tfd' applied to an object of class "c('tfb_spline', 'tfb', 'tf', 'vctrs_vctr', 'list')"
tf:::vec_arith.tfd("+", x, y)
#> Error in UseMethod("vec_arith.tfd", y): no applicable method for 'vec_arith.tfd' applied to an object of class "c('tfb_spline', 'tfb', 'tf', 'vctrs_vctr', 'list')"
# why does the above not dispatch to this below ?!?
tf:::vec_arith.tfd.default("+", x, y)
#> Error in `tf:::vec_arith.tfd.default()`:
#> ! <tfd_reg> + <tfb_spline> is not permitted

# 2. numeric-fun ops should not cast tfb_fpc to tfb_spline
class(y_pc + y_pc)
#> [1] "tfb_fpc"    "tfb"        "tf"         "vctrs_vctr" "list"
class(2 * y_pc) #!!!
#> [1] "tfb_spline" "tfb"        "tf"         "vctrs_vctr" "list"

# 3. ops for tfb should preserve attributes of the original tfb
#    (ALWAYS for numeric-fun-ops, at least)
tf:::compare_tf_attribs(y, 2 * y)
#>       domain        basis  basis_label   basis_args basis_matrix          arg 
#>         TRUE         TRUE        FALSE        FALSE         TRUE         TRUE 
#>       family        class 
#>         TRUE         TRUE
tf:::compare_tf_attribs(y / y, y)
#> Warning: Fit captures <50% of input data variability for at least one function
#> -- consider increasing no. of basis functions 'k' or decreasing penalization.
#> Percentage of input data variability preserved in basis representation
#> (per functional observation, approximate):
#>    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#>    -Inf    -Inf    -Inf    -Inf    -Inf    -Inf
#>       domain        basis  basis_label   basis_args basis_matrix          arg 
#>         TRUE         TRUE        FALSE        FALSE         TRUE         TRUE 
#>       family        class 
#>         TRUE         TRUE
y + 2 * y
#> Error in vec_arith.tfb.tfb("+", e1, e2): all(compare_tf_attribs(x, y)) is not TRUE

# minor: better error messages for things like this
x + x_i
#> Error in fun_op(op, x, y): isTRUE(all.equal(tf_arg(x), tf_arg(y), check.attributes = FALSE)) is not TRUE

x + x == 2 * x
#>    1    2    3 
#> TRUE TRUE TRUE
x * x == x^2
#>    1    2    3 
#> TRUE TRUE TRUE
all( (tf_evaluations(x / x) |> unlist()) == 1)
#> [1] TRUE
tf:::compare_tf_attribs(x / x, x)
#>            arg         domain      evaluator evaluator_name          class 
#>           TRUE           TRUE           TRUE           TRUE           TRUE

x_i + x_i == 2 * x_i
#>    1    2    3 
#> TRUE TRUE TRUE
x_i * x_i == x_i^2
#>    1    2    3 
#> TRUE TRUE TRUE
all( (tf_evaluations(x_i / x_i) |> unlist()) == 1)
#> [1] TRUE
tf:::compare_tf_attribs(x_i / x_i, x_i)
#>            arg         domain      evaluator evaluator_name          class 
#>           TRUE           TRUE           TRUE           TRUE           TRUE

#-----------------------------------

y * y == y^2
#>    1    2    3 
#> TRUE TRUE TRUE
all.equal((tf_evaluations(y / y) |> unlist()),
          rep(1, length(y)*length(tf_arg(y))), check.attributes = FALSE)
#> Warning: Fit captures <50% of input data variability for at least one function
#> -- consider increasing no. of basis functions 'k' or decreasing penalization.
#> Percentage of input data variability preserved in basis representation
#> (per functional observation, approximate):
#>    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#>    -Inf    -Inf    -Inf    -Inf    -Inf    -Inf
#> [1] TRUE
tf:::compare_tf_attribs(y / y, y)  #!!
#> Warning: Fit captures <50% of input data variability for at least one function
#> -- consider increasing no. of basis functions 'k' or decreasing penalization.
#> Percentage of input data variability preserved in basis representation
#> (per functional observation, approximate):
#>    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#>    -Inf    -Inf    -Inf    -Inf    -Inf    -Inf
#>       domain        basis  basis_label   basis_args basis_matrix          arg 
#>         TRUE         TRUE        FALSE        FALSE         TRUE         TRUE 
#>       family        class 
#>         TRUE         TRUE

#--------------------------

y_pc + y_pc == 2 * y_pc #!
#> [1] FALSE FALSE FALSE
all.equal(tf_evaluations(y_pc + y_pc), tf_evaluations(2 * y_pc)) # urks
#> [1] "Component \"1\": Mean relative difference: 0.01724136"
#> [2] "Component \"2\": Mean relative difference: 0.01293853"
#> [3] "Component \"3\": Mean relative difference: 0.01199792"

next steps should be adding tests for all combinations of ops and classes and tackling the issues above.

fabian-s commented 1 month ago

re:

I haven't included the dispatch case for when one is logical, since I don't believe that ever make sense.

good call, agree

fabian-s commented 1 month ago

edit: issue 1 above was just missing @method directives

m-muecke commented 1 month ago

Good catch. This bug was already there before:

library(tf)
#> 
#> Attaching package: 'tf'
#> The following objects are masked from 'package:stats':
#> 
#>     sd, var
set.seed(1234)
x <- tf_rgp(3) |> tfb()
#> Percentage of input data variability preserved in basis representation
#> (per functional observation, approximate):
#>    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#>   99.90   99.90   99.90   99.93   99.95  100.00
isTRUE(all.equal(x + x, x * 2))
#> [1] FALSE

Created on 2024-10-11 with reprex v2.1.1

fabian-s commented 1 month ago

Thx A LOT -- i'll need some more time to review this, Wed at the latest

fabian-s commented 1 month ago

i fixed some things & broke some more -- some weird behavior visible in attic/dev-vec-arith.R and the currently breaking tests for fwise need to be tackled next.

m-muecke commented 1 month ago

Very nice. Was the addition of the glue package just for the WIP and should be exchanged with sprintf later, since glue is only used sparingly, or do you want to add it as a dependency?

fabian-s commented 1 month ago

glue is used in vctrs anyway so it's not really an additional dependency, and this way we create errors that look like vctrs-errors.

fabian-s commented 1 month ago

@m-muecke can you check that this still makes sense and LMK if you across anything sus ? i'm at home with COVID and not entirely sure i was thinking straight...

fabian-s commented 1 month ago

is the thumb "yes i will take a look" or "looked at it, seems fine" ? ;)

m-muecke commented 1 month ago

is the thumb "yes i will take a look" or "looked at it, seems fine" ? ;)

I will take a look this week

m-muecke commented 2 weeks ago

@m-muecke can you check that this still makes sense and LMK if you across anything sus ? i'm at home with COVID and not entirely sure i was thinking straight...

@fabian-s Looks good, left a couple comments. Since you've adopted using the glue package, one could think about using cli for warnings/errors since its also a vctrs dep: https://github.com/r-lib/vctrs/blob/main/DESCRIPTION#L22 this would replace the stop(glue(...)) pattern with cli::cli_abort(...), also the following: https://github.com/r-lib/vctrs/blob/78d9f2b0b24131b5ce2230eb3c2c9f93620b10d9/R/type-misc.R#L33