wbuchanan / crossvalidate

Stata package to implement cross-validation methods for statistical models
https://wbuchanan.github.io/crossvalidate
2 stars 1 forks source link
cross-validation crossvalidation stata

Stata Crossvalidation

crossvalidation

The crossvalidate package includes several commands and a Mata library that provide a range of possible cross-validation techniques that can be used with any Stata estimation command returning results in e(). For the majority of users and use cases, the prefix commands (see xv and xvloo) should handle your needs. If, however, you need to implement something a bit different from generic use cases, the package also includes lower level commands that can save you time from having to code the entire cross-validation process. These commands are named after the four-steps found in all cross-validation work: splitit, fitit, predictit, and validateit. There are also a few utility commands that take care of the metaprogramming tasks needed to allow these commands to be applied to the correct fold/split of the data.

Lastly, we implemented the majority of validation metrics that can be found in the R package yardstick in our Mata library so you don't have to. However, if you want to implement your own validation metric that is possible and easy to do (see information below which specifies the function signature to use for your Mata function that will implement your metric) and easy to use with the existing tools (i.e., pass the name of your function as an argument to the metric or monitors options of either the prefix commands or validateit and it will handle the rest).

Examples:

// Load example dataset
sysuse auto.dta, clear

// Simple train/test (TT) split
xv 0.8, pstub(ttpred) metric(mse): reg price mpg length

// Simple train/validation/test (TVT) split
xv 0.6 0.2, pstub(tvtpred) metric(mse) monitors(mape smape): reg price mpg length

// Leave-One-Out cross valiation with a train/test split
xvloo 0.8, pstub(ttloopred) metric(mse): reg price mpg length

// LOO TVT split
xvloo 0.6 0.2, pstub(tvtloopred) metric(mse): reg price mpg length

// K-Fold TT split
xv 0.8, pstub(ttkfpred) metric(mae) kfold(5): reg price mpg length if !mi(rep78)

// K-Fold TVT split
xv 0.6 0.2, pstub(tvtkfpred) metric(mbe) kfold(3): reg price mpg length, vce(rob)

// Clustered K-Fold TT Split
xv 0.8, metric(phl) uid(rep78) kfold(4) display retain: reg price mpg length, vce(rob)

TODO

libxv

Metrics/Monitors

Method Signature

The program will allow users to define their own metrics/monitors that are not contained in libcrossvalidate. In order to do this, users must implement a specific method/function signature:

real scalar metric(string scalar pred, string scalar obs, 
                   string scalar touse, | transmorphic matrix opts)

The function must return a real valued scalar and take three arguments. The three arguments are used to access the data that would be used to compute the metrics/monitors and to provide a method to pass optional arguments to the underlying functions if supported.

Data access

Within the function body, we recommend using the following pattern to access the data needed to compute any metrics/monitors:

real colvector yhat, y
yhat = st_data(., pred, touse)
y = st_data(., obs, touse)

The programs in the cross validate package will handle the construction of the variables and passing them to the function name that users pass to the programs.

Building the library

Once we are ready to build the Mata library we should do the following using an instance of Stata 15.

// Clear everything out of Mata
mata: mata clear 

// Define all of the Mata functions in memory
run crossvalidate.mata

// If the library is already built, use this instead:
lmbuild libxv, replace dir(`"`c(pwd)'"')

Prefix Commands

xv

xv # [#], MEtric(string asis) [seed(integer) Uid(varlist) TPoint(string asis)
             SPLit(string asis) KFold(integer) RESults(string asis) 
             fitnm(string asis) Classes(integer) PStub(string asis) noall 
             MOnitors(string asis) DISplay RETain valnm(string asis)
             PMethod(string asis) POpts(string asis)] : 
             estimation command ...

Syntax and options

xvloo

xvloo # [#], MEtric(string asis) [ seed(integer) Uid(varlist) TPoint(string asis)
             SPLit(string asis) RESults(string asis) fitnm(string asis) 
             Classes(integer) PStub(string asis) noall MOnitors(string asis) 
             DISplay RETain valnm(string asis) PMethod(string asis)
             POpts(string asis) ] : 
             estimation command ...

Syntax and options

Phase Specific Commands

splitit

splitit # [#] [if] [in] [, Uid(varlist) TPoint(string asis) KFold(integer 1) 
                           SPLit(string asis) loo ]

Syntax and options

fitit

fitit anything(name = cmd), SPLit(passthru) RESults(string asis) 
            [ KFold(integer 1) noall DISplay NAme(string asis) ]

Syntax and options

predictit

predictit [anything(name = cmd)], PStub(string asis) 
            [ SPLit(passthru) Classes(integer 0) 
              KFold(integer 1) THReshold(passthru) 
              MODifin(string asis) KFIfin(string asis) noall 
              PMethod(string asis) POpts(string asis) ]

Syntax and options

validateit

validateit, MEtric(string asis) PStub(string asis) SPLit(string asis) 
          [ Obs(string asis) MOnitors(string asis) DISplay KFold(integer 1) 
            noall loo NAme(string asis) ]

Syntax and options

Utility commands

classify

classify # [if], PStub(string asis) [ THReshold(real 0.5) PMethod(string asis) 
                                      POpts(string asis) ]

Syntax and options

state

state

Syntax and options

No options

cmdmod

cmdmod anything(name = cmd id = "estimation command"), 
        SPLit(varlist min = 1 max = 1) [ KFold(integer 1) ]

Syntax and options

libxv

libxv [, DISplay ]

Syntax and options