####################################################################################
#                                                                                  # 
#  Validity checks as presented in the section 'Verification of model assumptions  #
#  on the basis of real data'                                                      #
#                                                                                  # 
####################################################################################


# Load the dataset 'BreastCancerConcatenation'.
# NOTE: The Rda-file loaded here is NOT included in the electronic appendix. To obtain it,
# one has to run the R-script 'BreastCancerConcatenation_preparationinfos.R' found in the folder
# 'FAbatchPaper/Datasets/PreparationScripts'.

load("./FAbatchPaper/Datasets/ProcessedData/BreastCancerConcatenation.Rda")


#############
#  FAbatch  #
#############

# Set seed to enable reproduction:
set.seed(1234)


# Load 'bapred'-package:
library("bapred")


# 'fabatch'-function from the bapred-package which was altered to return
#  additional information necessary for calculating the fitted values:

fabatch <- function(x, y, batch, nbf=NULL, minerr=1e-06, probcrossbatch=TRUE, maxiter=100, maxnbf=12) {   

  if(any(is.na(x)))
	stop("Data contains missing values.")
  if(!(is.factor(batch) & all(levels(batch)==(1:length(levels(batch))))))
    stop("'batch' has to be of class 'factor' with levels '1','2',...")  
  if(!is.matrix(x))
    stop("'x' has to be of class 'matrix'.") 
  if(!(is.factor(y) & all(levels(y)==(1:2))))
    stop("'y' has to be of class 'factor' with levels '1' and '2'.")
  if(max(table(batch)) >= ncol(x))
    stop("The sample size within each batch has to be smaller than the number of variables.")
  if((length(batch)!=length(y)) | (length(y)!=nrow(x)))
    stop("At least one pair out of 'x', 'y' and 'batch' is incompatible.")
	
   ##require("glmnet")

   batches = levels(batch) 
   nbatches = length(batches)
   nvars <- ncol(x)

   biggestbatch <- names(table(batch))[which.max(table(batch))]

   ylevels = levels(y)


   # If a variable is constant in a batch, infinite values
   # or missing values will arise, when scaling by batch.
   # To avoid this, the following is done:
   # Variables, for which this phenomen occurs, are left
   # out during scaling and factor estimation and only
   # adjusted with respect to their mean.
	 
   # Calculate batch-specific standard deviations:
   sdb = as.list(rep(0,nbatches))
   for (i in 1:nbatches) {
      sdb[[i]] = apply(x[batch==batches[i],],2,sd)
   }
   sdb0=sdb

   badvariableslist <- lapply(sdb, function(x) which(x==0))
   badvariables <- sort(unique(unlist(badvariableslist)))
   goodvariables <- setdiff(1:ncol(x), badvariables)

   if(length(badvariables) > 0) {
     xsafe <- x
     x <- x[,goodvariables]
     sdb <- lapply(sdb, function(x) x[goodvariables])
     sdb0=sdb
   }

   # overall variable means:
   
   meanoverall <- colMeans(x)
   
   # First center and scale the variables per batch to apply the Lasso
   # for obtaining the probability estimates:
    
   # Remove batch specific means:
   if(nbatches > 1) {
     adjustmentmod = lm(x~batch)
     design = model.matrix(~batch)
     adjustmentcoef = coef(adjustmentmod)
     xb = x-design%*%adjustmentcoef
     adjustmentcoef0 = adjustmentcoef
   } else
     xb <- scale(x, scale=FALSE)  

   # Pooled variance:

   pooledsds <- apply(xb, 2, sd)   


   # 'scaledxb' is X centered and scaled per batch:    
   scaledxb = xb
   for (i in 1:nbatches) {
         scaledxb[batch==batches[i],] = scale(xb[batch==batches[i],],center=rep(0,ncol(x)),scale=sdb[[i]])
   }

   # Now lasso is trained on the batch-centered and batch-scaled X and the fitted probability
   # of each y to belong to class 1 or 2 calculated - 'proba' (There is probably a problem with overfitting:
   # When the probabilities are much more shifted towards 0 and 1 in the training data than in the 
   # test data, the training and test data might not be that comparable anymore -> Solution might be
   # to use cross validation to estimate the probabilities here):   
   cvmod = glmnet::cv.glmnet(x=as.matrix(scaledxb),y=y,family="binomial",type.measure="deviance", alpha = 0)
   lambda.min = cvmod$lambda.min
   mod = glmnet::glmnet(x=as.matrix(scaledxb),y=y,family="binomial",lambda=lambda.min, alpha = 0)


   if(probcrossbatch) {

   if(nbatches > 1) {
     proba <- rep(NA, nrow(x))
     for(i in 1:nbatches) {
       modtemp = glmnet::glmnet(x=as.matrix(scaledxb[batch!=batches[i],]),y=y[batch!=batches[i]],family="binomial",lambda=lambda.min, alpha = 0)
       b0temp = modtemp$a0
       btemp = modtemp$b  
	   proba[batch==batches[i]] = as.vector(1/(1+exp(-b0temp-as.matrix(scaledxb[batch==batches[i],])%*%btemp))) # mean(ifelse(p1x>0.5,2,1)!=ytest)
	 }   
   } else {
     warning("Number of Training batches equal to one. Using ordinary cross-validation for preliminary class probability estimation.")
     cvmod = glmnet::cv.glmnet(x=as.matrix(scaledxb),y=y,family="binomial",type.measure="deviance", keep=TRUE, lambda=c(lambda.min, lambda.min+(1/1000)*lambda.min), alpha = 0)
     proba = cvmod$fit.preval[,1]
   }   

  }
  else {
    cvmod = glmnet::cv.glmnet(x=as.matrix(scaledxb),y=y,family="binomial",type.measure="deviance", keep=TRUE, lambda=c(lambda.min, lambda.min+(1/1000)*lambda.min), alpha = 0)
    proba = cvmod$fit.preval[,1]
  }      

   # 'b0' is intercept from Lasso and 'b' coefficient vector for variables:
   b0 = mod$a0
   b = mod$b


   # Calculate the class-specific mean vectors 'm1' and 'm2':
   m1 = apply(scaledxb[y==ylevels[1],],2,mean)
   m2 = apply(scaledxb[y==ylevels[2],],2,mean) 


   nbfvec <- rep(NA, nbatches)

   # Determine number of factors if not given:    
   if (!is.null(nbf)) {
     nbfvec <- rep(nbf, times=nbatches)
     nbfinput <- NULL
   }
   else
     nbfinput <- nbf

   # Calculate the factors on 'cdta' (batch-centered-scaled and class-removed):


   falist <- vector("list", nbatches)
   
   criterionall <- list()

   scaledxbfa <- scaledxb

   for (i in 1:nbatches) {
   
       if(is.null(nbf)) {
           ##require("mnormt")
           maxnbf2 <- min(c(maxnbf, floor(sum(batch==batches[i])/2)))
           nbfobj = nbfactors(scale(sweep(scale(scaledxb[batch==batches[i],],center=m1,scale=FALSE), 1, 1-proba[batch==batches[i]], "*") + 
              sweep(scale(scaledxb[batch==batches[i],],center=m2,scale=FALSE), 1, proba[batch==batches[i]], "*")), maxnbfactors=maxnbf2, minerr=minerr, maxiter=maxiter)
           nbfvec[i] <- nbfobj$optimalnbfactors
       }

       if(is.na(nbfvec[i])) {
         warning("There occured an issue in the factor estimation. Number of factors set to zero.")
         nbfvec[i] <- 0
       }   
    
	 if(nbfvec[i] > 0) {
        # Calculate the factors on 'cdta' (batch-centered-scaled and class-removed):
        fa = emfahighdim(sweep(scale(scaledxb[batch==batches[i],],center=m1,scale=FALSE), 1, 1-proba[batch==batches[i]], "*") + 
          sweep(scale(scaledxb[batch==batches[i],],center=m2,scale=FALSE), 1, proba[batch==batches[i]], "*"),nbf=nbfvec[i],minerr=minerr, maxiter=maxiter)

        # Remove the factor influences:
        scaledxbfa[batch==batches[i],] = scaledxb[batch==batches[i],] - fa$Factors%*%t(fa$B)

        falist[[i]] <- fa
		
     } else {
       scaledxbfa[batch==batches[i],] <- scaledxb[batch==batches[i],]
       fa <- NULL
     }

   } 
   
   means2batch <- sd2batch <- matrix(nrow=length(levels(batch)), ncol=ncol(scaledxbfa))

   # scale again:    
   for (i in 1:nbatches) {
         means2batch[i,] <- colMeans(scaledxbfa[batch==batches[i],])
         sd2batch[i,] <- apply(scaledxbfa[batch==batches[i],], 2, sd)
         scaledxbfa[batch==batches[i],] = scale(scaledxbfa[batch==batches[i],], center=means2batch[i,], scale=sd2batch[i,])
   }
   
   xfa <- sweep(sweep(scaledxbfa, 2, pooledsds, "*"), 2, meanoverall, "+")
   
   if(length(badvariables)>0) {
     xbadvar <- xsafe[,badvariables, drop=FALSE]
	 
	 sdbbad = as.list(rep(0,nbatches))
     for (i in 1:nbatches) {
       sdbbad[[i]] = apply(xbadvar[batch==batches[i],],2,sd)
     }
     whichzerosdmat <- t(sapply(sdbbad, function(x) x==0))

     sdbbad0to1 <- lapply(sdbbad, function(x) {
       x[x==0] <- 1
       x
     })
	 
     pooledsdsbad <- mapply(function(x, y) {
      if(!all(y)) {
        return(sd(x[batch %in% which(!y)]))
      }
      else
        return(1)
     }, as.data.frame(xbadvar), as.data.frame(whichzerosdmat))
	 	 
     mubadvar <- colMeans(xbadvar)
     xbadvaradj <- matrix(nrow=nrow(xsafe), ncol=length(badvariables), data=mubadvar, byrow=TRUE)
     for(i in 1:nbatches)
       xbadvaradj[batch==i,] <- xbadvaradj[batch==i,,drop=FALSE] + sweep(scale(xbadvar[batch==i,,drop=FALSE], center=TRUE, scale=sdbbad0to1[[i]]), 2, pooledsdsbad, "*")
     xfanew <- matrix(nrow=nrow(xsafe), ncol=ncol(xsafe))
     xfanew[,goodvariables] <- xfa
     xfanew[,badvariables] <- xbadvaradj
     xfa <- xfanew
   }

     meanvectorwhole <- rep(NA, nvars)
	 meanvectorwhole[goodvariables] <- meanoverall
	 if(length(badvariables)>0)
	   meanvectorwhole[badvariables] <- mubadvar

	 sdvectorwhole <- rep(NA, nvars)
	 sdvectorwhole[goodvariables] <- pooledsds
	 if(length(badvariables)>0)
	   sdvectorwhole[badvariables] <- 1
	   
   params <- list(xadj=xfa, m1=m1,m2=m2,b0=b0,b=b, pooledsds=sdvectorwhole, meanoverall=meanvectorwhole, minerr=minerr,
     nbfinput=nbfinput, badvariables=badvariables, nbatches=nbatches, batch=batch, nbfvec=nbfvec,
	 meanbatches=unique(design)%*%adjustmentcoef, sdbatches=sdb, falist=falist, proba=proba) ##
	 
  class(params) <- "fabatch"
	 
  return(params)
	 
}


# Apply the altered fabatch-function to the dataset:

params <- fabatch(x=X, y=y, batch=batch)




# Calculate the fitted values (given the factors):

batchun <- levels(batch)
Xpred <- matrix(nrow=nrow(X), ncol=ncol(X), data=0)

# Batch-specific means:
for(i in 1:length(batchun))
  Xpred[batch==i,] <-  sweep(Xpred[batch==i,], 2, params$meanbatches[i,], "+")

# Add "Cross-validated class signal":
signalterm <- sweep(params$proba%*%t(params$m2 - params$m1), 2, params$m1, "+")
for(i in 1:length(batchun))
  Xpred[batch==i,] <-  Xpred[batch==i,] + sweep(signalterm[batch==i,], 2, params$sdbatches[[i]], "*")

  
# Add factor influences:
for(i in 1:length(batchun)) {
  if(!is.null(params$falist[[i]]))
    Xpred[batch==i,] <- Xpred[batch==i,] + sweep(params$falist[[i]]$Factors%*%t(params$falist[[i]]$B), 2, params$sdbatches[[i]], "*")
}




# Plot of real vs. fitted values:
#################################

library("MASS")
library("ggplot2")

x <- y <- batchind <- c()

for(i in 1:length(batchun)) {
xtemp <- as.vector(Xpred[batch==i,])
ytemp <- as.vector(X[batch==i,])
x <- c(x, xtemp)
y <- c(y, ytemp)
batchind <- c(batchind, rep(batchun[i], length(xtemp)))
}
batchind <- factor(batchind, levels=batchun)
data <- data.frame(x=x, y=y, batch=batchind)

ranind <- as.vector(sapply(1:length(batchun), function(x) sort(sample(which(data$batch==x), size=1000))))
datasub <- data[ranind,]

dataline <- data.frame(x=rep(NA, 50*length(batchun)), y=rep(NA, 50*length(batchun)), batch=factor(rep(batchun, each=50), levels=batchun))

for(i in 1:length(batchun)) {
predtemp <- loess.smooth(data$x[data$batch==i], data$y[data$batch==i], evaluation=50)
dataline$x[dataline$batch==i] <- predtemp$x
dataline$y[dataline$batch==i] <- predtemp$y
}

ngrid <- 50

densdfall <- data.frame(x=rep(NA, (ngrid^2)*length(batchun)), y=rep(NA, (ngrid^2)*length(batchun)), 
  z=rep(NA, (ngrid^2)*length(batchun)), batch=factor(rep(batchun, each=(ngrid^2)), levels=batchun))

for(i in 1:length(batchun)) {

    lims <- c(min(data$x[data$batch==i])-sd(data$x[data$batch==i]), max(data$x[data$batch==i])+sd(data$x[data$batch==i]), min(data$y[data$batch==i])-
      sd(data$y[data$batch==i]), max(data$y[data$batch==i])+sd(data$y[data$batch==i]))
  dens2 <- kde2d(data$x[data$batch==i], data$y[data$batch==i], lims=lims, n = ngrid)
  densdf <- expand.grid(dens2$x, dens2$y)[,2:1]
  densdf$z <- as.vector(dens2$z)
  names(densdf) <- c("y", "x", "z")

  densdfall$x[densdfall$batch==i] <- densdf$x
  densdfall$y[densdfall$batch==i] <- densdf$y
  densdfall$z[densdfall$batch==i] <- densdf$z

}

levels(datasub$batch) <- paste("batch", levels(datasub$batch))
levels(dataline$batch) <- paste("batch", levels(dataline$batch))
levels(densdfall$batch) <- paste("batch", levels(densdfall$batch))

p <- ggplot(data=datasub, aes(x=x, y=y)) + geom_point(col="grey") +
  geom_line(data=dataline, aes(x=x, y=y), col="red") + geom_abline(intercept = 0, slope = 1, linetype="dashed") + # geom_hline(aes(yintercept=0), linetype="dashed") +
  stat_contour(data=densdfall, aes(x=x, y=y, z=z), bins=10, size=0.1) + 
  labs(x = "fitted values (given factors)", y = "data values") + facet_wrap(~ batch, nrow=3, ncol=2) +
  theme(axis.title=element_text(size=14), axis.text=element_text(size=12), strip.text.x=element_text(size=13))
p

ggsave(file="./FAbatchPaper/Results/SupplementaryFigure1.pdf", width=8, height=9.1)





# Plot of residuals vs. fitted values:
######################################

library("MASS")
library("ggplot2")

x <- y <- batchind <- c()

for(i in 1:length(batchun)) {
xtemp <- as.vector(Xpred[batch==i,])
ytemp <- as.vector((X-Xpred)[batch==i,])
x <- c(x, xtemp)
y <- c(y, ytemp)
batchind <- c(batchind, rep(batchun[i], length(xtemp)))
}
batchind <- factor(batchind, levels=batchun)
data <- data.frame(x=x, y=y, batch=batchind)

ranind <- as.vector(sapply(1:length(batchun), function(x) sort(sample(which(data$batch==x), size=1000))))
datasub <- data[ranind,]

dataline <- data.frame(x=rep(NA, 50*length(batchun)), y=rep(NA, 50*length(batchun)), batch=factor(rep(batchun, each=50), levels=batchun))

for(i in 1:length(batchun)) {
predtemp <- loess.smooth(data$x[data$batch==i], data$y[data$batch==i], evaluation=50)
dataline$x[dataline$batch==i] <- predtemp$x
dataline$y[dataline$batch==i] <- predtemp$y
}

ngrid <- 50

densdfall <- data.frame(x=rep(NA, (ngrid^2)*length(batchun)), y=rep(NA, (ngrid^2)*length(batchun)), 
  z=rep(NA, (ngrid^2)*length(batchun)), batch=factor(rep(batchun, each=(ngrid^2)), levels=batchun))

for(i in 1:length(batchun)) {

  # if(is.null(lims))
    lims <- c(min(data$x[data$batch==i])-sd(data$x[data$batch==i]), max(data$x[data$batch==i])+sd(data$x[data$batch==i]), min(data$y[data$batch==i])-
      sd(data$y[data$batch==i]), max(data$y[data$batch==i])+sd(data$y[data$batch==i]))
  dens2 <- kde2d(data$x[data$batch==i], data$y[data$batch==i], lims=lims, n = ngrid)
  densdf <- expand.grid(dens2$x, dens2$y)[,2:1]
  densdf$z <- as.vector(dens2$z)
  names(densdf) <- c("y", "x", "z")

  densdfall$x[densdfall$batch==i] <- densdf$x
  densdfall$y[densdfall$batch==i] <- densdf$y
  densdfall$z[densdfall$batch==i] <- densdf$z

}

levels(datasub$batch) <- paste("batch", levels(datasub$batch))
levels(dataline$batch) <- paste("batch", levels(dataline$batch))
levels(densdfall$batch) <- paste("batch", levels(densdfall$batch))

p <- ggplot(data=datasub, aes(x=x, y=y)) + geom_point(col="grey") +
  geom_line(data=dataline, aes(x=x, y=y), col="red") + geom_hline(aes(yintercept=0), linetype="dashed") +
  stat_contour(data=densdfall, aes(x=x, y=y, z=z), bins=10, size=0.1) + 
  labs(x = "fitted values (given factors)", y = "residuals") + facet_wrap(~ batch, nrow=3, ncol=2) +
  theme(axis.title=element_text(size=14), axis.text=element_text(size=12), strip.text.x=element_text(size=13))
p

ggsave(file="./FAbatchPaper/Results/SupplementaryFigure3.pdf", width=8, height=9.1)






# Plot of the density estimates of the residuals:
#################################################

x <- y <- batchind <- c()

for(i in 1:length(batchun)) {
xtemp <- as.vector(Xpred[batch==i,])
ytemp <- as.vector(scale((X-Xpred)[batch==i,], center=FALSE, scale=TRUE))
x <- c(x, xtemp)
y <- c(y, ytemp)
batchind <- c(batchind, rep(batchun[i], length(xtemp)))
}
batchind <- factor(batchind, levels=batchun)
data <- data.frame(x=x, y=y, batch=batchind)


levels(data$batch) <- paste("batch", levels(data$batch))

p <- ggplot(data=data, aes(x = y)) + geom_line(stat="density") + 
  facet_wrap(~ batch, nrow=3, ncol=2) + geom_vline(xintercept=0, linetype="dashed", size=0.3) + 
  theme(axis.title.x=element_blank(), axis.title=element_text(size=14), 
  axis.text=element_text(size=12), strip.text.x=element_text(size=13)) + xlim(-9, 9) + ylim(0, 0.45)
p

ggsave(file="./FAbatchPaper/Results/SupplementaryFigure5.pdf", width=8, height=9.1)







############
#  ComBat  #
############


# Set seed to enable reproduction:
set.seed(1234)


# 'combatba'-function from the bapred-package which was altered to return
#  additional information necessary for calculating the fitted values:

combatba <- function(x, batch) {

  if(any(is.na(x)))
	stop("Data contains missing values.")
  if(!(is.factor(batch) & all(levels(batch)==(1:length(levels(batch))))))
    stop("'batch' has to be of class 'factor' with levels '1','2',...")  
  if(!is.matrix(x))
    stop("'x' has to be of class 'matrix'.") 

    mod=NULL; numCovs = NULL; par.prior=TRUE; prior.plots=FALSE

    if(length(unique(batch))==1) {
      params <- list(xadj=x, meanoverall=colMeans(x), var.pooled=matrix(nrow=ncol(x), ncol=1, data=apply(x, 2, var)*(1 - 1/nrow(x))))
      params$batch <- batch
	  params$nbatches <- length(unique(batch))
	
      class(params) <- "combat"
	
      return(params)    
	}
	  
	  
	mod = cbind(mod,batch)
	
	# check for intercept, and drop if present
	check = apply(mod, 2, function(x) all(x == 1))
	mod = as.matrix(mod[,!check])
	
	colnames(mod)[ncol(mod)] = "Batch"
	
	if(sum(check) > 0 & !is.null(numCovs)) numCovs = numCovs-1
	
	design <- design.mat(mod,numCov = numCovs)	

	batches <- list.batch(mod)
	n.batch <- length(batches)
	n.batches <- sapply(batches, length)
	n.array <- sum(n.batches)
	
	## Check for missing values
	NAs = any(is.na(x))
	if(NAs){stop(paste("Found", sum(is.na(x)), "Missing data Values"))}
        #print(x[1:2,])
	##standardize data across genes
	# cat('Standardizing data across genes\n')
	if (!NAs){B.hat <- solve(t(design)%*%design)%*%t(design)%*%as.matrix(x)}
    ## if (NAs){B.hat=apply(t(x),1,Beta.NA,design)} #Standarization Model
	grand.mean <- t(n.batches/n.array)%*%B.hat[1:n.batch,]
	if (!NAs){var.pooled <- ((t(x)-t(design%*%B.hat))^2)%*%rep(1/n.array,n.array)}
      if (NAs){var.pooled <- apply(t(x)-t(design%*%B.hat),1,var,na.rm=T)}

	meanoverall <- t(grand.mean)%*%t(rep(1,n.array))
	if(!is.null(design)){tmp <- design;tmp[,c(1:n.batch)] <- 0;meanoverall <- meanoverall+t(tmp%*%B.hat)}	
	s.data <- (t(x)-meanoverall)/(sqrt(var.pooled)%*%t(rep(1,n.array)))

	##Get regression batch effect parameters
	# cat("Fitting L/S model and finding priors\n")
	batch.design <- design[,1:n.batch]
	if (!NAs){
		gamma.hat <- solve(t(batch.design)%*%batch.design)%*%t(batch.design)%*%t(as.matrix(s.data))
	} 
    ##  if (NAs){
	##	gamma.hat=apply(s.data,1,Beta.NA,batch.design)
	##}
	
	delta.hat <- NULL
	for (i in batches){
		delta.hat <- rbind(delta.hat,apply(s.data[,i], 1, var,na.rm=T))
		}

	##Find Priors
	gamma.bar <- apply(gamma.hat, 1, mean)
	t2 <- apply(gamma.hat, 1, var)
	a.prior <- apply(delta.hat, 1, aprior)
	b.prior <- apply(delta.hat, 1, bprior)

	
	##Plot empirical and parametric priors

	if (prior.plots & par.prior){
		par(mfrow=c(2,2))
		tmp <- density(gamma.hat[1,])
		plot(tmp,  type='l', main="Density Plot")
		xx <- seq(min(tmp$x), max(tmp$x), length=100)
		lines(xx,dnorm(xx,gamma.bar[1],sqrt(t2[1])), col=2)
		qqnorm(gamma.hat[1,])	
		qqline(gamma.hat[1,], col=2)	
	
		tmp <- density(delta.hat[1,])
		invgam <- 1/rgamma(ncol(delta.hat),a.prior[1],b.prior[1])
		tmp1 <- density(invgam)
		plot(tmp,  typ='l', main="Density Plot", ylim=c(0,max(tmp$y,tmp1$y)))
		lines(tmp1, col=2)
		qqplot(delta.hat[1,], invgam, xlab="Sample Quantiles", ylab='Theoretical Quantiles')	
		lines(c(0,max(invgam)),c(0,max(invgam)),col=2)	
		title('Q-Q Plot')
	}
	
	##Find EB batch adjustments

	gamma.star <- delta.star <- NULL
	if(par.prior){
		# cat("Finding parametric adjustments\n")
		for (i in 1:n.batch){
			temp <- it.sol(s.data[,batches[[i]]],gamma.hat[i,],
				delta.hat[i,],gamma.bar[i],t2[i],a.prior[i],b.prior[i])
			gamma.star <- rbind(gamma.star,temp[1,])
			delta.star <- rbind(delta.star,temp[2,])
			}
	}
	##else{
	##	# cat("Finding nonparametric adjustments\n")
	##	for (i in 1:n.batch){
	##		temp <- int.eprior(as.matrix(s.data[,batches[[i]]]),gamma.hat[i,],delta.hat[i,])
	##		gamma.star <- rbind(gamma.star,temp[1,])
	##		delta.star <- rbind(delta.star,temp[2,])
	##		}
	##	}


	### Normalize the data ###
	# cat("Adjusting the data\n")

	bayesdata <- s.data
	j <- 1
	for (i in 1:length(batches)){
            id = batches[[i]]
		bayesdata[,id] <- (bayesdata[,id]-t(batch.design[id,]%*%gamma.star))/(sqrt(delta.star[j,])%*%t(rep(1,n.batches[j])))
		j <- j+1
		}

	bayesdata <- (bayesdata*(sqrt(var.pooled)%*%t(rep(1,n.array))))+meanoverall
	
    params <- list(xadj=t(bayesdata), meanoverall=meanoverall[,1], var.pooled=var.pooled,
	  gammastar=gamma.star)
	params$batch <- batch
	params$nbatches <- length(unique(batch))
	
    class(params) <- "combat"
	
    return(params)

}


# Apply the altered combatba-function to the dataset:

params <- combatba(x=X, batch=batch)




# Calculate the fitted values:

batchun <- levels(batch)
Xpred <- matrix(nrow=nrow(X), ncol=ncol(X), data=0)


# Batch-specific means:
Xpred <- sweep(Xpred, 2, params$meanoverall, "+")
for(i in 1:length(batchun))
  Xpred[batch==i,] <- sweep(Xpred[batch==i,], 2, params$gammastar[i,]*sqrt(as.numeric(params$var.pooled)), "+")




# Plot of real vs. fitted values:
#################################

library("MASS")
library("ggplot2")

x <- y <- batchind <- c()

for(i in 1:length(batchun)) {
xtemp <- as.vector(Xpred[batch==i,])
ytemp <- as.vector(X[batch==i,])
x <- c(x, xtemp)
y <- c(y, ytemp)
batchind <- c(batchind, rep(batchun[i], length(xtemp)))
}
batchind <- factor(batchind, levels=batchun)
data <- data.frame(x=x, y=y, batch=batchind)

ranind <- as.vector(sapply(1:length(batchun), function(x) sort(sample(which(data$batch==x), size=1000))))
datasub <- data[ranind,]

dataline <- data.frame(x=rep(NA, 50*length(batchun)), y=rep(NA, 50*length(batchun)), batch=factor(rep(batchun, each=50), levels=batchun))

for(i in 1:length(batchun)) {
predtemp <- loess.smooth(data$x[data$batch==i], data$y[data$batch==i], evaluation=50)
dataline$x[dataline$batch==i] <- predtemp$x
dataline$y[dataline$batch==i] <- predtemp$y
}

ngrid <- 50

densdfall <- data.frame(x=rep(NA, (ngrid^2)*length(batchun)), y=rep(NA, (ngrid^2)*length(batchun)), 
  z=rep(NA, (ngrid^2)*length(batchun)), batch=factor(rep(batchun, each=(ngrid^2)), levels=batchun))

for(i in 1:length(batchun)) {

  # if(is.null(lims))
    lims <- c(min(data$x[data$batch==i])-sd(data$x[data$batch==i]), max(data$x[data$batch==i])+sd(data$x[data$batch==i]), min(data$y[data$batch==i])-
      sd(data$y[data$batch==i]), max(data$y[data$batch==i])+sd(data$y[data$batch==i]))
  dens2 <- kde2d(data$x[data$batch==i], data$y[data$batch==i], lims=lims, n = ngrid)
  densdf <- expand.grid(dens2$x, dens2$y)[,2:1]
  densdf$z <- as.vector(dens2$z)
  names(densdf) <- c("y", "x", "z")

  densdfall$x[densdfall$batch==i] <- densdf$x
  densdfall$y[densdfall$batch==i] <- densdf$y
  densdfall$z[densdfall$batch==i] <- densdf$z

}

levels(datasub$batch) <- paste("batch", levels(datasub$batch))
levels(dataline$batch) <- paste("batch", levels(dataline$batch))
levels(densdfall$batch) <- paste("batch", levels(densdfall$batch))

p <- ggplot(data=datasub, aes(x=x, y=y)) + geom_point(col="grey") +
  geom_line(data=dataline, aes(x=x, y=y), col="red") + geom_abline(intercept = 0, slope = 1, linetype="dashed") + # geom_hline(aes(yintercept=0), linetype="dashed") +
  stat_contour(data=densdfall, aes(x=x, y=y, z=z), bins=10, size=0.1) + 
  labs(x = "fitted values", y = "data values") + facet_wrap(~ batch, nrow=3, ncol=2) +
  theme(axis.title=element_text(size=14), axis.text=element_text(size=12), strip.text.x=element_text(size=13))
p

ggsave(file="./FAbatchPaper/Results/SupplementaryFigure2.pdf", width=8, height=9.1)





# Plot of residuals vs. fitted values:
######################################

library("MASS")
library("ggplot2")

x <- y <- batchind <- c()

for(i in 1:length(batchun)) {
xtemp <- as.vector(Xpred[batch==i,])
ytemp <- as.vector((X-Xpred)[batch==i,])
x <- c(x, xtemp)
y <- c(y, ytemp)
batchind <- c(batchind, rep(batchun[i], length(xtemp)))
}
batchind <- factor(batchind, levels=batchun)
data <- data.frame(x=x, y=y, batch=batchind)

ranind <- as.vector(sapply(1:length(batchun), function(x) sort(sample(which(data$batch==x), size=1000))))
datasub <- data[ranind,]

dataline <- data.frame(x=rep(NA, 50*length(batchun)), y=rep(NA, 50*length(batchun)), batch=factor(rep(batchun, each=50), levels=batchun))

for(i in 1:length(batchun)) {
predtemp <- loess.smooth(data$x[data$batch==i], data$y[data$batch==i], evaluation=50)
dataline$x[dataline$batch==i] <- predtemp$x
dataline$y[dataline$batch==i] <- predtemp$y
}

ngrid <- 50

densdfall <- data.frame(x=rep(NA, (ngrid^2)*length(batchun)), y=rep(NA, (ngrid^2)*length(batchun)), 
  z=rep(NA, (ngrid^2)*length(batchun)), batch=factor(rep(batchun, each=(ngrid^2)), levels=batchun))

for(i in 1:length(batchun)) {

  # if(is.null(lims))
    lims <- c(min(data$x[data$batch==i])-sd(data$x[data$batch==i]), max(data$x[data$batch==i])+sd(data$x[data$batch==i]), min(data$y[data$batch==i])-
      sd(data$y[data$batch==i]), max(data$y[data$batch==i])+sd(data$y[data$batch==i]))
  dens2 <- kde2d(data$x[data$batch==i], data$y[data$batch==i], lims=lims, n = ngrid)
  densdf <- expand.grid(dens2$x, dens2$y)[,2:1]
  densdf$z <- as.vector(dens2$z)
  names(densdf) <- c("y", "x", "z")

  densdfall$x[densdfall$batch==i] <- densdf$x
  densdfall$y[densdfall$batch==i] <- densdf$y
  densdfall$z[densdfall$batch==i] <- densdf$z

}

levels(datasub$batch) <- paste("batch", levels(datasub$batch))
levels(dataline$batch) <- paste("batch", levels(dataline$batch))
levels(densdfall$batch) <- paste("batch", levels(densdfall$batch))

p <- ggplot(data=datasub, aes(x=x, y=y)) + geom_point(col="grey") +
  geom_line(data=dataline, aes(x=x, y=y), col="red") + geom_hline(aes(yintercept=0), linetype="dashed") +
  stat_contour(data=densdfall, aes(x=x, y=y, z=z), bins=10, size=0.1) + 
  labs(x = "fitted values", y = "residuals") + facet_wrap(~ batch, nrow=3, ncol=2) +
  theme(axis.title=element_text(size=14), axis.text=element_text(size=12), strip.text.x=element_text(size=13))
p

ggsave(file="./FAbatchPaper/Results/SupplementaryFigure4.pdf", width=8, height=9.1)






# Plot of the density estimates of the residuals:
#################################################

x <- y <- batchind <- c()

for(i in 1:length(batchun)) {
xtemp <- as.vector(Xpred[batch==i,])
ytemp <- as.vector(scale((X-Xpred)[batch==i,], center=FALSE, scale=TRUE))
x <- c(x, xtemp)
y <- c(y, ytemp)
batchind <- c(batchind, rep(batchun[i], length(xtemp)))
}
batchind <- factor(batchind, levels=batchun)
data <- data.frame(x=x, y=y, batch=batchind)

levels(data$batch) <- paste("batch", levels(data$batch))

p <- ggplot(data=data, aes(x = y)) + geom_line(stat="density") + 
  facet_wrap(~ batch, nrow=3, ncol=2) + geom_vline(xintercept=0, linetype="dashed", size=0.3) + 
  theme(axis.title.x=element_blank(), axis.title=element_text(size=14), 
  axis.text=element_text(size=12), strip.text.x=element_text(size=13)) + xlim(-9, 9) + ylim(0, 0.45)
p

ggsave(file="./FAbatchPaper/Results/SupplementaryFigure6.pdf", width=8, height=9.1)
