Genentech / jmpost

https://genentech.github.io/jmpost/
17 stars 4 forks source link

Improve Stan code parser #97

Closed gowerc closed 11 months ago

gowerc commented 1 year ago

Currently the StanModules class parses stan code and splits it out into a list of blocks i.e. list(data="", parameters="", model=""). This parsing is done via the as_stan_fragments() function however it is very fragile and makes strong assumptions such as that the start of each block must be on its own line, for example the following would be invalid:

model{ x ~ normal(1, 2) }

as its all on 1 line.

It would be good to try and update this function to make it more flexible / robust. (Extensive unit testing will be required to make sure it can handle edge cases well.

gowerc commented 1 year ago

Removed backlog as this keeps causing subtle issues. Is worth double checking if any of the existing packages provides a parser

gowerc commented 1 year ago

ChatGPT suggested the following alternative:

#' Convert a Stan Program to a Named List of Code Blocks
#'
#' This function reads a Stan program file and returns a named list of code blocks,
#' where each block is represented as a character vector of lines.
#'
#' @param stan_program A character string specifying the path to a Stan program file.
#'
#' @return A named list with each element corresponding to a code block in the
#' Stan program. The list names correspond to the block names, such as "functions",
#' "data", "transformed data", "parameters", "transformed parameters", "model",
#' and "generated quantities". Each element is a character vector containing the
#' lines of code in the respective block.
#'
#' @examples
#' \dontrun{
#' # Assuming you have a Stan program file called "example.stan"
#' stan_program_path <- "path/to/your/example.stan"
#' named_list <- stan_to_named_list(stan_program_path)
#' print(named_list)
#' }
#'
#' @export
stan_to_named_list <- function(stan_program) {
    stan_lines <- readLines(stan_program)
    stan_blocks <- c("functions", "data", "transformed data", "parameters", "transformed parameters", "model", "generated quantities")
    block_pattern <- paste0("^(", paste(stan_blocks, collapse = "|"), ")\\s*\\{?")

    block_code <- list()
    current_block <- NULL

    for (line in stan_lines) {
        block_start <- grep(block_pattern, line)

        if (length(block_start) > 0) {
            current_block <- sub("\\s*\\{.*", "", line)
            print(current_block)
            block_code[[current_block]] <- character()

            # Check if the block starts and ends on the same line
            if (grepl("}", line)) {
                code_line <- sub("^.*\\{\\s*", "", line)
                code_line <- sub("\\s*}.*$", "", code_line)
                block_code[[current_block]] <- append(block_code[[current_block]], code_line)
                current_block <- NULL
            }
        } else if (line == "}") {
            current_block <- NULL
        } else if (!is.null(current_block)) {
            block_code[[current_block]] <- append(block_code[[current_block]], line)
        }
    }

    return(block_code)
}

Not yet tested it

gowerc commented 12 months ago

@danielinteractive - Frustratingly there is already a fully developed stan code parser provided by cmdstan e.g.

model.stan

 parameters { real mu_x; real<lower=0.00000001> sigma_x; }
/Users/gowerc/.cmdstan/cmdstan-2.33.1/bin/stanc --print-canonical ./model.stan

parameters {
  real mu_x;
  real<lower=0.00000001> sigma_x;
}

Which is then trivial to parse. Unfortunately though it checks for correct syntax and won't accept our pre-resolved jinja files. e.g.

/Users/gowerc/.cmdstan/cmdstan-2.33.1/bin/stanc --print-canonical ./inst/stan/base/base.stan
   -------------------------------------------------
     5:      //
     6:
     7:  {% if link_none %}
         ^
     8:      // If user has requested link_none then provide a dummy link_contribution function
     9:      // that does nothing
   -------------------------------------------------

Potentially could try to re-work the code to ensure all jinja template sections are fully resolved before attempting to parse the files though this.... Skimming the code I think there are only a handful of sections that don't do this already resolve the jinja code before offloading to StanModule().

danielinteractive commented 12 months ago

Sounds like a good idea!

gowerc commented 12 months ago

:( this might be a dead end. It also throws errors for undefined variables (e.g. variables defined in the base file which haven't yet been included) and I can't see an obvious way to suppress this. I'll post a question on the STAN forum and see if there is any guidance but I think this use case might be too tangental, might just have to resort to a "by-hand" parser with some notes on limitations.

EDIT - Link to Stan forum question https://discourse.mc-stan.org/t/is-there-a-r-stan-code-parser/33244

gowerc commented 12 months ago

@danielinteractive - I guess to re-cap the issue here. We currently have a basic working parser, however for it to work it assumes that each block is formatted as:

<name> {
    <code>
}

In particular that the data { is on a separate line to the actual code that goes into the block. That is to say that a one line block (or even one line multiple block) won't be parsed even though they are valid stan programs e.g the following fails:

data { int n;  array[n] real x; }
parameter { real mu; real sigma; } model { x ~ normal(mu, sigma) };

The issue is that parsing this purely by regex is incredibly challenging as there are many pitfalls such as commented out code + nested blocks. I'm struggling to work out how to do this without parsing the AST or going character by character which would be a major hassle.

As far as I can see I think we have 3 options:

gowerc commented 12 months ago

I guess for reference:

rstanarm solves this issue by spliting their file fragments per block e.g. all code related to the data block is stored under the data folder: image

brms side steps this issue by generating nearly all their code within R rather than storing it in stan file. The only exception being some re-used functions that are stored in the inst directory.

gowerc commented 12 months ago

I must admit part of me likes keeping the multiple blocks within a single file as its often highly related logic which would be a pain to separate and flick back and forth across many files. I am tempted to say the easiest option would be some mini-file-fragment syntax like:


// [data]
    int n;
    array [n] real x;

// [parameters]
    // this is then a regular comment
    real mu;
    real sigma;

// [model]
    x ~ normal(mu, sigma);

EDIT - Only annoyance with this is that idealy you would want to give it a different file extension like .stanf to signify that its a fragement and not a proper stan file but by doing so you would lose auto-complete / syntax highlighting etc from IDEs

danielinteractive commented 12 months ago

hm I would go with the option "Just put in our docs that we don't support 1-line blocks and add some more error handling to flag these cases." I would also say normally I would never use a 1-line block intuitively

gowerc commented 12 months ago

For reference I am just dumping my code that used the stanc parser encase for some reason we do ever want to come back to this, but for now I am going to revert the banch back to main (😢 ) and then just add in a few more tests to try and catch & throw errors on the known problem edge cases.

as_canonical_stan <- function(x) {
    if (!(length(x) == 1 && is_file(x))) {
        fi <- paste0(tempfile(), ".stan")
        con <- file(fi)
        cat(paste0(x, collapse = "\n"), file = con)
        close(con)
    } else {
        fi <- x
    }
    stanc_exe <- file.path(cmdstanr::cmdstan_path(), "bin", "stanc")
    suppressWarnings({
        result <- system2(
            command = c(stanc_exe, "--print-canonical", fi),
            stderr = TRUE,
            stdout = TRUE
        )
    })
    if (!is.null(attr(result, "status"))) {
        msg <- sprintf(
            "Unable to parse stan file. Got error message:\n\n%s",
            paste(result, collapse = "\n")
        )
        stop(msg)
    }
    result
}