#!/usr/bin/env python2

import argparse
import bisect
import collections
import logging
import os

import fast5

import signal
signal.signal(signal.SIGPIPE, signal.SIG_DFL)

def construct_codewords(val_d, cw_d, n, s, rg):
    if n < rg[1] + 1:
        cw_d[n] = s
    else:
        construct_codewords(val_d, cw_d, val_d[n][1][0], s + "0", rg)
        construct_codewords(val_d, cw_d, val_d[n][1][1], s + "1", rg)

def codeword_sort(s0, s1):
    if len(s0) == len(s1):
        return [-1, 1][s0 > s1]
    else:
        return len(s0) - len(s1)

def load_rw(args):
    # construct histogram
    val_d = {val: [1, None] for val in range(args.rw_range[0] - 1, args.rw_range[1] + 1)}
    #val_d[args.rw_range[0] - 1] = [1, None]
    for fn in args.input:
        f = fast5.File(fn)
        if not f.have_raw_samples():
            continue
        a = f.get_raw_int_samples()
        for i in range(1, len(a)):
            val = a[i] - a[i-1]
            if args.rw_range[0] <= val and val <= args.rw_range[1]:
                val_idx = val
            else:
                val_idx = args.rw_range[0] - 1
            val_d[val_idx][0] += 1
    return val_d

def load_ed_len(args):
    val_d = {val: [1, None] for val in range(args.ed_len_range[0] - 1, args.ed_len_range[1] + 1)}
    #val_d[args.ed_len_range[0] - 1] = [1, None]
    val_sum = 0
    val_cnt = 0
    for fn in args.input:
        f = fast5.File(fn)
        if not f.have_eventdetection_events(args.gr):
            continue
        d = f.get_eventdetection_events(args.gr)
        for e in d:
            val = e.length
            val_sum += val
            val_cnt += 1
            if args.ed_len_range[0] <= val and val <= args.ed_len_range[1]:
                val_idx = val
            else:
                val_idx = args.ed_len_range[0] - 1
            val_d[val_idx][0] += 1
    logger.debug("mean val: " + str(float(val_sum)/val_cnt))
    return val_d

def load_fq_qv(args):
    val_d = {val: [1, None] for val in range(args.fq_qv_range[0] - 1, args.fq_qv_range[1] + 1)}
    val_sum = 0
    val_cnt = 0
    for fn in args.input:
        f = fast5.File(fn)
        if not f.have_basecall_fastq(0, args.gr):
            continue
        fq = f.get_basecall_fastq(0, args.gr)
        qv = fq.split()[3]
        for c in qv:
            val = ord(c) - 33
            val_sum += val
            val_cnt += 1
            if args.fq_qv_range[0] <= val and val <= args.fq_qv_range[1]:
                val_idx = val
            else:
                val_idx = args.fq_qv_range[0] - 1
            val_d[val_idx][0] += 1
    logger.debug("mean val: " + str(float(val_sum)/val_cnt))
    return val_d

def run_build_tree(val_d, rg):
    # dump histogram
    #for val in range(-args.range_width, args.range_width + 1):
    #    print('%s\t%s\t%s' % (val, val_d[val], float(val_d[val])/val_count))
    # initialize codes
    kw_l = [(val_d[val][0], val) for val in range(rg[0] - 1, rg[1] + 1)]
    kw_l.sort()
    logger.debug("smallest frequency: " + str(kw_l[:10]))
    logger.debug("highest frequency: " + str(kw_l[-10:]))
    next_node = rg[1] + 1
    # main loop
    while len(kw_l) > 1:
        e = list()
        for i in range(2):
            e.append(kw_l[0])
            del kw_l[0]
        logger.debug('e=' + str(e))
        val_d[next_node] = [e[0][0] + e[1][0], [e[0][1], e[1][1]]]
        logger.debug('next_node=' + str(next_node) + ' val=' + str(val_d[next_node]))
        bisect.insort(kw_l, (val_d[next_node][0], next_node))
        next_node += 1
    # construct codewords
    assert kw_l[0][1] > rg[1]
    cw_d = dict()
    construct_codewords(val_d, cw_d, kw_l[0][1], "", rg)
    cw_key_l = cw_d.keys()
    cw_key_l.sort(lambda w0, w1: codeword_sort(cw_d[w0], cw_d[w1]))
    print("{")
    for cw in cw_key_l:
        print('"%s", "%s",' % (['.', cw][cw != rg[0] - 1], cw_d[cw]))
    print("}")

if __name__ == "__main__":
    description = """
    Toolkit for constructing Huffman codes for encoding fast5 files.
    """
    parser = argparse.ArgumentParser(description=description, epilog="")
    parser.add_argument("--log-level", default="info",
                        help="log level")
    parser.add_argument("--gr", type=str, default="",
                        help="Group")
    parser.add_argument("--rw-range", default=[-100,100],
                        help="Encoding range for raw sample differences.")
    parser.add_argument("--ed-skip-range", default=[0,1],
                        help="Encoding range for ed skip values.")
    parser.add_argument("--ed-len-range", default=[1,100],
                        help="Encoding range for ed length values.")
    parser.add_argument("--fq-qv-range", default=[0,31],
                        help="Encoding range for fq qv values.")
    parser.add_argument("command", choices=["rw", "ed-skip", "ed-len", "fq-qv"])
    parser.add_argument("input", nargs="*",
                        help="Fast5 file")
    args = parser.parse_args()

    numeric_log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(numeric_log_level, int):
        raise ValueError("Invalid log level: '%s'" % args.log_level)
    logging.basicConfig(level=numeric_log_level,
                        format="%(asctime)s %(name)s.%(levelname)s %(message)s",
                        datefmt="%Y/%m/%d %H:%M:%S")
    logger = logging.getLogger(os.path.basename(__file__))
    logger.debug("args: " + str(args))

    if type(args.rw_range) != list:
        args.rw_range = list((int(i) for i in args.rw_range.split(',')))[:2]
        assert args.rw_range[0] < args.rw_range[1]
    if type(args.ed_skip_range) != list:
        args.ed_skip_range = list((int(i) for i in args.ed_skip_range.split(',')))[:2]
        assert args.ed_skip_range[0] < args.ed_skip_range[1]
    if type(args.ed_len_range) != list:
        args.ed_len_range = list((int(i) for i in args.ed_len_range.split(',')))[:2]
        assert args.ed_len_range[0] < args.ed_len_range[1]
    logger.debug("args: " + str(args))

    if args.command == "rw":
        d = load_rw(args)
        run_build_tree(d, args.rw_range)
    elif args.command == "ed-len":
        d = load_ed_len(args)
        run_build_tree(d, args.ed_len_range)
    elif args.command == "fq-qv":
        d = load_fq_qv(args)
        run_build_tree(d, args.fq_qv_range)
