#!/usr/bin/env Rscript
# Copyright (c) 2011-2012, Universite de Versailles St-Quentin-en-Yvelines
#
# This file is part of ASK.  ASK is free software: you can redistribute
# it and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation, version 2.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.


require(inline)
suppressPackageStartupMessages(require(gbm, quietly=T))
# We want the inline package for inlining C uniform_sampling code
suppressPackageStartupMessages(require(inline, quietly=T))
# We want the lhs method inside tgp
suppressPackageStartupMessages(require(tgp, quietly=T))

#### Define some utility functions
uniform_sampling <- function(variance, nsamples, available) {
  samples = rep(-1, nsamples)
  #compute distance between all the pairs of available points
  dist = as.matrix(dist(available[,1:ncol(available)-1], diag=T))

  # Inline the uniform sampling scheme that is written in C,
  # the R version was too slow.
  sig_uniform_sampling = signature(vsize="integer",
                                   ssize="integer",
                                   variance="integer",
                                   samples="integer",
                                   dist="double",
                                   max_distance="double")
  code_uniform_sampling = "
        int vs = vsize[0];
        int ss = ssize[0];
        double max_delta = max_distance[0];
        double min_delta = 0;

        /* Find optimal delta by dichotomy */
        while (1) {

        double delta = (max_delta + min_delta) / 2.0;
        if ((max_delta - min_delta) < RESOLUTION) {
            int decision =
            try_delta(min_delta, vs, ss, variance, samples, dist);
            assert(decision >= 0);
            return;
        }
        int decision = try_delta(delta, vs, ss, variance, samples, dist);
        if (decision == 0) {	/* All done */
            break;
        } else if (decision == -1) {	/* decrease delta */
            max_delta = delta;
        } else if (decision == 1) {	/* increase delta */
            min_delta = delta;
        }
        }
    "

  code_try_delta = "
        int
        try_delta(double delta, int vs, int ss, int *variance,
                int *samples, double *dist)
        {

        int i, j;
        /* Always select the first sample (largest variance) */
        samples[0] = variance[0] - 1;
        int count = 1;

        for (i = 1; i < vs; i++) {
        int next = variance[i] - 1;
        /* Check that the next point distance to all other sampled points
         * is above the threshold */
        int ok = 1;
        for (j = 0; j < count; j++) {
            if (dist[next + samples[j] * vs] < delta) {
            ok = 0;
            break;
            }
        }

        /* If that is not the case continue to next point */
        if (!ok)
            continue;

        /* If that is the case, add the point to the selected samples list */
        samples[count] = next;
        count++;

        /* Did we select all the needed samples ? */
        if (count == ss) {
            if (i == vs - 1)
            return 0;
            else
            return 1;
            /* We reached ss points, we need to increase delta */
        }
        }

        /* We did not reach enough points, we need to decrease delta */
        return -1;
        }
    "

  fns = cfunction(list(uniform_sampling=sig_uniform_sampling),
                  list(code_uniform_sampling),
                  convention=".C",
                  includes=c("#define RESOLUTION 0.00000000001",
                             "#include <assert.h>"),
                  otherdefs=c(code_try_delta)
                 )

  out = fns[["uniform_sampling"]](
             vsize = as.integer(length(variance)),
             ssize = as.integer(nsamples),
             variance = as.integer(variance),
             samples = as.integer(samples),
             dist = as.double(dist),
             max_dist = as.double(max(dist)))

  return(out$samples+1)
}

# Adaptive mart method from YU,LI,FANG,PENG 09
Amart <- function(sampled, available, formula, ssample, ntree, NRS) {
  pred = vector("list", NRS)
  print("Building the trees ...")
  for (see in c(1:NRS)) {
    print(paste("Tree ", see, "/", NRS))
    # select a new random seed
    set.seed(floor(runif(1,1,2**16)))

    # build a tree
    gbm1 <- gbm(formula, data=sampled, distribution="gaussian",
                n.trees=ntree,
                interaction.depth=8, shrinkage=0.01, verbose=F)
    best.iter <- gbm.perf(gbm1,method="OOB",plot.it=FALSE)
    # generate predictions with the tree on the non sampled points
    pred[[see]] <- predict(gbm1, available, n.trees=best.iter)
  }
  # Compute for all the different predictions the CoV (coefficient of Variance)
  print("Compute prediction CoV")
  m <- do.call(rbind,pred)
  variance_list <- apply(m,2,sd)
  mean_list <- apply(m,2,mean)
  cov_list <- variance_list/mean_list
  covf <- data.frame(cl=cov_list, ind=c(1:nrow(available)))

  #draw new samples using the adaptive sampling method
  print("Adaptive sampling")
  covf.sorted = covf[order(covf$cl, decreasing=TRUE),]
  smp_indexes = uniform_sampling(covf.sorted[,2], ssample, available)
  new_samples = covf.sorted[smp_indexes,2]
  return (available[new_samples,])
}

# Parse the arguments
args <- commandArgs(trailingOnly = T)

if (length(args) != 3) {
    stop("Usage: amart <configuration.conf> <input_file> <output_file>")
}
source(file.path(Sys.getenv("ASKHOME"), "common/configuration.R"))


# Read the data
data = read.table(args[2])
nsamples = conf("modules.sampler.params.n", 10)

if (conf("modules.sampler.params.predict_on_all", F)) {
	# Generate all available points and set up categorical variables
	# as factors
	lf = list()
	i = 1
	for (f in conf("factors")) {
	  if (f$type == "categorical") {
	    data[,i] = as.factor(data[,i])
	    r = f$values
	  } else {
	    r = f$range$min:f$range$max
	  }
	  lf = cbind(lf,list(r))
	  i = i+1
       }
	available = do.call(expand.grid, lf)
	available = eval(parse(text=paste("cbind(available, V",
			 ncol(data),"=rep(0, nrow(available)))", sep="")))

	# Rename avaible columns from Vari -> Vi so they are consistent
	# with the input
	i = 1
	for (f in conf("factors")) {
	  names(available)[names(available)==paste("Var",i,sep="")] <- paste("V",i,sep="")
	  i = i + 1
	}

} else {
	# Choose a 20*nsamples lhs set of points
	lf = list()
	mins = c()
	maxs = c()
	for (f in conf("factors")) {
	  if (f$type == "categorical") {
	    mins = c(mins, 0)
	    maxs = c(maxs, length(f$values))
	  } else {
	    mins = c(mins, f$range$min)
	    maxs = c(maxs, f$range$max+1)
	  }
        }

	available <- lhs(20*nsamples, cbind(mins,maxs))

	i = 1
	for (f in conf("factors")) {
	    if (f$type != "float") {
		available[,i] = as.integer(available[,i])
	    }
	    i = i+1
	}
	available = data.frame(available)
	available = eval(parse(text=paste("cbind(available, V",
			 ncol(data),"=rep(0, nrow(available)))", sep="")))

	# Rename avaible columns from Xi -> Vi so they are consistent
	# with the input
	i = 1
	for (f in conf("factors")) {
	  names(available)[names(available)==paste("X",i,sep="")] <- paste("V",i,sep="")
	  i = i + 1
	}
}

# Ensure in the 1D case that the first column is called V1
colnames(available)[1] = "V1"

# the regression formula
formula = formula(paste("V",ncol(data),"~.", sep=""))

newsamples <- Amart(data, available, formula,
                    ssample=nsamples,
                    ntree=conf("modules.sampler.params.trees", 3000),
                    NRS=conf("modules.sampler.params.seeds",20)
                    )

write.table(newsamples[,1:ncol(newsamples)-1],
            file=args[3],
            row.names=F,
            col.names=F)
