StanJulia / StanSample.jl

WIP: Wrapper package for the sample method in Stan's cmdstan executable.
MIT License
18 stars 4 forks source link

Generic interface for read_samples and read_csv_files #39

Closed itsdfish closed 3 years ago

itsdfish commented 3 years ago

Hi Rob,

I wanted to use this package to read Stan output files for the LOO and WAIC tests. Unfortunately, this required a model object. It might be useful to have two methods: one based on the model object, for convenience, and a generic method that does not require any special object.

I tried to split the methods up, but ran into some issues that I could not resolve easily. Here is what I started on. The function convert_a3d threw an error because I passed start=1001.

function read_csv_files(model::SampleModel, output_format=:namedtuple;
  include_internals=false, kwargs...)
  # File path components of sample files (missing the "_$(i).csv" part)
  output_base = model.output_base
  name_base ="_chain"

  # How many samples?
  if model.method.save_warmup
      n_samples = floor(Int,
      (model.method.num_samples+model.method.num_warmup)/model.method.thin)
  else
      n_samples = floor(Int, model.method.num_samples/model.method.thin)
  end

  # How many chains?
  n_chains = model.n_chains[1]
  file_names = map(c->output_base*name_base*"_$(c).csv", 1:n_chains)
  return read_csv_files(file_names, n_samples, output_format; include_internals, kwargs...)
end

function read_csv_files(file_names, n_samples, output_format=:namedtuple;
      include_internals=false, kwargs...)

    local a3d, monitors, index, idx, indvec, ftype, noofsamples
    n_chains = length(file_names)
    # First samples number returned
    start = (:start in keys(kwargs)) ? values(kwargs).start : 1
    # Read .csv files and return a3d[n_samples, parameters, n_chains]
    for (i,file_name) in enumerate(file_names)
        if isfile(file_name)
            instream = open(file_name)

            # Skip initial set of commented lines, e.g. containing cmdstan version info, etc.      
            skipchars(isspace, instream, linecomment='#')

            # First non-comment line contains names of variables
            line = Unicode.normalize(readline(instream), newline2lf=true)
            idx = split(strip(line), ",")
            index = [idx[k] for k in 1:length(idx)]      
            indvec = 1:length(index)
            n_parameters = length(indvec)

            # Allocate a3d as we now know number of parameters
            if i == 1
                a3d = fill(0.0, n_samples, n_parameters, n_chains)
            end

            skipchars(isspace, instream, linecomment='#')
            for j in 1:n_samples
                skipchars(isspace, instream, linecomment='#')
                line = Unicode.normalize(readline(instream), newline2lf=true)
                if eof(instream) && length(line) < 2
                    close(instream)
                    break
                else
                    flds = parse.(Float64, split(strip(line), ","))
                    flds = reshape(flds[indvec], 1, length(indvec))
                    a3d[j,:,i] = flds
                end
            end   # read in samples
        end   # read in next file if it exists
    end   # read in file for each chain

    cnames = convert.(String, idx[indvec])
    if include_internals
        snames = [Symbol(cnames[i]) for i in 1:length(cnames)]
        indices = 1:length(cnames)
    else
        pi = filter(p -> length(p) > 2 && p[end-1:end] == "__", cnames)
        snames = filter(p -> !(p in  pi), cnames)
        indices = Vector{Int}(indexin(snames, cnames))
    end 

    res = convert_a3d(a3d[:, indices, :], snames, Val(output_format); kwargs...)

    (res, snames) 
end   #
goedman commented 3 years ago

I'll have a look. For a long time I had on my todo list to check the start argument. Would you like to read multiple csv files.

itsdfish commented 3 years ago

I needed two read multiple files for the WAIC/LOO tests. However, it might be nice to have to option to pass an array of file names or a single file name. I think the easiest way to achieve both use cases would be to dispatch on a String and wrap it in a vector:

read_csv_files(file_name::String, n_samples, output_format=:namedtuple;
      include_internals=false, kwargs...) = read_csv_files([file_name], n_samples, output_format=:namedtuple;
      include_internals=false, kwargs...)

One advantage of this approach is that it would not require modifying your original function to accomodate a case for reading a single file.

Personally, I wish Stan would save its metadata in a separate file. So that we can load chains easily. I'm sure you have a similar opinion.

goedman commented 3 years ago

Yes, that is a wish that's been around for a long, long time. I haven't really looked at the JSON format.

Do you mean n_samples or n_chains in above signature? :namedtuple is the default output_format and it organizes the parameters such as a.1, a.2, etc. in arrays.

itsdfish commented 3 years ago

I passed n_samples to determine the number of rows and used the length of file_names to infer the number of chains.

goedman commented 3 years ago

I've added a function (not exported) StanSample.read_csv() to StanSample.jl. Basically you provide a base path, n_chains and n_samples (for each chain) and the output_format (by default :namedtuple). But there are many roads that lead to Rome, yours in the test script also works.

itsdfish commented 3 years ago

Thanks, Rob. I think this will be useful for others as well. My approach to reading the files was very inelegant: I manually removed Stan's metadata and used CSV to read in the files. Although it works, it is not ideal if we need to change the test. I will update the test accordingly later.

goedman commented 3 years ago

Hi Chris, I definitely need to do more testing and also check if the start argument works. Once convinced, I prefer to use this code in the current read_csv_files as well. And I have also an option on my list to read selected chains, e.g. 1,3 and 4.

Will also take a closer look at what tests we have in place now and how your test results compare with R's loo.