MeasureTransport / MParT

Monotone Parameterization Toolkit (MParT): A core library for constructing and using transport maps.
https://measuretransport.github.io/MParT/
BSD 3-Clause "New" or "Revised" License
13 stars 4 forks source link

MParT Terminates Julia sessions for some dimension sizes #331

Closed dannys4 closed 1 year ago

dannys4 commented 1 year ago

I'm attaching the closest thing to a MWE.

using MParT, Random, Statistics, CxxWrap

rng = Xoshiro(2039482)

function banana_ND(rng::AbstractRNG, dim::Int, N_samples::Int)
    ref_samps = randn(rng, dim, N_samples)
    target_samps = similar(ref_samps)
    # First dimension is normally distributed samples Z_1
    target_samps[1,:] = ref_samps[1,:]
    for j = 2:dim
        # Dim j is normal conditioned on Z_j-1, and squared normal conditioned on Z_j
        target_samps[j,:] = ref_samps[j,:] + ref_samps[j-1,:].^2
    end
    target_samps
end

function centering(samples::Matrix;dims::Int=2)
    # Subtract mean and divide by marginal std
    samp_m = mean(samples,dims=dims)
    samp_s = std(samples, dims=dims, mean=samp_m)
    (samples .- samp_m)./samp_s
end

function banana_ND_msets(dim::Int)
    msets = Vector{MultiIndexSet}(undef, dim)
    first_mset = [0;1;;] # The first mset should have constant and linear terms

    general_mset = [0 0; # This is what the MultiIndexSet should
                    0 1; # look like in the last two dimensions for
                    1 0; # all subsequent dimensions
                    2 0]
    mset_len = size(general_mset,1) # How many multiindices should be in each mset
    msets[1] = MultiIndexSet(first_mset) # First mset
    for j = 2:dim
        # Each subsequent dimensions should be zeros in all dimensions
        # except for the last two, which follow the general mset
        msets[j] = MultiIndexSet([zeros(Int,mset_len, j-2) general_mset])
    end
    msets
end

# Set the dimension of the banana and number of training samples
banana_dim, N_samples = 8, 20_000
# Collect samples from the banana and center them
samps = centering(banana_ND(rng,banana_dim,N_samples))
# Construct map options
map_opts = MapOptions(basisType="HermiteFunctions",basisLB=-3.,basisUB=3.)
# Construct msets
msets = banana_ND_msets(banana_dim)
# Create the components for the map
# StdVector necessary or else error is thrown (don't know why)
comps = StdVector([CreateComponent(Fix(mset,true), map_opts) for mset in msets])
# Create map and objective
trimap = TriangularMap(comps)
obj = CreateGaussianKLObjective(samps)
## Train the map
train_err = TrainMap(trimap, obj, TrainOptions(verbose=true))

If banana_dim is <=5, this seems to do just fine. When I try it with 6,7,8, this usually just aborts and crashes without any info-- something to the effect of

TrainMap: Initializing map coeffs to 1.
Optimization Settings:
Algorithm: Sequential Quadratic Programming (SQP) (local, derivative)
Optimization dimension: 26
Optimization stopval: -inf
Max f evaluations: 1000
Maximum time: inf
Relative x Tolerance: 0.0001
Relative f Tolerance: 0.001
Absolute f Tolerance: 0.001
terminate called after throwing an instance of 'std::runtime_error'
terminate called recursively

signal (6): Aborted
in expression starting at /home/dannys4/misc/mpart_examples/multidim_banana.jl:57
Allocations: 10025392 (Pool: 10020279; Big: 5113); GC: 9
Aborted (core dumped)

Every now and then, I instead encounter the classic never-ending "nans encountered in monotone integrand" errors, which is odd because that usually happens when I'm trying to invert the map.

dannys4 commented 1 year ago

Interestingly, the "same" script seems to work just fine in python

import os
os.environ['KOKKOS_NUM_THREADS'] = '2'
import mpart as mt
import numpy as np

np.random.seed(1029394)

def banana_ND(dim, N_samples):
    ref_samps = np.random.randn(dim, N_samples)
    target_samps = np.empty_like(ref_samps)
    # First dimension is normally distributed samples Z_1
    target_samps[0,:] = ref_samps[0,:]
    for j in range(1,dim):
        # Dim j is normal conditioned on Z_j-1, and squared normal conditioned on Z_j
        target_samps[j,:] = ref_samps[j,:] + ref_samps[j-1,:]**2
    return target_samps

def centering(samples,axis=1):
    # Subtract mean and divide by marginal std
    samp_m = np.mean(samples,axis=axis)
    samp_s = np.std( samples,axis=axis)
    return (samples - samp_m[:,None])/samp_s[:,None]

def banana_ND_msets(dim):
    msets = []
    first_mset = np.array([[0],[1]]) # The first mset should have constant and linear terms

    general_mset = [[0, 0], # This is what the MultiIndexSet should
                    [0, 1], # look like in the last two dimensions for
                    [1, 0], # all subsequent dimensions
                    [2, 0]]
    general_mset = np.array(general_mset)
    mset_len = general_mset.shape[0] # How many multiindices should be in each mset
    msets.append(mt.MultiIndexSet(first_mset)) # First mset
    for j in range(1,dim):
        # Each subsequent dimensions should be zeros in all dimensions
        # except for the last two, which follow the general mset
        mset_j = np.column_stack((np.zeros((mset_len,j-1)),general_mset))
        msets.append(mt.MultiIndexSet(mset_j))
    return msets

if __name__ == '__main__':
    # Set the dimension of the banana and number of training samples
    banana_dim, N_samples = 8, 20000
    # Collect samples from the banana and center them
    samps = centering(banana_ND(banana_dim,N_samples))
    # Construct map options
    map_opts = mt.MapOptions()
    mt.basisType="HermiteFunctions"
    mt.basisLB=-3.
    mt.basisUB=3.
    # Construct msets
    msets = banana_ND_msets(banana_dim)
    # Create last map component and objective
    comps = [mt.CreateComponent(mset.fix(True), map_opts) for mset in msets]
    trimap = mt.TriangularMap(comps)
    obj = mt.CreateGaussianKLObjective(np.asfortranarray(samps))
    trimap.SetCoeffs(np.zeros(trimap.numCoeffs))
    ## Train the map
    train_opts = mt.TrainOptions()
    train_opts.verbose = True
    train_err = mt.TrainMap(trimap, obj, train_opts)
    print(f'Training error = f{train_err}')
mparno commented 1 year ago

I'm guessing this is related to segfaults occuring as a result of the outer stride issue in #332

dannys4 commented 1 year ago

I had a chance to test, and it seems to resolve with the same fix as #332, so I'm going to link it to the same PR.