#!/usr/bin/env python3

import argparse
import logging
import sys
import os
import tempfile
import resource
import shlex
import subprocess
from contextlib import contextmanager


log = logging.getLogger("runtest")


class Fail(Exception):
    pass


def set_resource_limits():
    # Limit to 1G of virtual memory
    try:
        vmem_limit = 1024*1024*1024
        limit = getattr(resource, "RLIMIT_VMEM", None)
        if limit is None:
            limit = resource.RLIMIT_AS
        resource.setrlimit(limit, (vmem_limit, vmem_limit))
    except Exception:
        log.exception("setrlimit failed")


class Environment:
    def __init__(self):
        self.top_srcdir = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
        self.python = os.environ.get("PYTHON", sys.executable)
        self.env = dict(os.environ)
        self.env["WREPORT_EXTRA_TABLES"] = os.path.join(self.top_srcdir, "tables")
        self.env["WREPORT_TESTDATA"] = os.path.join(self.top_srcdir, "extra")
        self.env["DBA_REPINFO"] = os.path.join(self.top_srcdir, "tables/repinfo.csv")
        self.env["DBA_TABLES"] = os.path.join(self.top_srcdir, "tables")
        self.env["DBA_TESTDATA"] = os.path.join(self.top_srcdir, "extra")
        self.env["DBA_INSECURE_SQLITE"] = "1"
        self.env["PYTHONPATH"] = os.path.join(self.top_srcdir, "python/.libs")

        # Run under eatmydata if available
        libeatmydata = "/usr/lib/libeatmydata/libeatmydata.so"
        if os.path.exists(libeatmydata):
            if "LD_PRELOAD" in os.environ:
                self.env["LD_PRELOAD"] = "{} {}".join(libeatmydata, os.environ["LD_PRELOAD"])
            else:
                self.env["LD_PRELOAD"] = libeatmydata

    def make_test(self, cmd):
        if cmd.endswith(".py"):
            return PyTestCase(cmd, env=self)
        else:
            return BinTestCase(cmd, env=self)


class TestCase:
    def __init__(self, name, env, *args, **kw):
        self.name = name
        self.log = logging.getLogger(self.name)
        self.env = env
        self.initial_dir = os.getcwd()
        self.success = None

    def is_skipped(self):
        return False

    def get_cmd(self):
        cmd = list(self.cmd)

        debugger = os.environ.get("DEBUGGER", None)
        if debugger is not None:
            # Try to debug the libtool executable, if present
            new_cmd = [os.path.join(self.env.top_srcdir, "libtool"), "--mode=execute"]
            new_cmd.extend(shlex.split(debugger))
            new_cmd.extend(cmd)
            cmd = new_cmd

        args = os.environ.get("ARGS", None)
        if args is not None:
            cmd.extend(shlex.split(args))

        return cmd

    def run(self):
        if self.is_skipped():
            self.log.debug("skipped.")
            return

        cmd = self.get_cmd()
        cmd_fmt = " ".join(shlex.quote(x) for x in cmd)
        with tempfile.TemporaryDirectory(prefix="dballe-test-") as tmpdir:
            self.log.debug("%s: running in %s", cmd_fmt, tmpdir)
            res = subprocess.call(self.get_cmd(), cwd=tmpdir, env=self.env.env, preexec_fn=set_resource_limits)

            if os.environ.get("PAUSE", None) is not None:
                print("Post-test inspection requested.")
                print("Exit this shell to cleanup the test environment.")
                subprocess.call("/bin/bash", cwd=tmpdir, env=self.env.env)

            self.success = res == 0
            if not self.success:
                self.log.error("%s failed with code %d", cmd_fmt, res)
            else:
                self.log.info("%s: success", cmd_fmt)

        # ## Clean up the test environment at exit unless asked otherwise
        # cleanup() {
        #     test -z "$PRESERVE" && rm -rf "$TESTDIR"
        # }


class BinTestCase(TestCase):
    def __init__(self, cmd, *args, **kw):
        super().__init__(os.path.basename(cmd), *args, **kw)
        self.cmd = [cmd]


class PyTestCase(TestCase):
    def __init__(self, cmd, *args, **kw):
        super().__init__(os.path.basename(cmd)[:-3], *args, **kw)
        self.cmd = [self.env.python, cmd]

    def is_skipped(self):
        args = os.environ.get("ARGS", None)
        if args is None:
            return False
        args = shlex.split(args)
        if not args:
            return False
        return args[0] != self.name


@contextmanager
def term_red():
    if not os.isatty(sys.stderr.fileno()):
        yield
        return

    sys.stderr.buffer.write(b"\x1b[31m")
    yield
    sys.stderr.buffer.write(b"\x1b[0m")


@contextmanager
def term_green():
    if not os.isatty(sys.stderr.fileno()):
        yield
        return

    sys.stderr.buffer.write(b"\x1b[32m")
    yield
    sys.stderr.buffer.write(b"\x1b[0m")


def main():
    parser = argparse.ArgumentParser(
            description="run DB-All.e unit tests")
    parser.add_argument("--verbose", "-v", action="store_true", help="verbose output")
    parser.add_argument("--debug", action="store_true", help="debug output")
    parser.add_argument("-C", "--command-dir", action="store", default=".", help="directory where tests to run are found (default: .)")
    parser.add_argument("test", nargs="+", help="test executable(s) to run")
    args = parser.parse_args()

    log_format = "%(asctime)-15s %(levelname)s %(name)s: %(message)s"
    level = logging.WARN
    if args.debug:
        level = logging.DEBUG
    elif args.verbose:
        level = logging.INFO
    logging.basicConfig(level=level, stream=sys.stderr, format=log_format)

    env = Environment()
    tests = []
    for test in args.test:
        tests.append(env.make_test(os.path.realpath(os.path.join(args.command_dir, test))))

    for test in tests:
        try:
            test.run()
        except Exception:
            test.log.exception("could not run test")

    success = 0
    fail = 0
    not_run = 0
    for test in tests:
        if test.success is None:
            not_run += 1
        elif test.success is True:
            success += 1
        else:
            fail += 1

    if fail > 0:
        with term_red():
            log.info("%d tests requested, %d succeeded, %d failed, %d did not run.", len(tests), success, fail, not_run)
    elif fail == 0 and not_run == 0:
        with term_green():
            log.info("%d tests requested, %d succeeded, %d failed, %d did not run.", len(tests), success, fail, not_run)

    if any(x.success is not True for x in tests):
        sys.exit(1)


if __name__ == "__main__":
    try:
        main()
    except Fail as e:
        log.error("%s", e)
        sys.exit(1)
