JuliaSurv / NetSurvival.jl

A pure-Julia take on standard net survival routines
https://juliasurv.github.io/NetSurvival.jl/
MIT License
10 stars 0 forks source link

[Tests] Improve tests #44

Closed lrnv closed 4 months ago

lrnv commented 5 months ago

Fixes #43 Also improve the testing situation

lrnv commented 4 months ago

@rimhajal these reformulations of the tests seem good to go from now on, BUT by adding a few more tests on GraffeoTest I uncovered an issue (it is not exact)... See next message

lrnv commented 4 months ago

So when running the different tests directly on the repl i have :

julia>     v1 = fit(GraffeoTest, @formula(Surv(time,status)~stage), colrec, slopop)
Grafféo's log-rank-type-test
1×3 DataFrame
 Row │ test_statistic  degrees_of_freedom  p_value      
     │ Float64         Int64               Float64      
─────┼──────────────────────────────────────────────────
   1 │        658.779                   3  1.81926e-142

julia>     v2 = GraffeoTest(colrec.time, colrec.status, colrec.age, colrec.year, colrec.sex, ones(length(colrec.age)), colrec.stage, slopop)     
Grafféo's log-rank-type-test
1×3 DataFrame
 Row │ test_statistic  degrees_of_freedom  p_value      
     │ Float64         Int64               Float64      
─────┼──────────────────────────────────────────────────
   1 │        658.779                   3  1.81926e-142

julia> 

julia>     v1_strat = fit(GraffeoTest, @formula(Surv(time,status)~stage+Strata(sex)), colrec, frpop)
Grafféo's log-rank-type-test
1×3 DataFrame
 Row │ test_statistic  degrees_of_freedom  p_value      
     │ Float64         Int64               Float64      
─────┼──────────────────────────────────────────────────
   1 │        871.499                   3  1.34567e-188

julia>     v2_strat = GraffeoTest(colrec.time, colrec.status, colrec.age, colrec.year, colrec.sex, colrec.sex, colrec.stage, slopop)
Grafféo's log-rank-type-test
1×3 DataFrame
 Row │ test_statistic  degrees_of_freedom  p_value      
     │ Float64         Int64               Float64      
─────┼──────────────────────────────────────────────────
   1 │        678.514                   3  9.56849e-147

julia>     R"""
           rez = relsurv::rs.diff(survival::Surv(time, stat) ~ stage, rmap=list(age = age, sex = sex, year = diag), data = relsurv::colrec, ratetable = relsurv::slopop)
           rez_strat = relsurv::rs.diff(survival::Surv(time, stat) ~ stage+survival::strata(sex), rmap=list(age = age, sex = sex, year = diag), data = relsurv::colrec, ratetable = relsurv::slopop)
           """
RObject{VecSxp}
Value of test statistic: 788.7504 
Degrees of freedom: 7
P value: 0

julia>     vR = @rget rez
OrderedCollections.OrderedDict{Symbol, Any} with 12 entries:
  :n         => [889, 393, 1361, 3328]
  :time      => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0  …  8139.0, 8140.0, 8141.0, 8142.0, 8143.0, 8144.0, 8145.0, 8146.0, 8147.0, 8…  :n_risk    => [889.0, 887.0, 885.0, 883.0, 878.0, 876.0, 872.0, 871.0, 871.0, 870.0  …  2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]      
  :n_event   => [2.0, 2.0, 2.0, 5.0, 2.0, 4.0, 1.0, 0.0, 1.0, 2.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  :n_censor  => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0  …  -0.0, -0.0, -0.0, 1.0, -0.0, -0.0, -0.0, -0.0, -0.0, 1.0]
  :groups    => [8148, 8148, 8148, 8148]
  :call      => :((relsurv :: var"rs.diff")($(Expr(:(=), :formula, :((survival :: Surv)(time, stat) ~ stage))), $(Expr(:(=), :data, :(relsurv ::…  :zh        => [-202.936 238.582 939.881]
  :covmat    => [2.95096e5 18772.0 12542.6; 18772.0 2730.18 891.519; 12542.6 891.519 1949.53]
  :test_stat => 658.602
  :p_value   => 0.0
  :df        => 3.0

julia> vR_strat = @rget rez_strat
OrderedCollections.OrderedDict{Symbol, Any} with 12 entries:
  :n         => [474, 415, 188, 205, 803, 558, 1824, 1504]
  :time      => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0  …  8139.0, 8140.0, 8141.0, 8142.0, 8143.0, 8144.0, 8145.0, 8146.0, 8147.0, 8…  :n_risk    => [474.0, 474.0, 474.0, 473.0, 469.0, 468.0, 466.0, 465.0, 465.0, 465.0  …  2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]      
  :n_event   => [0.0, 0.0, 1.0, 4.0, 1.0, 2.0, 1.0, 0.0, 0.0, 1.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  :n_censor  => [-0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0  …  -0.0, -0.0, -0.0, 1.0, -0.0, -0.0, -0.0, -0.0, -0.0, 1.0]
  :groups    => [8148, 8148, 8148, 8148, 8148, 8148, 8148, 8148]
  :call      => :((relsurv :: var"rs.diff")($(Expr(:(=), :formula, :((survival :: Surv)(time, stat) ~ stage + (survival :: strata)(sex)))), $(Ex…  :zh        => [6.39563 -209.331 … 389.624 187.284]
  :covmat    => [1.25111e5 59465.8 … 2859.03 1.60839e5; 59465.8 51053.8 … 2323.36 1.28799e5; … ; 2859.03 2323.36 … 672.592 6061.77; 1.60839e5 1.…  
  :test_stat => 788.75
  :p_value   => 0.0
  :df        => 7.0

Which shows that the unstratified version is OK, all three have a test statistic of 658.779, but on the stratified version I have 871.499 and 678.514 (the two interfaces do not even return the same value) while the R verison yields 788.75 (neither interfaces has the right value..)

So there seem to be two issues: 1) The fit(GraffeoTest,...) and GraffeoTest(...) syntaxes do not return the same thing, and I am too tired to find out why. 2) The way we understood and implemented the stratification of the test is probably wrong, we have to investigate a bit more..

rimhajal commented 4 months ago

v1_strat = fit(GraffeoTest, @formula(Surv(time,status)~stage+Strata(sex)), colrec, frpop)

this one is frpop and not slopop like the others. it is still failing ...

lrnv commented 4 months ago

Indeed I missed that... So now at least check_equal(v1_strat,v2_strat) is OK, but compare_with_R(v1_strat, vR_strat) is not, which means that there is in our stratification something taht is not exactly the same as what R does. I have:

julia> vR_strat
OrderedCollections.OrderedDict{Symbol, Any} with 12 entries:      
  :n         => [474, 415, 188, 205, 803, 558, 1824, 1504]        
  :time      => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.…  
  :n_risk    => [474.0, 474.0, 474.0, 473.0, 469.0, 468.0, 466.0,…  :n_event   => [0.0, 0.0, 1.0, 4.0, 1.0, 2.0, 1.0, 0.0, 0.0, 1.0…  
  :n_censor  => [-0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0,…  
  :groups    => [8148, 8148, 8148, 8148, 8148, 8148, 8148, 8148]  
  :call      => :((relsurv :: var"rs.diff")($(Expr(:(=), :formula…  
  :zh        => [6.39563 -209.331 … 389.624 187.284]
  :covmat    => [1.25111e5 59465.8 … 2859.03 1.60839e5; 59465.8 5…  
  :test_stat => 788.75
  :p_value   => 0.0
  :df        => 7.0

julia> v1_strat
Grafféo's log-rank-type-test
1×3 DataFrame
 Row │ test_statistic  degrees_of_freedom  p_value      
     │ Float64         Int64               Float64      
─────┼──────────────────────────────────────────────────
   1 │        678.514                   3  9.56849e-147

julia> 

In particular, we can extract the following from the test:

julia> Z = vR_strat[:zh]
1×7 Matrix{Float64}:
 6.39563  -209.331  110.493  128.089  550.257  389.624  187.284   

julia> V = vR_strat[:covmat]
7×7 Matrix{Float64}:
     1.25111e5  59465.8        …  2859.03       1.60839e5
 59465.8        51053.8           2323.36       1.28799e5
  6359.93        5096.28           228.628  13521.8
  4001.61        3314.19           140.109   8655.32
  4081.84        3278.41            84.955   8699.31
  2859.03        2323.36       …   672.592   6061.77
     1.60839e5      1.28799e5     6061.77       3.49675e5
julia> Z*inv(V)*(Z')
1×1 Matrix{Float64}:
 788.7504408090592
julia> 

while on the Julia side :

julia> Zjl =  dropdims(sum(v1_strat.∂Z, dims=(1,3)), dims=(1,3))  
4-element Vector{Float64}:
 -335.64878404730246
  934.6384524081541
  230.71202352424885
 -829.7016918850945

julia> Vjl = dropdims(sum(v1_strat.∂VZ, dims=(1,4)), dims=(1,4))  
4×4 Matrix{Float64}:
    1.05322e5   3177.89    4825.44      -1.13326e5
 3177.89        1481.09     189.151  -4848.13
 4825.44         189.151   1638.06   -6652.65
   -1.13326e5  -4848.13   -6652.65       1.24826e5

julia> Zjl'inv(Vjl) * Zjl
678.6491902935958

julia>

What really trickles me is that R has 7 dimensions while Julia has only 4 !

rimhajal commented 4 months ago

in the older version of the code, we used the function join( ) and i am not sure the simpler version does this

function StatsBase.fit(::Type{E}, formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) where {E<:GraffeoTest}
    rate_predictors = String.([RateTables.predictors(rt)...])

    expected_columns = [rate_predictors...,"age","year"]
    missing_columns = filter(name -> !(name in names(df)), expected_columns)
    if !isempty(missing_columns)
        throw(ArgumentError("Missing columns in data: $missing_columns"))
    end

    strata = ones(nrow(df))
    group = ones(nrow(df))
    strata_terms = []
    group_terms = []

    if typeof(formula.rhs) == Term
        group = select(df, StatsModels.termvars(formula.rhs))
        group = [join(row, " ") for row in eachrow(group)]
    elseif typeof(formula.rhs) <: FunctionTerm{typeof(Strata)}
        strata = select(df, StatsModels.termvars(formula.rhs))
        strata = [join(row, " ") for row in eachrow(strata)]
    else
        for myterm in formula.rhs
            is_strata = typeof(myterm) <: FunctionTerm{typeof(Strata)}
            if is_strata
                append!(strata_terms, StatsModels.termvars(myterm))
            else
                push!(group_terms, Symbol(myterm))
            end
        end
    end

    if !isempty(group_terms)
        group = select(df, group_terms)
        group = [join(row, " ") for row in eachrow(group)]
    end

    if !isempty(strata_terms)
        strata = select(df, strata_terms)
        strata = [join(row, " ") for row in eachrow(strata)]
    end

    formula = apply_schema(formula,schema(df))
    resp = modelcols(formula.lhs,df)

    return GraffeoTest(resp[:,1], resp[:,2], df.age, df.year, select(df,rate_predictors), strata, group, rt)
end
lrnv commented 4 months ago

Looks like R considers groups to be "grouping variables X strats variables"while we did not. I just corrected it. I had to increase a bit the tolerence but I think it is Ok. Let's wait for the online tests to pass and then we can merge.

Sorry for the issue, indeed it had to do with the rewriting of the join() functions that i did, I had forgotten a part of it. Glad we catched it :)

rimhajal commented 4 months ago

I'm sorry I should've added a test for the strat version before, we would've caught this issue earlier on

lrnv commented 4 months ago

One more reason to rework the testing suite !