rmcelreath / rethinking

Statistical Rethinking course and book package
2.1k stars 596 forks source link

How to construct a spline model using ulam() #415

Open JuneXiLiang opened 7 months ago

JuneXiLiang commented 7 months ago

Hi, I am trying to construct a spline model using ulam(). Here is my R code:

library(rethinking)
library(splines)
data(cherry_blossoms)
d <- cherry_blossoms
d2 <- d[complete.cases(d$doy),]

B <- bs(d2$year,df=13,degree=3,intercept=TRUE)
m4.7 <- ulam(
    alist(
        D ~ dnorm(mu,sigma),
        mu <- a + B %*% w,
        a ~ dnorm(100,10),
        vector[n]:w ~ dnorm(0,10),
        sigma ~ dexp(1)
    ),
    data = list(D=d2$doy,B=B)
)

It can not run properly. I think it is because of the wrong translation to stan code, which is:

data{
     int n;
     vector[827] D;
     matrix[827,13] B;
}
parameters{
     real a;
     vector[n] w;
     real<lower=0> sigma;
}
model{
     vector[10751] mu;
    sigma ~ exponential( 1 );
    w ~ normal( 0 , 10 );
    a ~ normal( 100 , 10 );
    for ( i in 1:10751 ) {
        mu[i] = a + B[i] * w;
    }
    D ~ normal( mu , sigma );
}

ulam() can not identifies %*% correctly. mu should be an 827x1 vector, but in the stan code, it is considered as a 10751x1 vector.

I wonder how to fix this?

Thanks for your help.

wesleyburr commented 7 months ago

vector[n]:w ~ dnorm(0,10)

I don't see a definition of n here? You want vector[13]:w ~ dnorm(0, 10) right?

JuneXiLiang commented 7 months ago

vector[n]:w ~ dnorm(0,10)

I don't see a definition of n here? You want vector[13]:w ~ dnorm(0, 10) right?

Yes, I forgot to define n. But it is not the essential problem.

I change n to 13. It still throws exception that element out of range.

I think the problem is that ulam() can not identifies %*% correctly. mu should be an vector[827] rather than vector[10751].

wesleyburr commented 7 months ago

This was the point of my previous ticket for Richard. ulam does recognize %*%, but the issue is that the inputs aren't initialized in a way that is going to make sense. If B is [827, 13], and your input has B %*% w, then you need to ensure that w is [13, 1] so the matrix product drops dimension as expected, and ends up being 827 elements. Right now, the non-initialized version (at least, your first bit of code) is:

You could try the previous, with the added bit of vector[13]:w ~ dnorm(0, 10), and also add n = 13 as an input in your data(list= bit.

JuneXiLiang commented 7 months ago

This was the point of my previous ticket for Richard. ulam does recognize %*%, but the issue is that the inputs aren't initialized in a way that is going to make sense. If B is [827, 13], and your input has B %*% w, then you need to ensure that w is [13, 1] so the matrix product drops dimension as expected, and ends up being 827 elements. Right now, the non-initialized version (at least, your first bit of code) is:

  • not initializing n, so it's just a floating variable - unless you pass this in as part of your data(list=, it's not going to do anything
  • thus not actually getting a properly defined w
  • thus the matrix product is getting translated to a sum over 827 * 13 elements instead of only 827

You could try the previous, with the added bit of vector[13]:w ~ dnorm(0, 10), and also add n = 13 as an input in your data(list= bit.

I understand what you mean. However, even though I initialize n, ulam() still can not recognize %*%.

Here is my new R code:

library(rethinking)
library(splines)
data(cherry_blossoms)
d <- cherry_blossoms
d2 <- d[complete.cases(d$doy),]
B <- bs(d2$year,df=13,degree=3,intercept=TRUE)
m4.7 <- ulam(
    alist(
        D ~ dnorm(mu,sigma),
        mu <- a + B %*% w,
        a ~ dnorm(100,10),
        vector[13]:w ~ dnorm(0,10),
        sigma ~ dexp(1)
    ),
    data = list(D=d2$doy,B=B)
)

The corresponding stan code is:

data{
     vector[827] D;
     matrix[827,13] B;
}
parameters{
     real a;
     vector[13] w;
     real<lower=0> sigma;
}
model{
     vector[10751] mu;
    sigma ~ exponential( 1 );
    w ~ normal( 0 , 10 );
    a ~ normal( 100 , 10 );
    for ( i in 1:10751 ) {
        mu[i] = a + B[i] * w;
    }
    D ~ normal( mu , sigma );
}

I even try matrix[13,1]:w ~ dnorm(0,10) and start = list(w=rep(0,13)).