#!/usr/bin/env Rscript
# Copyright (c) 2011-2012, Universite de Versailles St-Quentin-en-Yvelines
#
# This file is part of ASK.  ASK is free software: you can redistribute
# it and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation, version 2.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.


suppressPackageStartupMessages(require(tgp, quietly=T))

args <- commandArgs(trailingOnly = T)

if (length(args) != 4) {
    stop("Usage: tgp_predict <configuration.conf> <model> <requested_file> <output_file>")
}

source(file.path(Sys.getenv("ASKHOME"), "common/configuration.R"))
# read the requested points
req = read.table(args[3])

# Set up categorical variables as factors
configuration = fromJSON(paste(readLines(args[1]), collapse=""))
i = 0
for (f in conf("factors")) {
  i = i + 1
#  if (f$type == "categorical") {
#    stop("categorical data not yet supported in tgp module")
#    #req[,i] = as.factor(req[,i])
#  }
}
nfactors = i

output = args[4]
# Load gbm model
load(args[2])

requested = req[,1:nfactors]

i = 0

# The gbm prediction must work in chunks.
# If passed a data.frame too large tgp package complains.
chunk_size=1000
step=chunk_size/nrow(requested)*100
progress= 0
while(T) {
    from = i*chunk_size+1
    to = min((i+1)*chunk_size,nrow(requested))
    r = requested[from:to,]

    pred = predict(tgp1, r, pred.n=F, zcov=F, krige=F)
    out_data = cbind(r, pred$ZZ.mean)
    write.table(out_data, file=output, row.names=F, col.names=F, append=T)

    if (to == nrow(requested)) break;
    i = i + 1
    progress = progress + step
    if (round(progress,0) %% 10 == 0 && round(progress-step) %% 10 != 0) {
        print(paste("Progress", round(progress,0), "%"))
    }
}
