#! /usr/bin/env python

"""
%prog mcfile.aida [outputfile.aida]
%prog --no-rivet-refs datafile.aida mcfile.aida [outputfile.aida]

Remove bins in Rivet-generated MC AIDA files which are actually binning gaps in
the reference histogram. Rivet's histogramming system currently has no way to
leave gaps between bins, hence this clean-up script.

If the output file is not specified, the input MC file will be overwritten.
"""

import sys
if sys.version_info[:3] < (2,4,0):
    print "rivet scripts require Python version >= 2.4.0... exiting"
    sys.exit(1)


import os, tempfile

class Inputdata:
    def __init__(self, filenames):
        self.histos = {}
        self.description = {}
        self.description['DrawOnly'] = []
        if not hasattr(filenames, "__iter__"):
            filenames = [filenames]
        for filename in filenames:
            f = open(filename+'.dat', 'r')
            for line in f:
                if (line.count('#',0,1)):
                    if (line.count('BEGIN HISTOGRAM')):
                        title = line.split('BEGIN HISTOGRAM', 1)[1].strip()
                        self.description['DrawOnly'].append(title)
                        self.histos[title] = Histogram(f)
            f.close()


class Histogram:
    def __init__(self, f):
        self.read_input(f)

    def read_input(self, f):
        self.description = {}
        self.data = []
        for line in f:
            if (line.count('#',0,1)):
                if (line.count('END HISTOGRAM')):
                    break
            else:
                line = line.rstrip()
                if line.count('='):
                    linearray = line.split('=', 1)
                    self.description[linearray[0]] = linearray[1]
                else:
                    linearray = line.split('\t')
                    if len(linearray) == 4:
                        self.data.append({'LowEdge': float(linearray[0]),
                                          'UpEdge':  float(linearray[1]),
                                          'Content': float(linearray[2]),
                                          'Error':   [float(linearray[3]),float(linearray[3])]})
                    elif len(linearray) == 5:
                        self.data.append({'LowEdge': float(linearray[0]),
                                          'UpEdge':  float(linearray[1]),
                                          'Content': float(linearray[2]),
                                          'Error':   [float(linearray[3]),float(linearray[4])]})

    def write_datapoint(self, f, xval, xerr, yval, yerr):
        f.write('    <dataPoint>\n')
        f.write('      <measurement errorPlus="%e" value="%e" errorMinus="%e"/>\n' %(xerr, xval, xerr))
        f.write('      <measurement errorPlus="%e" value="%e" errorMinus="%e"/>\n' %(yerr[1], yval, yerr[0]))
        f.write('    </dataPoint>\n')

    def write_datapointset_header(self, f):
        title = self.description.setdefault('Title', '')
        xlabel = self.description.setdefault('XLabel', '')
        ylabel = self.description.setdefault('YLabel', '')
        path = self.description.setdefault('AidaPath', None)
        if path is not None:
            path = path.replace('>', '&gt;').replace('<', '&lt;').replace('"', '&quot;')
        f.write('  <dataPointSet name="%s" dimension="2"\n' % path.split('/')[-1])
        f.write('    path="%s" title="%s">\n' % (os.path.abspath(path.replace(path.split('/')[-1], '')),
                                                 title.replace('>', '&gt;').replace('<', '&lt;').replace('"', '&quot;')))
        f.write('    <annotation>\n')
        f.write('      <item key="Title" value="%s" sticky="true"/>\n' % title.replace('>', '&gt;').replace('<', '&lt;').replace('"', '&quot;'))
        f.write('      <item key="XLabel" value="%s" sticky="true"/>\n' % xlabel.replace('>', '&gt;').replace('<', '&lt;').replace('"', '&quot;'))
        f.write('      <item key="YLabel" value="%s" sticky="true"/>\n' % ylabel.replace('>', '&gt;').replace('<', '&lt;').replace('"', '&quot;'))
        f.write('      <item key="AidaPath" value="%s" sticky="true"/>\n' % path)
        f.write('      <item key="FullPath" value="/%s.aida%s" sticky="true"/>\n' % (filename.split('/')[-1], path))
        f.write('    </annotation>\n')
        f.write('    <dimension dim="0" title="%s" />\n' % xlabel.replace('>', '&gt;').replace('<', '&lt;').replace('"', '&quot;'))
        f.write('    <dimension dim="1" title="%s" />\n' % ylabel.replace('>', '&gt;').replace('<', '&lt;').replace('"', '&quot;'))

    def write_datapointset_footer(self, f):
        f.write('  </dataPointSet>\n')

    def write_datapointset(self, f):
        self.write_datapointset_header(f)
        for bin, bindata in enumerate(self.data):
            xval = 0.5*(bindata['UpEdge'] + bindata['LowEdge'])
            #if bindata['UpEdge'] == bindata['LowEdge']:
            #    xerr = 0.5
            #else:
            #    xerr = 0.5*(bindata['UpEdge'] - bindata['LowEdge'])
            xerr = 0.5*(bindata['UpEdge'] - bindata['LowEdge'])
            yval = bindata['Content']
            yerr = bindata['Error']
            self.write_datapoint(f, xval, xerr, yval, yerr)
        self.write_datapointset_footer(f)

    def remove_gaps(self):
        ## Only look at histograms which are present in the reference file:
        try:
            refhist = refdata.histos['/REF%s' % self.description['AidaPath']]
        except:
            return

        ## Check for differences in the binning and remove superfluous MC bins:
        if len(refhist.data) != len(self.data):
            numrm = abs(len(self.data) - len(refhist.data))
            if numrm != 1:
                plural = "s"
            else:
                plural = ""
            logging.info("Stripping %d bin%s from %s" % (numrm, plural, self.description['AidaPath']))
            newdata = []
            for i in xrange(len(self.data)):
                if (len(refhist.data) == i):  ## MC has additional bins above the upper edge of the refdata
                    break
                if self.data[i]['LowEdge'] == refhist.data[i]['LowEdge'] and \
                   self.data[i]['UpEdge'] == refhist.data[i]['UpEdge']:
                    newdata.append(self.data[i])
                else:
                    logging.debug('Deleted bin %d' % i)
                    refhist.data.insert(i, self.data[i])
            self.data = newdata


## Command line parsing
import logging
from optparse import OptionParser
parser = OptionParser(usage=__doc__)
parser.add_option("-R", "--rivet-refs", dest="USE_RIVETREFS", action="store_true", default=True,
                  help="use the Rivet reference data files for comparison (default)")
parser.add_option("--no-rivet-refs", dest="USE_RIVETREFS", action="store_false", default=True,
                  help="don't use the Rivet reference data files for comparison")
parser.add_option("-v", "--verbose", action="store_const", const=logging.DEBUG, dest="LOGLEVEL",
                  default=logging.INFO, help="print debug (very verbose) messages")
parser.add_option("-q", "--quiet", action="store_const", const=logging.WARNING, dest="LOGLEVEL",
                  default=logging.INFO, help="be very quiet")
opts, args = parser.parse_args()

## Configure logging
logging.basicConfig(level=opts.LOGLEVEL, format="%(message)s")


if opts.USE_RIVETREFS:
    try:
        import rivet
        rivet.getAnalysisRefPaths()
    except:
        logging.error("Could not find the Rivet ref paths because the 'rivet' Python module could not be loaded")
        sys.exit(1)


if opts.USE_RIVETREFS:
    if len(args) < 1:
        logging.error("Must specify at least the MC input file")
        sys.exit(1)
    if len(args) >= 1:
        REFFILES = []
        INFILE = args[0]
        OUTFILE = args[0]
    if len(args) == 2:
        OUTFILE = args[1]
    if len(args) > 2:
        logging.error("Maximum of two arguments with the -R argument flag!")
        sys.exit(1)
else:
    if len(args) < 2:
        logging.error("Must specify at least the reference file and the MC input file (or use the -R option)")
        sys.exit(1)
    if len(args) >= 2:
        REFFILES = [args[0]]
        INFILE = args[1]
        OUTFILE = args[1]
    if len(args) == 3:
        OUTFILE = args[2]
    if len(args) > 3:
        logging.error("Maximum of three arguments with the -R argument flag!")
        sys.exit(1)


## Convert the aida input files to flat files we can parse:
tempdir = tempfile.mkdtemp('.gap_removal')

filename = INFILE.replace(".aida", "")
os.system("%s/aida2flat %s.aida > %s/%s.dat" % \
              (os.path.dirname(os.path.realpath(sys.argv[0])),
               filename, tempdir, os.path.basename(filename)))
mcdata = Inputdata(os.path.join(tempdir, os.path.basename(filename)))


## Build list of Rivet ref files
if opts.USE_RIVETREFS:
    anas = set()
    for hpath in mcdata.description['DrawOnly']:
        ana = hpath[1:].split("/")[0]
        anas.add(ana)
    #print anas
    for a in anas:
        apath = rivet.findAnalysisRefFile(a+".aida")
        if apath:
            REFFILES.append(apath)
            break
    #print REFFILES


## Run over ref files
refdatfiles = []
for rf in REFFILES:
    filename = rf.replace(".aida", "")
    os.system("%s/aida2flat %s.aida > %s/%s.dat" % \
                  (os.path.dirname(os.path.realpath(sys.argv[0])),
                   filename, tempdir, os.path.basename(filename)))
    refdatfiles.append( os.path.join(tempdir, os.path.basename(filename)) )
refdata = Inputdata(refdatfiles)


## Clean up
for i in os.listdir(tempdir):
    os.unlink('%s/%s' %(tempdir, i))
os.rmdir(tempdir)


## Remove gap bins
for i in mcdata.description['DrawOnly']:
    mcdata.histos[i].remove_gaps()


## Write the new aida file with removed gap bins:
f = open(OUTFILE, 'w')
f.write('<?xml version="1.0" encoding="ISO-8859-1" ?>\n')
f.write('<!DOCTYPE aida SYSTEM "http://aida.freehep.org/schemas/3.3/aida.dtd">\n')
f.write('<aida version="3.3">\n')
f.write('  <implementation version="1.1" package="FreeHEP"/>\n')
for i in mcdata.description['DrawOnly']:
    mcdata.histos[i].write_datapointset(f)
f.write('</aida>\n')
f.close
