#!/usr/bin/env python3

# Expand a switch in this format:
#  switch (key) {  // mklookup
#     case "string1": code1;
#     case "string2": code2;
#     default: code3;
#  }
# Into code that matches the given strings.
# break; statements in each case are ignored, and effectively added if missing.

import argparse
import logging
import sys
import re
import os
from collections import defaultdict
from contextlib import contextmanager

log = logging.getLogger("mklookup")


class Fail(Exception):
    pass


class Builder:
    def __init__(self, indent, indent_step=4, file=sys.stdout):
        self.indent = indent
        self.indent_step = indent_step
        self.file = file

    @contextmanager
    def nest(self, indent=None):
        if indent is None:
            indent = self.indent_step
        yield Builder(self.indent + indent, file=self.file)

    @contextmanager
    def print_switch(self, condition):
        self.print("switch ({}) {{".format(condition))
        yield Builder(self.indent + self.indent_step, file=self.file)
        self.print("}")

    @contextmanager
    def print_case(self, condition, unquoted=False):
        if unquoted:
            self.print("case {}:".format(condition))
        else:
            self.print("case '{}':".format(condition))
        sub = Builder(self.indent + self.indent_step, file=self.file)
        yield sub
        sub.print("break;")

    def print_key_match(self, pos, val):
        if len(val) == 1:
            self.print("if (key[{}] == '{}') {{".format(pos, val))
        else:
            self.print('if (memcmp(key + {}, "{}", {}) == 0) {{'.format(pos, val, len(val)))

    def print_leaf_check(self, key, code, default, pos):
        self.print_key_match(pos, key[pos:])
        self.print(code, indent=4)
        self.print("} else {")
        self.print(default, indent=4)
        self.print("}")

    def print(self, *args, indent=0):
        self.file.write(" " * (self.indent + indent))
        print(*args, file=self.file)

    def build_len_switch(self, keys, default):
        by_len = defaultdict(dict)
        for k, v in keys.items():
            by_len[len(k)][k] = v

        with self.print_switch("len") as b1:
            for length, keys in sorted(by_len.items()):
                with b1.print_case(length, unquoted=True) as b2:
                    b2.build_prefix_switch(keys, default, pos=0, length=length)
            b1.print("default:", default)

    def build_prefix_switch(self, keys, default, pos, length):
        if pos == length:
            for val in keys.values():
                self.print(val)
            return

        by_prefix = defaultdict(dict)
        for k, v in keys.items():
            by_prefix[k[pos]][k] = v

        if len(by_prefix) > 1:
            with self.print_switch("key[{}]".format(pos)) as b1:
                for prefix, keys in by_prefix.items():
                    with b1.print_case(prefix) as b2:
                        b2.build_prefix_switch(keys, default, pos=pos + 1, length=length)
                b1.print("default:", default)
        else:
            # Only one key left, or many keys left with a common prefix
            if len(keys) == 1:
                key, val = next(iter(keys.items()))
                self.print_leaf_check(key, val, default, pos)
            else:
                self.build_common_prefix_switch(keys, default, pos, length)

    def build_common_prefix_switch(self, keys, default, pos, length):
        # Look for the common prefix
        prefix = os.path.commonprefix([x[pos:] for x in keys.keys()])
        if len(prefix) == 1:
            self.print_key_match(pos, prefix)
            self.build_prefix_switch(keys, default, pos + 1, length)
            self.print("} else {")
            self.print(default, indent=self.indent_step)
            self.print("}")
        else:
            self.print_key_match(pos, prefix)
            with self.nest() as b:
                b.build_prefix_switch(keys, default, pos + len(prefix), length)
            self.print("} else {")
            self.print(default, indent=self.indent_step)
            self.print("}")


class Table:
    re_case_line = re.compile(r'^\s*case\s*"(?P<key>[^"]+)"\s*:\s*(?P<code>.*)$')
    re_default_line = re.compile(r'^\s*default\s*:\s*(?P<code>.+)$')
    re_trailing_break = re.compile(r"\s*break\s*;\s*$")

    def __init__(self, indent):
        self.keys = {}
        self.default = ""
        self.cur_key = None
        self.cur_code = None
        self.re_end_switch = re.compile("^" + re.escape(indent) + r"}\s*")
        self.indent = indent

    def _add_key_code(self):
        self.keys[self.cur_key] = self.re_trailing_break.sub("", self.cur_code)
        self.cur_key = None
        self.cur_code = None

    def _start_key(self, key, code):
        if self.cur_key is not None:
            self._add_key_code()
        self.cur_key = key
        self.cur_code = code

    def _add_code(self, code):
        if self.cur_code:
            self.cur_code += "\n"
        self.cur_code += code

    def _add_default(self, code):
        if self.cur_key:
            self._add_key_code()
        if self.default:
            self.default += "\n"
        self.default += code

    def _end_switch(self):
        if self.cur_key:
            self._add_key_code()

    def parse_line(self, lineno, line):
        if self.re_end_switch.match(line):
            self._end_switch()
            return Verbatim()

        mo = self.re_case_line.match(line)
        if mo:
            self._start_key(mo.group("key"), mo.group("code"))
            return None

        mo = self.re_default_line.match(line)
        if mo:
            self._add_default(mo.group("code"))
            return None

        self._add_code(line)
        return None

    def dump(self):
        for k, v in self.keys.items():
            print("{}: {}".format(k, v))

    def print(self, file):
        builder = Builder(len(self.indent), file=file)
        builder.build_len_switch(self.keys, self.default)


class Verbatim:
    re_switch_line = re.compile(r"^(?P<indent>\s*)switch\s*\(key\)\s*{\s*//\s*mklookup\s*$")

    def __init__(self):
        self.lines = []

    def parse_line(self, lineno, line):
        mo = self.re_switch_line.match(line)
        if not mo:
            self.lines.append(line)
            return None
        else:
            return Table(indent=mo.group("indent"))

    def print(self, file):
        for line in self.lines:
            print(line, file=file)


def parse(fname):
    """
    Parse a file into a sequence of blocks
    """
    blocks = [Verbatim()]
    with open(fname) as fd:
        for lineno, line in enumerate(fd, start=1):
            next_block = blocks[-1].parse_line(lineno, line.rstrip())
            if next_block is not None:
                blocks.append(next_block)
    return blocks


@contextmanager
def outfile(args):
    if args.outfile:
        with open(args.outfile, "wt", encoding="utf8") as fd:
            try:
                yield fd
            except Exception as e:
                if os.path.exists(args.outfile):
                    os.unlink(args.outfile)
                raise e
    else:
        yield sys.stdout


def main():
    parser = argparse.ArgumentParser(
            description="build C code for a lookup table")
    parser.add_argument("--verbose", "-v", action="store_true", help="verbose output")
    parser.add_argument("--debug", action="store_true", help="debug output")
    parser.add_argument("-o", "--outfile", action="store", help="output file (default: stdout)")
    parser.add_argument("pathname", nargs="?", help="source file with the lookup table description")
    args = parser.parse_args()

    log_format = "%(asctime)-15s %(levelname)s %(message)s"
    level = logging.WARN
    if args.debug:
        level = logging.DEBUG
    elif args.verbose:
        level = logging.INFO
    logging.basicConfig(level=level, stream=sys.stderr, format=log_format)

    blocks = parse(args.pathname)

    with outfile(args) as out:
        for block in blocks:
            block.print(out)


if __name__ == "__main__":
    try:
        main()
    except Fail as e:
        log.error("%s", e)
        sys.exit(1)
