library(dplyr)
library(tidyr)
library(ggplot2)
library(rstap)
library(plot3D)
set.seed(25124)
num_subj_init <- 1.5E2
num_bef_init <- 50
possible_dates <- seq(from=as.Date('1970/01/01'), as.Date('1999/12/31'), by="day")
initial_dates <- sample(possible_dates,size = num_subj_init,replace = T)
pos_dates_1 <- seq(from=as.Date('2000/01/01'), as.Date('2000/12/31'), by="day")
pos_dates_2 <- seq(from=as.Date('2002/01/01'),as.Date('2003/12/31'),by='day')
pos_dates_3 <- seq(from=as.Date('2005/01/01'),as.Date('2006/12/31'),by='day')
pos_dates_4 <- seq(from=as.Date('2008/01/01'),as.Date('2009/12/31'),by='day')
dates_1 <- sample(pos_dates_1, size = num_subj_init, replace = T)
dates_2 <- sample(pos_dates_2,size = num_subj_init, replace = T)
dates_3 <- sample(pos_dates_3,size = num_subj_init, replace = T)
dates_4 <- sample(pos_dates_4,size = num_subj_init, replace = T)
time_1 <- (initial_dates - initial_dates)
times_2 <- (dates_2-initial_dates)
times_3 <- (dates_3 - initial_dates)
times_4 <- (dates_4 - initial_dates)

subj_data <- data_frame(x = runif(min = -1, max = 1, n = num_subj_init),
                        y = runif(min = -1, max = 1, n = num_subj_init),
                        date = initial_dates,
                        subj_ID = 1:num_subj_init,
                        class = "Subject")

 
bef_data <- data_frame(x = runif(min = -1, max = 1, n = num_bef_init),
                       y = runif(min = -1, max = 1, n = num_bef_init),
                       date_open = sample(possible_dates,size = num_bef_init),
                       date_close = as.Date(NA),
                       bef_ID = 1:num_bef_init,
                       class = "Coffee_Shop")



DOBS <- sample(seq(from=as.Date('1952/01/01'), as.Date('1969/12/31'), by="day"),
               size = num_subj_init,replace = T)
subj_sex <- rbinom(n = num_subj_init,size = 1,prob = .5)
subj_fdata <- data_frame(subj_ID = 1:num_subj_init,
                         sex = subj_sex,
                         DOB = DOBS,
                         measure_date = dates_1,
                         measure_ID = 1)
subj_fdata <- rbind(subj_fdata,
                    data_frame(subj_ID = 1:num_subj_init,
                                sex = subj_sex,
                               DOB = DOBS,
                                measure_date = dates_2,
                               measure_ID = 2)) %>% 
    rbind(.,data_frame(subj_ID = 1:num_subj_init,
                       sex = subj_sex,
                       DOB = DOBS,
                       measure_date = dates_3,
                       measure_ID = 3)) %>% 
    rbind(.,data_frame(subj_ID = 1:num_subj_init,
                       sex = subj_sex,
                       DOB = DOBS,
                       measure_date = dates_4,
                       measure_ID = 4))

subj_mdata <- subj_data %>% left_join(subj_fdata,by="subj_ID") %>% 
     select(x,y,class,subj_ID,date,measure_ID)

td_data <- bef_data %>% crossing(subj_fdata) %>% 
    left_join(subj_mdata,by = c("subj_ID","measure_ID"))  %>% 
    mutate(dist = sqrt((x.x - x.y)^2 + (y.x-y.y)^2),
           time_open = (measure_date - pmax(date_open,date)),
           time_close = date_close - pmax(date_open,date),
           time = pmin(time_open,time_close,na.rm = T)/365,
           time = as.numeric(time)*(time>0),
           class = class.x) %>% 
    select(subj_ID,measure_ID,bef_ID,measure_date,
           date_open,date_close,date,class,dist,time)

alpha <- 25
delta <- c(sex = .8)
beta <- c(Coffee_Shop=2)
theta_s <- .8
theta_t <- 34
sigma <- .5
tau <- 1.2
tau_2 <- .3
rho <- .1
cov <- tau*tau_2*rho
Sigma <- cbind(c(tau,cov),c(cov,tau_2))
d <- seq(from = 0, to = max(td_data$dist), by = 0.01)
t <- seq(from = 0, to = as.numeric(max(td_data$time)), by = 0.05)
w_s <- pracma::erfc(d/theta_s)
w_t <- pracma::erf(t/theta_t)
par(mfrow=c(1,2))
plot(d,w_s,type='l', main="Spatial Decay",xlab = "Distance", ylab ="Exposure")
plot(t,w_t,type = "l", main = "Temporal Accumulation",xlab='years', ylab = "Exposure")
st_exp <- matrix(NA,nrow=length(d),ncol=length(t))
for(d_ix in 1:length(d)){
    for(t_ix in 1:length(t))
        st_exp[d_ix,t_ix] <- pracma::erfc(d[d_ix]/theta_s)*pracma::erf(t[t_ix]/theta_t)
}
persp3D(d,t,st_exp,phi = 15,theta = 45,xlab="Distance",
        ylab="Time (years)", zlab = "Exposure", ticktype='detailed',
        main = "Spatial Temporal Exposure Across Space and Time")
Xs <- td_data %>% 
    group_by(subj_ID,measure_ID,date) %>% 
    summarize(total_exposure = sum(pracma::erfc(dist/theta_s) *
                                       pracma::erf(as.numeric(time)/theta_t)) ) %>% 
    as_data_frame() %>% gather(contains("Exposure"),key='Exposure_Type',value="Exposure") %>% 
    filter(Exposure_Type == "total_exposure") %>% 
    group_by(subj_ID,measure_ID) %>% 
    summarise(Exposure = sum(Exposure))

global_mn <- mean(Xs$Exposure)
global_sd <- mean(Xs$Exposure)
Xs <- Xs %>% spread(measure_ID,Exposure) %>% 
    as_data_frame() %>% 
    transmute(subj_ID = subj_ID,
              j_1 = (`1` - global_mn)/global_sd,
              j_2 = (`2` - global_mn)/global_sd,
              j_3 = (`3` - global_mn)/global_sd,
              j_4 = (`4` - global_mn)/global_sd)

Xs_prep <- Xs %>% 
    gather(num_range("j_",1:4),key = "measure_ID",value="Coffee_Shop") %>% 
    mutate(measure_ID = as.integer(stringr::str_remove(measure_ID,"j_")))

subj_fdata <- subj_fdata %>% left_join(Xs_prep,
                                       by=c("subj_ID","measure_ID"))

# b_s <- sapply(1:10, function(x) rnorm(n = num_subj_init,mean = 0,sd = tau))
# tmp <- as_data_frame(cbind(data_frame(subj_ID = 1:num_subj_init,
#                                    first_measure = dates_1),
#                         b_s))

b_s <- MASS::mvrnorm(n = num_subj_init, mu = c(0,0), Sigma =  Sigma)
tmp <- data_frame(subj_ID = 1:num_subj_init,
                  ran_int = b_s[,1],
                  ran_slope = b_s[,2],
                  first_measure = initial_dates)

subj_fdata <- subj_fdata %>% 
    left_join(tmp,by = "subj_ID") %>% 
    mutate(time_measure = (as.numeric((measure_date - first_measure))/365)/100 )

Xs %>% gather(num_range(prefix="j_",1:4),key="measurement",value="exposure") %>%
    ggplot(aes(x=exposure,fill=measurement)) + geom_density(alpha=0.3)  + theme_bw()
lme_inf <- lme4::glFormula(formula =  measure_ID ~ + 
                               sex + 
                               Coffee_Shop + (1 + time_measure|subj_ID),
                     data = subj_fdata,family = gaussian(link='identity'))

beta <- 3
design_mat <- as.matrix(cbind(1,lme_inf$fr[,2:3],1,lme_inf$fr[,4]))
#for multiple sims
# delta_beta_b <- lapply(1:10, function(x) cbind(alpha,delta,beta,pull(subj_fdata[,7+x])))
# y <- sapply(1:10, function(x) diag(tcrossprod(design_mat,as.matrix(delta_beta_b[[x]])))  +
#                 rnorm(n = nrow(design_mat),mean = 0,sd = sigma))
delta_beta_b <- cbind(alpha,delta,beta,subj_fdata$ran_int,subj_fdata$ran_slope)
y <- diag(tcrossprod(design_mat, as.matrix(delta_beta_b) + rnorm(n = nrow(design_mat), mean = 0, sd = sigma)  ))
hist(y)
fit <- stap_glmer(BMI ~ sex + stap(Coffee_Shop) + (1 + time_measure|subj_ID),
                  family = gaussian(link='identity'),
                  subject_data = subj_fdata,
                  distance_data = d_data,
                  subject_ID = 'subj_ID',
                  measure_ID = 'measure_ID',
                  time_data = t_data,
                  max_distance = max(d_data$dist),
                  max_time = max(t_data$time),
                  prior = normal(location = 0, scale = 4),
                  prior_intercept = normal(location = 25, scale = 5),
                  prior_stap = normal(location = 0, scale = 4),
                  prior_theta = list(Coffee_Shop = list(spatial = log_normal(1,1),
                                                        temporal = log_normal(2,1))),
                  prior_aux = cauchy(location = 0,scale = 5),
                  chains = 3,
                  cores = 3, iter = 6E2)
fit