#! /usr/bin/env python
# ---------------------------------------------------------------
# Programmer(s): Steven Smith @ LLNL
#                David J. Gardner @ LLNL
# ---------------------------------------------------------------
# Acknowledgement: Based on code from Chris White @ LLNL
# ---------------------------------------------------------------
# SUNDIALS Copyright Start
# Copyright (c) 2002-2021, Lawrence Livermore National Security
# and Southern Methodist University.
# All rights reserved.
#
# See the top-level LICENSE and NOTICE files for details.
#
# SPDX-License-Identifier: BSD-3-Clause
# SUNDIALS Copyright End
# ---------------------------------------------------------------
# Sundials test runner for CTest. Note examples must be enabled
# to use the test runner (e.g. EXAMPLES_ENABLE_C=ON). 
#
# All output from the test runner output will be collected in
# <build directory>/Testing/Temporary/LastTest.log
# ---------------------------------------------------------------

from __future__ import print_function

import os
import subprocess
import sys
import argparse
import re

def main():
    parser = argparse.ArgumentParser(description='Sundials test runner for use in ctest.')

    # Required inputs
    parser.add_argument('--testname', dest='testName', type=str, required=True,
                        help='name of the test')
    parser.add_argument('--executablename', dest='executableName', type=str, required=True,
                        help='executable, including path')
    parser.add_argument('--outputdir', dest='outputDir', type=str, required=True,
                        help='path to output directory')

    # Optional inputs
    parser.add_argument('--answerdir', dest='answerDir', type=str,
                        help='path to answers directory')
    parser.add_argument('--answerfile', dest='answerFile', type=str,
                        help='answer file, filename only specify path using --answerdir')

    parser.add_argument('--runcommand', dest='runCommand', type=str,
                        help='command used to run test executable, eg. mpirun')
    parser.add_argument('--runargs', dest='runArgs', type=str,
                        help='command line arguments passed to test executable')

    parser.add_argument('--floatprecision', dest='floatPrecision', type=int,
                        help='precision for floating point failure comparison (number of digits')
    parser.add_argument('--integerpercentage', dest='integerPercentage', type=int,
                        help='integer percentage difference for failure comparison')

    parser.add_argument('--nodiff', dest='CheckDiff', action='store_false',
                        help='do not diff the output file against answer file, only check the test return value')

    parser.add_argument('--builddir', dest='buildDir', type=str,
                        help='path to the example CMakeLists.txt file')
    parser.add_argument('--buildcmd', dest='buildCmd', type=str,
                        help='CMake command used to build the example')

    parser.add_argument('-v', '--verbose', dest='verbose', action='store_true',
                        help='run verbosely')

    args = parser.parse_args()

    if(args.verbose):
        print("testName=%s"%(args.testName))
        print("executableName=%s"%(args.executableName))
        print("answerDir=%s"%(args.answerDir))
        print("answerFile=%s"%(args.answerFile))
        print("outputDir=%s"%(args.outputDir))
        print("runCommand=%s"%(args.runCommand))
        print("runArgs=%s"%(args.runArgs))
        print("buildDir=%s"%(args.buildDir))
        print("buildCmd=%s"%(args.buildCmd))

    # regression test name and output directory 
    testName  = args.testName
    outDir    = args.outputDir

    # path to Sundials test directory
    testDir = os.path.dirname(sys.argv[0])

    # path to Sundials answer directory
    if (args.answerDir):
        answerDir = args.answerDir
    else:
        answerDir = os.path.join(testDir, "answers")
    
    if(args.verbose):
        print("testName %s"%testDir)
        print("outDir %s"%outDir)
        print("testDir %s"%testDir)
        print("answerDir %s"%answerDir)

    # create output directory if necessary
    if not os.path.exists(outDir):
        error("Output directory does not exist, it must be created.", outDir)
        sys.exit(1)
    elif not os.path.isdir(outDir):
        error("Output directory exists but is not a directory, it must be deleted.", outDir)
        sys.exit(1)

    # initialize testing result to success
    success = True

    # configure/build the example if necessary
    if (args.buildCmd):

        # configure/build output file
        outPath = os.path.join(outDir, args.testName+".build")
        outFile = open(outPath, 'w')

        # configure/build commands
        cmakeCommand = [args.buildCmd] + [args.buildDir]
        makeCommand  = ["make"] + [args.testName]

        try:
            # configure the example
            process = subprocess.Popen(cmakeCommand, stdout=outFile, stderr=outFile)
            process.wait()

            # check the return value
            if process.returncode != 0:
                error("cmake exited with error code " + str(process.returncode)
                      + "for ", args.buildDir)
                success = False
                
            # build the example
            process = subprocess.Popen(makeCommand, stdout=outFile, stderr=outFile)
            process.wait()

            # check the return value
            if process.returncode != 0:
                error("make exited with error code " + str(process.returncode)
                      + "for ", args.testName)
                success = False

        finally:
            outFile.close()

    # exit if configure/build failed
    if not success:
        print("Test failed.")
        sys.exit(1)

    # get testing information
    testInfos = getInfo(testName, testDir)

    for testInfo in testInfos:

        if(args.verbose):
            print("testInfo.arguments = ", testInfo.arguments)

        # If running under another program (e.g. mpirun)
        if(args.runCommand):
            testCommand = args.runCommand.replace("\"", "").split(" ")
        else:
            testCommand = []

        testCommand = testCommand + [args.executableName]

        if(args.runArgs):
            # Remove first and last quote
            userArgs=re.sub(r'^"|"$', '', args.runArgs)
            testCommand = testCommand +  userArgs.split(" ")
            answerFileName = testName
        else:
            if testInfo.arguments != "":
                testCommand = testCommand + testInfo.arguments.split(" ")
                answerFileName = testName + "_" + testInfo.arguments.replace(" ", "_")
            else:
                answerFileName = testName

        # if user supplies filename overide default choices
        if(args.answerFile):
            answerFileName = args.answerFile

        outPath = os.path.join(outDir, args.testName+".out")

        answerPath = os.path.join(answerDir, answerFileName)

        # if user supplies precision info overide the default choices
        if(args.floatPrecision):
            testInfo.floatPrecision=args.floatPrecision

        if(args.integerPercentage):
            testInfo.integerPercentage=args.integerPercentage
        
        if(args.verbose):
            print("Starting test...")
            print("  Floating Point Precision = " + str(testInfo.floatPrecision))
            print("  Integer Precision = " + str(testInfo.integerPercentage))
            print("  Command " + ' '.join(testCommand))
            print("  Answer File = " + answerPath)
            print("  Output File = " + outPath + "\n")

        outFile = open(outPath, 'w')
        try:
            if(args.verbose):
                print("Running command : " + ' '.join(testCommand))
                        
            # run the test
            process = subprocess.Popen(testCommand, stdout=outFile, stderr=outFile)
            process.wait()
            
            # check test return value
            if process.returncode != 0:
                error("Exited with error code " + str(process.returncode) + ".", testName)
                success = False

            # compare test output against the answer file output
            if (args.CheckDiff):
                if(args.verbose):
                    print("Comparing test output to answer file.")

                if not os.path.exists(answerPath):
                    error("Answer file did not exist.", answerPath)
                    success = False
                if not os.path.isfile(answerPath):
                    error("Answer file existed but was not a file.", answerPath)
                    success = False
                if not compare(outPath, answerPath, testInfo):
                    success = False
        finally:
            outFile.close()

    if not success:
        print("Test failed.")
        sys.exit(1)
    print("Test passed.")
    sys.exit(0)

def compare(outPath, answerPath, testInfo):
    # Note: This functions strips all lines of leading and trailing whitespace
    # and blank lines. Then it compares the output against the answer file. It
    # allows fuzziness in floating points and integers, as described in the
    # testInfo file.
    outFile = open(outPath, 'r')
    try:
        outLines = stripLines(outFile.readlines())
    finally:
        outFile.close()
    answerFile = open(answerPath, 'r')
    try:
        answerLines = stripLines(answerFile.readlines())
    finally:
        answerFile.close()
    diffCount = 0

    if len(outLines) != len(answerLines):
        error("Line count is not equal (blank lines ignored):\n"
              + str(len(outLines)) + ":" + outPath + "\n"
              + str(len(answerLines)) + ":" + answerPath)
        return False

    for i, outLine in enumerate(outLines):
        answerLine = answerLines[i]
        if not compareLines(outLine, answerLine, testInfo):
            diffCount += 1
            print("Difference:\n  Output: " + outLine + "\n  Answer: " + answerLine + "\n\n")
    if diffCount != 0:
        error(str(diffCount) + " line differences found.", outPath)
        return False
    return True

def compareLines(outLine, answerLine, testInfo):
    outTokens = tokenizeLine(outLine)
    answerTokens = tokenizeLine(answerLine)
    if len(outTokens) != len(answerTokens):
        print("DEBUG: number of tokens differs")
        print(answerTokens)
        print("----------")
        print(outTokens)
        print("----------")
        return False
    for i, outValue in enumerate(outTokens):
        answerValue = answerTokens[i]
        outIsInt, outInt = isInt(outValue)
        if outIsInt:
            answerIsInt, answerInt = isInt(answerValue)
            if answerIsInt and not intCompare(answerInt, outInt, testInfo.integerPercentage):
                print("DEBUG: int")
                return False
        else:
            outIsFloat, outFloat = isFloat(outValue)
            if outIsFloat:
                answerIsFloat, answerFloat = isFloat(answerValue)
                if answerIsFloat and not floatCompare(answerFloat, outFloat, testInfo.floatPrecision):
                    print("DEBUG: float")
                    return False
            elif outValue != answerValue:
                print("DEBUG: str")
                print("outValue <%s>"%(outValue))
                print("answerValue <%s>"%(answerValue))
                return False
    return True

def isInt(value):
    # Check if string can be converted to int
    try:
        return True, int(value)
    except ValueError:
        return False, 0

def intCompare(answer, value, percentage):
    # Check if error between value and answer is greater than the allowed percentage
    if answer == value:
        return True
    if percentage != 0:
        percentageOff = abs(int(100 - (float(value)/float(answer) * 100)))
        if percentageOff <= percentage:
            return True
    print("Error: Integer difference found:\n  Value = " + str(value) + ", Answer = " + str(answer) + ", allowance = " + str(percentage) + "%")
    return False

def isFloat(value):
    # Check if string can be converted to float
    try:
        return True, float(value)
    except ValueError:
        return False, 0

def floatCompare(answer, value, precision):
    # Check if two floating point numbers are the same to a given number of digits
    if precision == 0:
        if answer == value:
            return True
    elif round(answer, precision) == round(value, precision):
        return True
    print("Error: Floating point difference found:\n Value = " + str(value) + ", Answer = " + str(answer) + ", allowance = " + str(precision) + " decimals")
    return False

def tokenizeLine(line):
    return tokenizeLine.pattern.split(line)

# Precompile a pattern since it will be used many times
tokenizeLine.pattern = re.compile(r';|,|\s+|=+|:|\|\(|\)')


def error(message, path=""):
    if path:
        print("Error: " + path + ": " + message)
    else:
        print("Error: " + message)

def stripLines(lines):
    strippedLines = []
    for line in lines:
        strippedLine = line.strip()
        if strippedLine != "":
            strippedLines.append(strippedLine)
    return strippedLines

class TestInfo(object):
    def __init__(self, floatPrecision, integerPercentage, arguments):
        self.floatPrecision = floatPrecision
        self.integerPercentage = integerPercentage
        self.arguments = arguments

def getInfo(testName, testDir):
    infoPath = os.path.join(testDir, "testInfo")
    loweredTestName = testName.lower()
    testInfos = []

    if os.path.exists(infoPath):
        infoFile = open(infoPath, 'r')
        try:
            lines = stripLines(infoFile.readlines())
        finally:
            infoFile.close()

        for line in lines:
            if line.startswith('#'):
                continue
            lineInfo = line.split(':', 3)
            if lineInfo[0].strip().lower() == loweredTestName:
                testInfos.append(TestInfo(int(lineInfo[1].strip()),
                                          int(lineInfo[2].strip()),
                                          lineInfo[3].strip()))

    if len(testInfos) == 0:
        # Note: If no info is found in the testInfo file then run the test 
        # without arguments and use the default comparisions values
        testInfos.append(TestInfo(4,10,""))
    return testInfos

if __name__ == "__main__":
    main()
