#! /usr/bin/env python

"""\
%prog - generate comparison plots

USAGE:
 %prog [options] aidafile1[:'PlotOption1=Value':'PlotOption2=Value'] [path/to/aidafile2 ...]

where the plot options are described in the make-plots manual in the HISTOGRAM
section.

TODO:
 * two ref datas shouldn't be possible... but it does happen!
 * regex selector
 * ask/force overwrite modes
"""

import sys
if sys.version_info[:3] < (2, 4, 0):
    print "rivet scripts require Python version >= 2.4.0... exiting"
    sys.exit(1)

def sanitiseString(s):
    #s = s.replace('_','\\_')
    #s = s.replace('^','\\^{}')
    #s = s.replace('$','\\$')
    s = s.replace('#','\\#')
    s = s.replace('%','\\%')
    return s

from lighthisto import Histo, PlotParser


## Try to load faster but non-standard cElementTree module
try:
    import xml.etree.cElementTree as ET
except ImportError:
    try:
        import cElementTree as ET
    except ImportError:
        try:
            import xml.etree.ElementTree as ET
        except:
            sys.stderr.write("Can't load the ElementTree XML parser: please install it!\n")
            sys.exit(1)


## Function to make output dirs
def mkoutdir(outdir):
    if not os.path.exists(outdir):
        try:
            os.makedirs(outdir)
        except:
            msg = "Can't make output directory '%s'" % outdir
            logging.error(msg)
            raise Exception(msg)
    if not os.access(outdir, os.W_OK):
        msg = "Can't write to output directory '%s'" % outdir
        logging.error(msg)
        raise Exception(msg)

class AIDAreadException(Exception): pass
    
def getHistos(aidafile):
    '''Get a dictionary of histograms indexed by name.'''
    if not re.match(r'.*\.aida$', aidafile):
        logging.error("Error: input file '%s' is not an AIDA file" % aidafile)
        raise AIDAreadException
    aidafilepath = os.path.abspath(aidafile)
    if not os.access(aidafilepath, os.R_OK):
        logging.error("Error: cannot read from %s" % aidafile)
        raise AIDAreadException

    histos, titles, xlabels, ylabels = {}, {}, {}, {}
    try:
        tree = ET.parse(aidafilepath)
    except:
        logging.error("Error: Cannot parse input file '%s' as AIDA file" % aidafile)
        raise AIDAreadException
       
    for dps in tree.findall("dataPointSet"):
        ## Get this histogram's path name
        dpsname = os.path.join(dps.get("path"), dps.get("name"))
        h = Histo.fromDPS(dps)
        ## Is it a data histo?
        h.isdata = dpsname.upper().startswith("/REF")
        if h.isdata:
            dpsname = dpsname.replace("/REF", "")
            if not titles.has_key(dpsname):
                titles[dpsname] = h.title
                xlabels[dpsname] = h.xlabel
                ylabels[dpsname] = h.ylabel
        else:
            if dpsname.count('/') > 2:
                dpsname = '/' + dpsname.split('/', 2)[-1]
            titles[dpsname] = h.title
            xlabels[dpsname] = h.xlabel
            ylabels[dpsname] = h.ylabel
        h.expt = dpsname.split("_")[0][1:]
        ## Hard-coded cosmetic handling for the D0 experiment name!
        if h.expt == "D0":
            h.expt = "D\O{}"
        histos[dpsname] = h
    return histos, titles, xlabels, ylabels


def getCommandLineOptions():
    ## Parse command line options
    from optparse import OptionParser, OptionGroup
    parser = OptionParser(usage=__doc__)
    parser.add_option("-R", "--rivet-refs", dest="RIVETREFS", action="store_true",
                      default=True, help="use Rivet reference data files (default)")
    parser.add_option("--no-rivet-refs", dest="RIVETREFS", action="store_false",
                      default=True, help="don't use Rivet reference data files")
    parser.add_option("-o", "--outdir", dest="OUTDIR",
                      default=".", help="write data files into this directory")
    parser.add_option("--hier-out", action="store_true", dest="HIER_OUTPUT", default=False,
                      help="write output dat files into a directory hierarchy which matches the analysis paths")
    parser.add_option("--plotinfodir", dest="PLOTINFODIR", action="append",
                      default=["."], help="directory which may contain plot header information (in addition "
                      "to standard Rivet search paths)")
    parser.add_option("--no-rmgapbins", dest="RMBINS", action="store_false",
                      default=True, help="disable attempting to remove 'gap' bins from MC histos when they "
                      "don't appear in the ref file")

    stygroup = OptionGroup(parser, "Plot style")
    stygroup.add_option("--refid", dest="REF_ID",
                        default="REF", help="ID of reference data set (file path for non-REF data)")
    stygroup.add_option("--linear", action="store_true", dest="LINEAR",
                        default=False, help="plot with linear scale")
    stygroup.add_option("--logarithmic", action="store_false", dest="LINEAR",
                        default=False, help="plot with logarithmic scale (default behaviour)")
    stygroup.add_option("--mc-errs", action="store_true", dest="MC_ERRS",
                        default=False, help="show vertical error bars on the MC lines")
    stygroup.add_option("--no-ratio", action="store_true", dest="NORATIO",
                        default=False, help="disable the ratio plot")
    stygroup.add_option("--rel-ratio", action="store_true", dest="RATIO_DEVIATION",
                        default=False, help="show the ratio plots scaled to the ref error")
    stygroup.add_option("--abs-ratio", action="store_false", dest="RATIO_DEVIATION",
                        default=False, help="show the ratio plots with an absolute scale")
    stygroup.add_option("--no-plottitle", action="store_true", dest="NOPLOTTITLE",
                        default=False, help="don't show the plot title on the plot "
                        "(useful when the plot description should only be given in a caption)")
    stygroup.add_option("--style", dest="STYLE", default="default",
                        help="change plotting style: default|bw|talk")
    stygroup.add_option("-c", "--config", dest="CONFIGFILES", action="append", default=["~/.make-plots"],
                        help="additional plot config file(s). Settings will be included in the output configuration.")
    parser.add_option_group(stygroup)

    selgroup = OptionGroup(parser, "Selective plotting")
    selgroup.add_option("--show-single", dest="SHOW_SINGLE", choices=("no", "ref", "mc", "all"),
                        default="mc", help="control if a plot file is made if there is only one dataset to be plotted "
                        "[default=%default]. If the value is 'no', single plots are always skipped, for 'ref' and 'mc', "
                        "the plot will be written only if the single plot is a reference plot or an MC "
                        "plot respectively, and 'all' will always create single plot files.\n The 'ref' and 'all' values "
                        "should be used with great care, as they will also write out plot files for all reference "
                        "histograms without MC traces: combined with the -R/--rivet-refs flag, this is a great way to "
                        "write out several thousand irrelevant reference data histograms!")
    selgroup.add_option("--show-mc-only", "--all", action="store_true", dest="SHOW_IF_MC_ONLY",
                        default=False, help="make a plot file even if there is only one dataset to be plotted and "
                        "it is an MC one. Deprecated and will be removed: use --show-single instead, which overrides this.")
    # selgroup.add_option("-l", "--histogram-list", dest="HISTOGRAMLIST",
    #                     default=None, help="specify a file containing a list of histograms to plot, in the format "
    #                     "/ANALYSIS_ID/histoname, one per line, e.g. '/DELPHI_1996_S3430090/d01-x01-y01'.")
    selgroup.add_option("-m", "--match", action="append",
                        help="Only write out histograms whose $path/$name string matches these regexes. The argument "
                        "may also be a text file.",
                        dest="PATHPATTERNS")
    selgroup.add_option("-M", "--unmatch", action="append",
                        help="Exclude histograms whose $path/$name string matches these regexes",
                        dest="PATHUNPATTERNS")
    parser.add_option_group(selgroup)

    verbgrp = OptionGroup(parser, "Verbosity control")
    verbgrp.add_option("-q", "--quiet", help="Suppress normal messages", dest="LOGLEVEL",
                       action="store_const", default=logging.INFO, const=logging.WARNING)
    verbgrp.add_option("-v", "--verbose", help="Add extra debug messages", dest="LOGLEVEL",
                       action="store_const", default=logging.INFO, const=logging.DEBUG)
    parser.add_option_group(verbgrp)
    return parser


##################################################################

if __name__ == "__main__":
    import os, re, logging

    PROGPATH = sys.argv[0]
    PROGNAME = os.path.basename(PROGPATH)

    ## Try to rename the process on Linux
    try:
        import ctypes
        libc = ctypes.cdll.LoadLibrary('libc.so.6')
        libc.prctl(15, 'compare-histos', 0, 0, 0)
    except Exception:
        pass

    ## Try to use Psyco optimiser
    try:
        import psyco
        psyco.full()
    except ImportError:
        pass

    ## Get Rivet data dir
    rivet_data_dirs = None
    try:
        import rivet
        rivet_data_dirs = rivet.getAnalysisRefPaths()
    except Exception, e:
        sys.stderr.write(PROGNAME + " requires the 'rivet' Python module\n")
        logging.debug(str(e))
        sys.exit(1)

    parser = getCommandLineOptions()
    opts, args = parser.parse_args()


    ## Initialise regex list variables
    import re
    if opts.PATHPATTERNS is None:
        opts.PATHPATTERNS = []
    # See if a pattern option is an observable file and append its contents to
    # the pattern list
    import os
    for m in opts.PATHPATTERNS:
        if os.path.exists(m):
            f=open(m, "r")
            opts.PATHPATTERNS.extend([line.strip() for line in f if not line.startswith("#")])
            f.close()
    opts.PATHPATTERNS = [re.compile(r) for r in opts.PATHPATTERNS]
    if opts.PATHUNPATTERNS is None:
        opts.PATHUNPATTERNS = []
    opts.PATHUNPATTERNS = [re.compile(r) for r in opts.PATHUNPATTERNS]


    ## Work out the implications of the SHOW_SINGLE option
    opts.SHOW_IF_MC_ONLY = False
    opts.SHOW_IF_REF_ONLY = False
    if opts.SHOW_SINGLE in ("all", "mc"):
        opts.SHOW_IF_MC_ONLY = True
    if opts.SHOW_SINGLE in ("all", "ref"):
        opts.SHOW_IF_REF_ONLY = True


    ## Add standard locations and the input files' dirs to the PLOTINFO search paths
    opts.PLOTINFODIR += rivet.getAnalysisPlotPaths()
    for a in args:
        adir = os.path.abspath(os.path.split(a)[0])
        if not adir in opts.PLOTINFODIR:
            opts.PLOTINFODIR.append(adir)


    ## Configure logging
    logging.basicConfig(level=opts.LOGLEVEL, format="%(message)s")


    ## Line styles
    HISTSTYLES = ''
    PLOTSTYLES = ''
    COLORS = ('red', 'blue', '{[rgb]{0.12,0.57,0.14}}', 'magenta')
    LINESTYLES = ('solid', 'dashed', 'dashdotted', 'dotted')
    if opts.MC_ERRS:
        ## If using MC errors, dashed lines often aren't visible, so we put them
        ## to the back of the style list
        LINESTYLES = ('solid', 'dotted', 'dashdotted', 'dashed')
    STYLES = []
    for ls in LINESTYLES:
        for c in COLORS:
            STYLES.append( (c, ls) )
    if opts.STYLE == 'talk':
        PLOTSTYLES += 'PlotSize=8,6\n'
        HISTSTYLES += 'LineWidth=1pt\n'
    elif opts.STYLE == 'bw':
        PLOTSTYLES += 'RatioPlotErrorBandColor=black!10\n'
        COLORS = ('black!90', 'black!50', 'black!30')
        STYLES = []
        for c in COLORS:
            for ls in LINESTYLES:
                STYLES.append( (c, ls) )


    ## Get file names and labels
    FILES = []
    REFFILES = []
    FILEOPTIONS = {}
    if opts.RIVETREFS and rivet_data_dirs:
        for d in rivet_data_dirs:
            import glob
            REFFILES += glob.glob(os.path.join(d, "*.aida"))
    for a in args:
        asplit = a.split(":")
        path = asplit[0]
        FILES.append(path)
        if len(asplit)>1:
            FILEOPTIONS[path] = []
        for i in xrange(1, len(asplit)):
            ## Add "Title" if there is no = sign before math mode
            if not "=" in asplit[i] or ("$" in asplit[i] and asplit[i].index("$") < asplit[i].index("=")):
                asplit[i] = "Title=%s" % asplit[i]
            FILEOPTIONS[path].append(asplit[i])

    ## Ignore duplicates
    REFFILES = list(set(REFFILES))

    ## Check that the requested files are sensible
    if (len(FILES) < 1):
        logging.error(parser.get_usage())
        exit(2)

    ## Handle a request for a reference dataset other than REF
    if opts.REF_ID != "REF":
        if not os.access(os.path.abspath(opts.REF_ID), os.R_OK):
            logging.error("Error: cannot read reference file %s" % opts.REF_ID)
            sys.exit(2)

    ## Read histo data from files into data structures
    HISTOS = {}
    TITLES = {}
    XLABELS = {}
    YLABELS = {}
    LABELS = {}
    NAMES = set()
    MCNAMES = set()
    for f in FILES+REFFILES:
        HISTOS[f] = {}
        LABELS[f] = {}
    for f in FILES+REFFILES:
        try:
            histos, titles, xlabels, ylabels = getHistos(f)
        except AIDAreadException:
            continue
        for n, h in histos.iteritems():
            if h.isdata:
                l = "data"
                if h.expt:
                    l = "%s data" % h.expt
                LABELS[f][n] = l
            else:
                tmp = os.path.basename(f)
                tmp = re.sub(r'(.*)\.aida$', r'\1', tmp)
                LABELS[f][n] = "MC (%s)" % tmp
            HISTOS[f][n] = h
            NAMES.add(n)
            if not h.isdata:
                MCNAMES.add(n)
        for n, t in titles.iteritems():
            TITLES[n] = t
        for n, t in xlabels.iteritems():
            XLABELS[n] = t
        for n, t in ylabels.iteritems():
            YLABELS[n] = t

    # ## Choose histos - use all histos with MC data, or restrict with a list read from file
    # if opts.HISTOGRAMLIST is not None:
    #     newnames = []
    #     try:
    #         f = open(opts.HISTOGRAMLIST, 'r')
    #     except:
    #         logging.error("Cannot open histo list file %s" % opts.HISTOGRAMLIST)
    #         exit(2)
    #     hnames = set()
    #     for line in f:
    #         stripped = line.strip()
    #         if len(stripped) == 0 or stripped.startswith("#"):
    #             continue
    #         hnames.add(stripped.split()[0])
    #     f.close()
    #     NAMES = NAMES.intersection(hnames)
    #     MCNAMES = MCNAMES.intersection(hnames)


    ## Use regex matching to reduce the number of histos
    acceptednames = set()
    for path in NAMES.union(MCNAMES):
        useThis = True
        if opts.PATHPATTERNS:
            useThis = False
            for regex in opts.PATHPATTERNS:
                if regex.search(path):
                    useThis = True
                    break
        if useThis and opts.PATHUNPATTERNS:
            for regex in opts.PATHUNPATTERNS:
                if regex.search(path):
                    useThis = False
                    break
        if useThis:
            acceptednames.add(path)
    NAMES = NAMES.intersection(acceptednames)
    MCNAMES = MCNAMES.intersection(acceptednames)

    ## Pre-emptively reduce number of files to iterate through
    ## (assuming, reasonably, that there is only one ref file per histo)
    activenames = NAMES
    if not opts.SHOW_IF_REF_ONLY:
        activenames = MCNAMES

    ## Write out histos
    num_written = 0
    plotparser = PlotParser(opts.PLOTINFODIR, opts.CONFIGFILES)
    for name in sorted(activenames):
        logging.debug("Writing histos for plot '%s'" % name)

        ## Determine the title
        try:
            title = TITLES[name]
        except:
            title = name
        title = sanitiseString(title)
        xlabel = XLABELS[name]
        ylabel = YLABELS[name]

        ## Identify contributing data files for this histo
        activemcfiles = []
        activereffiles = []
        for f in REFFILES:
            if HISTOS.has_key(f):
                d = HISTOS[f]
                if d.has_key(name):
                    if d[name].isdata:
                        activereffiles.append(f)
        for f in FILES:
            if HISTOS.has_key(f):
                d = HISTOS[f]
                if d.has_key(name):
                    if d[name].isdata:
                        activereffiles.append(f)
                    else:
                        activemcfiles.append(f)
        activefiles = activereffiles + activemcfiles
        #print activereffiles
        #print activemcfiles
        #print activefiles
        if len(activefiles) == 0:
            logging.warning("Something's wrong... somehow there's no data for histogram '%s'!" % name)
            continue

        if len(activefiles) < 2:
            if len(activereffiles) == 0 and not opts.SHOW_IF_MC_ONLY:
                if not opts.RIVETREFS:
                    logging.warning("Skipping histo '%s' since only one (MC) plot is present" % name)
                continue
            if len(activemcfiles) == 0 and not opts.SHOW_IF_REF_ONLY:
                if not opts.RIVETREFS:
                    logging.warning("Skipping histo '%s' since only one (reference) plot is present" % name)
                continue

        ## Identify reference file for this histo
        ref = opts.REF_ID
        if ref == "REF" and activereffiles:
            ref = activereffiles[0]
        if not ref in activefiles:
            ref = activefiles[0]


        ## Header
        try:
            headers = plotparser.getHeaders(name)
        except ValueError, err:
            logging.debug("Could not get plot headers: %s" % err)
            headers = {}

        is2dim = False
        if HISTOS[activefiles[0]][name].is2dim():
            is2dim = True
        if is2dim and len(activefiles)>1:
            logging.warning("More than one histogram file for 2-dim histogram '%s', only using first file" % name)
            activefiles = [activefiles[0]]

        drawonlystr = ""
        for hfile in activefiles:
            drawonlystr += hfile.replace(' ','_') + HISTOS[hfile][name].fullPath().replace(' ','_') + " "
        paramdefaults = {}
        if not is2dim:
            paramdefaults = {"Title" : title,
                             "XLabel" : xlabel,
                             "YLabel" : ylabel,
                             "Legend" : "1",
                             "LogY" : "%d" % int((len(HISTOS[ref][name].getBins()) > 1) and not opts.LINEAR),
                             "DrawOnly" : drawonlystr,
                             "RatioPlot" : "%d" % int(not opts.NORATIO),
                             "XTwosidedTicks" : "1",
                             "YTwosidedTicks" : "1",
                             "RatioPlotReference" : "%s%s" % (ref.replace(' ','_'), HISTOS[ref][name].fullPath().replace(' ','_'))}
        else:
            paramdefaults = {"Title" : title,
                             "XLabel" : xlabel,
                             "YLabel" : ylabel,
                             "Legend" : "1",
                             "XTwosidedTicks" : "1",
                             "YTwosidedTicks" : "1"}


        if opts.RATIO_DEVIATION:
            paramdefaults["RatioPlotMode"] = "deviation"

        ## Behaviour changes if "ref" data is just another MC file
        if not HISTOS[ref][name].isdata:
            ## Don't attempt to remove "gap" bins
            opts.RMBINS = False
            ## Different ratio plot label (not MC/Data)
            paramdefaults["RatioPlotYLabel"] = "Ratio"
            if opts.RATIO_DEVIATION:
                paramdefaults["RatioPlotYLabel"] = "Deviation"


        headstr  = "# BEGIN PLOT\n"
        headstr += PLOTSTYLES
        for param, default in paramdefaults.iteritems():
            if param not in headers:
                headers[param] = default
        for key, val in headers.iteritems():
            if key != "Title" or not opts.NOPLOTTITLE:
                directive = "%s=%s\n" % (key, val)
                headstr += directive
        headstr += "# END PLOT\n"

        ## Special
        try:
            special = plotparser.getSpecial(name)
        except ValueError, err:
            logging.error("Could not get histo specials: %s" % err)
            special = {}
        if special:
            headstr += "\n"
            headstr += "# BEGIN SPECIAL %s\n" % name
            headstr += special
            headstr += "# END SPECIAL\n"

        ## Write histos
        try:
            histopts = plotparser.getHistogramOptions(name)
        except ValueError, err:
            logging.error("Could not get histo options: %s" % err)
            histopts = {}
        histstrs = []
        i = 0
        # logging.debug("Active files: %s" % activefiles)

        for hfile in activefiles:
            histstr = '# BEGIN HISTOGRAM %s%s\n' % (hfile.replace(' ','_'), HISTOS[hfile][name].fullPath().replace(' ','_'))
            if HISTOS[hfile][name].isdata:
                histstr += HISTSTYLES
                histstr += "ErrorBars=1\n"
                histstr += "PolyMarker=*\n"
                histstr += "Title=%s\n" % LABELS[hfile][name]
            else:
                histstr += HISTSTYLES
                color, style = STYLES[i % len(STYLES)]
                if opts.MC_ERRS:
                    histstr += "ErrorBars=1\n"
                if not is2dim:
                    histstr += 'LineColor=%s\n' % color
                    histstr += 'LineStyle=%s\n' % style
                else:
                    histstr += 'LineStyle=none\n'
                    histstr += 'LineWidth=0pt\n'
                histstr += 'Title=%s\n' % LABELS[hfile][name]
                for key, val in histopts.iteritems():
                    #if key == 'ErrorBars' and opts.MC_ERRS:
                    #    continue
                    histstr += "%s=%s\n" % (key, val)
                i += 1
            if hfile in FILEOPTIONS:
                for option in FILEOPTIONS[hfile]:
                    histstr += '%s\n' % option
            def eq(a, b):
                if a == 0 and b == 0:
                    return True
                return abs((b-a)/(b+a)) < 10e-3
            numskipped = 0
            #print hfile, name, HISTOS[hfile][name].numBins(), HISTOS[ref][name].numBins()
            for ibin, bin in enumerate(HISTOS[hfile][name].getBins()):
                ## Skip writing this MC bin if the bin edges don't match, and the MC histo has too many bins
                if opts.RMBINS:
                    xmin, xmax = bin.getXRange()
                    #print HISTOS[hfile][name].numBins(), HISTOS[ref][name].numBins()
                    if hfile != ref and HISTOS[hfile][name].numBins() > HISTOS[ref][name].numBins():
                        rxmin, rxmax = HISTOS[ref][name].getBin(ibin-numskipped).getXRange()
                        if not eq(rxmin, xmin) or not eq(rxmax, xmax):
                            numskipped += 1
                            #print numskipped, (HISTOS[hfile][name].numBins() - HISTOS[ref][name].numBins())
                            assert(numskipped <= abs(HISTOS[hfile][name].numBins() - HISTOS[ref][name].numBins()))
                            continue
                histstr += '%s\n' % (bin.asFlat())
            histstr += "# END HISTOGRAM\n"
            histstrs.append(histstr)

        ## Choose output file name and dir
        if opts.HIER_OUTPUT:
            outdir = os.path.dirname(os.path.join(opts.OUTDIR, name[1:]))
            outfilename = '%s.dat' % os.path.basename(name)
        else:
            outdir = opts.OUTDIR
            outfilename = '%s.dat' % name.replace('/', "_")[1:]

        ## Write file
        mkoutdir(outdir)
        outfilepath = os.path.join(outdir, outfilename)
        logging.debug("Writing histo '%s' to %s" % (name, outfilepath))
        f = open(outfilepath, 'w')
        f.write(headstr + "\n" + "\n".join(histstrs))
        f.close()
        num_written += 1

    logging.info("Wrote %d histo files" % num_written)
