#! /usr/bin/env python
# Copyright 2010, 2011, 2013, 2014 Martin C. Frith
# Read MAF-format alignments: write them in other formats.
# Seems to work with Python 2.x, x>=6

# By "MAF" we mean "multiple alignment format" described in the UCSC
# Genome FAQ, not e.g. "MIRA assembly format".

from __future__ import print_function

from itertools import *
import gzip
import math, optparse, os, signal, sys

try:
    from future_builtins import map, zip
except ImportError:
    pass

def myOpen(fileName):
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def maxlen(s):
    return max(map(len, s))

def pairOrDie(sLines, formatName):
    if len(sLines) != 2:
        e = "for %s, each alignment must have 2 sequences" % formatName
        raise Exception(e)
    return sLines

def isMatch(alignmentColumn):
    # No special treatment of ambiguous bases/residues: same as NCBI BLAST.
    first = alignmentColumn[0].upper()
    for i in alignmentColumn[1:]:
        if i.upper() != first: return False
    return True

def gapRunCount(row):
    """Get the number of runs of gap characters."""
    return sum(k == "-" for k, v in groupby(row))

def alignmentRowsFromColumns(columns):
    return map(''.join, zip(*columns))

def symbolSize(symbol, letterSize):
    if symbol == "\\": return 1
    if symbol == "/": return -1
    return letterSize

def insertSize(row, letterSize):
    """Get the length of sequence included in the row."""
    return (len(row) - row.count("-")) * letterSize - 4 * row.count("/") - 2 * row.count("\\")

def matchAndInsertSizes(alignmentColumns, letterSizes):
    """Get sizes of gapless blocks, and of the inserts between them."""
    letterSizeA, letterSizeB = letterSizes
    delSize = insSize = subSize = 0
    for x, y in alignmentColumns:
        if x == "-":
            if subSize:
                if delSize or insSize: yield str(delSize) + ":" + str(insSize)
                yield str(subSize)
                delSize = insSize = subSize = 0
            insSize += symbolSize(y, letterSizeB)
        elif y == "-":
            if subSize:
                if delSize or insSize: yield str(delSize) + ":" + str(insSize)
                yield str(subSize)
                delSize = insSize = subSize = 0
            delSize += symbolSize(x, letterSizeA)
        else:
            subSize += 1
    if delSize or insSize: yield str(delSize) + ":" + str(insSize)
    if subSize: yield str(subSize)

##### Routines for reading MAF format: #####

def updateEvalueParameters(opts, line):
    for field in line.split():
        try:
            k, v = field.split("=")
            x = float(v)
            if k == "lambda":
                opts.bitScoreA = x / math.log(2)
            if k == "K":
                opts.bitScoreB = math.log(x, 2)
        except ValueError:
            pass

def scoreAndEvalue(aLine):
    score = evalue = None
    for i in aLine.split():
        if i.startswith("score="):
            score = i[6:]
        elif i.startswith("E="):
            evalue = i[2:]
    return score, evalue

def mafInput(opts, lines):
    aLine = ""
    sLines = []
    qLines = []
    pLines = []
    for line in lines:
        if line[0] == "s":
            junk, seqName, beg, span, strand, seqLen, row = line.split()
            beg = int(beg)
            span = int(span)
            seqLen = int(seqLen)
            if "\\" in row or "/" in row or len(row) - row.count("-") < span:
                letterSize = 3
            else:
                letterSize = 1
            fields = seqName, seqLen, strand, letterSize, beg, beg + span, row
            sLines.append(fields)
        elif line[0] == "a":
            aLine = line
        elif line[0] == "q":
            qLines.append(line)
        elif line[0] == "p":
            pLines.append(line)
        elif line.isspace():
            if sLines: yield aLine, sLines, qLines, pLines
            aLine = ""
            sLines = []
            qLines = []
            pLines = []
        elif line[0] == "#":
            updateEvalueParameters(opts, line)
            if opts.isKeepComments:
                print(line, end="")
    if sLines: yield aLine, sLines, qLines, pLines

def isJoinable(opts, oldMaf, newMaf):
    x = oldMaf[1]
    y = newMaf[1]
    if x[-1][2] == "-":
        x, y = y, x
    return all(i[:4] == j[:4] and i[5] <= j[4] and j[4] - i[5] <= opts.join
               for i, j in zip(x, y))

def fixOrder(mafs):
    sLines = mafs[0][1]
    if sLines[-1][2] == "-":
        mafs.reverse()

def mafGroupInput(opts, lines):
    x = []
    for i in mafInput(opts, lines):
        if x and not isJoinable(opts, x[-1], i):
            fixOrder(x)
            yield x
            x = []
        x.append(i)
    if x:
        fixOrder(x)
        yield x

##### Routines for converting to AXT format: #####

axtCounter = count()

def writeAxt(maf):
    aLine, sLines, qLines, pLines = maf

    if sLines[0][2] != "+":
        raise Exception("for AXT, the 1st strand in each alignment must be +")

    # Convert to AXT's 1-based coordinates:
    ranges = [(i[0], str(i[4] + 1), str(i[5]), i[2]) for i in sLines]

    head, body = ranges[0], ranges[1:]

    outWords = [str(next(axtCounter))]
    outWords.extend(head[:3])
    for i in body:
        outWords.extend(i)

    score, evalue = scoreAndEvalue(aLine)
    if score:
        outWords.append(score)

    print(" ".join(outWords))
    for i in sLines:
        print(i[6])
    print()  # print a blank line at the end

def mafConvertToAxt(opts, lines):
    for maf in mafInput(opts, lines):
        writeAxt(maf)

##### Routines for converting to tabular format: #####

def writeTab(maf):
    aLine, sLines, qLines, pLines = maf

    score = "0"
    endWords = []
    for i in aLine.split():
        if   i.startswith("score="):
            score = i[6:]
        elif len(i) > 1:
            endWords.append(i)

    outWords = [score]

    for seqName, seqLen, strand, letterSize, beg, end, row in sLines:
        x = seqName, str(beg), str(end - beg), strand, str(seqLen)
        outWords.extend(x)

    letterSizes = [i[3] for i in sLines]
    rows = [i[6] for i in sLines]
    alignmentColumns = zip(*rows)
    gapWord = ",".join(matchAndInsertSizes(alignmentColumns, letterSizes))
    outWords.append(gapWord)

    print("\t".join(outWords + endWords))

def mafConvertToTab(opts, lines):
    for maf in mafInput(opts, lines):
        writeTab(maf)

##### Routines for converting to chain format: #####

def writeChain(maf):
    aLine, sLines, qLines, pLines = maf

    score = "0"
    for i in aLine.split():
        if i.startswith("score="):
            score = i[6:]

    outWords = ["chain", score]

    for seqName, seqLen, strand, letterSize, beg, end, row in sLines:
        x = seqName, str(seqLen), strand, str(beg), str(end)
        outWords.extend(x)

    outWords.append(str(next(axtCounter) + 1))

    print(" ".join(outWords))

    letterSizes = [i[3] for i in sLines]
    rows = [i[6] for i in sLines]
    alignmentColumns = zip(*rows)
    size = "0"
    for i in matchAndInsertSizes(alignmentColumns, letterSizes):
        if ":" in i:
            print(size + "\t" + i.replace(":", "\t"))
            size = "0"
        else:
            size = i
    print(size)

def mafConvertToChain(opts, lines):
    for maf in mafInput(opts, lines):
        writeChain(maf)

##### Routines for converting to PSL format: #####

def pslBlocks(opts, mafs, outCounts):
    """Get sizes and start coordinates of gapless blocks in an alignment."""
    # repMatches is always zero
    # for proteins, nCount is always zero, because that's what BLATv34 does
    normalBases = "ACGTU"
    matches = mismatches = repMatches = nCount = 0

    for maf in mafs:
        sLines = maf[1]
        fieldsA, fieldsB = pairOrDie(sLines, "PSL")
        letterSizeA, begA, endA, rowA = fieldsA[3:7]
        letterSizeB, begB, endB, rowB = fieldsB[3:7]

        size = 0
        for x, y in zip(rowA.upper(), rowB.upper()):
            if x == "-":
                if size:
                    yield size, begA, begB
                    begA += size * letterSizeA
                    begB += size * letterSizeB
                    size = 0
                begB += symbolSize(y, letterSizeB)
            elif y == "-":
                if size:
                    yield size, begA, begB
                    begA += size * letterSizeA
                    begB += size * letterSizeB
                    size = 0
                begA += symbolSize(x, letterSizeA)
            else:
                size += 1
                if x in normalBases and y in normalBases or opts.protein:
                    if x == y:
                        matches += 1
                    else:
                        mismatches += 1
                else:
                    nCount += 1
        if size:
            yield size, begA, begB

    outCounts[0:4] = matches, mismatches, repMatches, nCount

def pslNumInserts(blocks, letterSizeA, letterSizeB):
    numInsertA = numInsertB = 0
    for i, x in enumerate(blocks):
        size, begA, begB = x
        if i:
            if begA > endA:
                numInsertA += 1
            if begB > endB:
                numInsertB += 1
        endA = begA + size * letterSizeA
        endB = begB + size * letterSizeB
    return numInsertA, numInsertB

def pslCommaString(things):
    # UCSC software seems to prefer a trailing comma
    return ",".join(map(str, things)) + ","

def pslEnds(seqLen, strand, beg, end):
    if strand == "-":
        return seqLen - end, seqLen - beg
    return beg, end

def writePsl(opts, mafs):
    matchCounts = [0] * 4
    blocks = list(pslBlocks(opts, mafs, matchCounts))
    matches, mismatches, repMatches, nCount = matchCounts
    numGaplessColumns = sum(matchCounts)

    if not blocks:
        return

    fieldsA, fieldsB = mafs[0][1]
    headSize, headBegA, headBegB = blocks[0]
    tailSize, tailBegA, tailBegB = blocks[-1]

    seqNameA, seqLenA, strandA, letterSizeA = fieldsA[0:4]
    begA, endA = pslEnds(seqLenA, strandA, headBegA, tailBegA + tailSize)
    baseInsertA = endA - begA - numGaplessColumns * letterSizeA

    seqNameB, seqLenB, strandB, letterSizeB = fieldsB[0:4]
    begB, endB = pslEnds(seqLenB, strandB, headBegB, tailBegB + tailSize)
    baseInsertB = endB - begB - numGaplessColumns * letterSizeB

    numInsertA, numInsertB = pslNumInserts(blocks, letterSizeA, letterSizeB)

    strand = strandB
    if letterSizeA > 1 or letterSizeB > 1:
        strand += strandA
    elif strandA != "+":
        raise Exception("for non-translated PSL, the 1st strand in each alignment must be +")

    blockCount = len(blocks)
    blockSizes, blockStartsA, blockStartsB = map(pslCommaString, zip(*blocks))

    outWords = (matches, mismatches, repMatches, nCount,
                numInsertB, baseInsertB, numInsertA, baseInsertA, strand,
                seqNameB, seqLenB, begB, endB, seqNameA, seqLenA, begA, endA,
                blockCount, blockSizes, blockStartsB, blockStartsA)

    print("\t".join(map(str, outWords)))

def mafConvertToPsl(opts, lines):
    if opts.join:
        for i in mafGroupInput(opts, lines):
            writePsl(opts, i)
    else:
        for i in mafInput(opts, lines):
            writePsl(opts, [i])

##### Routines for converting to SAM format: #####

def readGroupId(readGroupItems):
    for i in readGroupItems:
        if i.startswith("ID:"):
            return i[3:]
    raise Exception("readgroup must include ID")

def readSequenceLengths(fileNames):
    """Read name & length of topmost sequence in each maf block."""
    for i in fileNames:
        f = myOpen(i)
        fields = None
        for line in f:
            if fields:
                if line.isspace():
                    fields = None
            else:
                if line[0] == "s":
                    fields = line.split()
                    yield fields[1], fields[5]
        f.close()

def naturalSortKey(s):
    """Return a key that sorts strings in "natural" order."""
    return [(str, int)[k]("".join(v)) for k, v in groupby(s, str.isdigit)]

def karyotypicSortKey(s):
    """Attempt to sort chromosomes in GATK's ridiculous order."""
    if s == "chrM": return []
    if s == "MT": return ["~"]
    return naturalSortKey(s)

def copyDictFile(lines):
    for line in lines:
        if line.startswith("@SQ"):
            sys.stdout.write(line)
        elif not line[0] == "@":
            break

def writeSamHeader(opts, fileNames):
    print("@HD\tVN:1.3\tSO:unsorted")

    if opts.dictionary:
        sequenceLengths = dict(readSequenceLengths(fileNames))
        for k in sorted(sequenceLengths, key=karyotypicSortKey):
            print("@SQ\tSN:%s\tLN:%s" % (k, sequenceLengths[k]))

    if opts.dictfile:
        f = myOpen(opts.dictfile)
        copyDictFile(f)
        f.close()

    if opts.readgroup:
        print("@RG\t" + "\t".join(opts.readgroup.split()))

mapqMissing = "255"
mapqMaximum = "254"
mapqMaximumNum = float(mapqMaximum)

def mapqFromProb(probString):
    try: p = float(probString)
    except ValueError: raise Exception("bad probability: " + probString)
    if p < 0 or p > 1: raise Exception("bad probability: " + probString)
    if p == 0: return mapqMaximum
    phred = -10 * math.log(p, 10)
    if phred >= mapqMaximumNum: return mapqMaximum
    return str(int(round(phred)))

def cigarParts(beg, alignmentColumns, end):
    if beg: yield str(beg) + "H"

    # (doesn't handle translated alignments)
    # uses "read-ahead" technique, aiming to be as fast as possible:
    isActive = True
    for x, y in alignmentColumns: break
    else: isActive = False
    while isActive:
        size = 1
        if x == y:  # xxx assumes no gap-gap columns, ignores ambiguous bases
            for x, y in alignmentColumns:
                if x != y: break
                size += 1
            else: isActive = False
            yield str(size) + "="
        elif x == "-":
            for x, y in alignmentColumns:
                if x != "-": break
                size += 1
            else: isActive = False
            yield str(size) + "I"
        elif y == "-":
            for x, y in alignmentColumns:
                if y != "-": break
                size += 1
            else: isActive = False
            yield str(size) + "D"
        else:
            for x, y in alignmentColumns:
                if x == y or x == "-" or y == "-": break
                size += 1
            else: isActive = False
            yield str(size) + "X"

    if end: yield str(end) + "H"

def writeSam(readGroup, maf):
    aLine, sLines, qLines, pLines = maf
    fieldsA, fieldsB = pairOrDie(sLines, "SAM")
    seqNameA, seqLenA, strandA, letterSizeA, begA, endA, rowA = fieldsA
    seqNameB, seqLenB, strandB, letterSizeB, begB, endB, rowB = fieldsB

    if letterSizeA > 1 or letterSizeB > 1:
        raise Exception("this looks like translated DNA - can't convert to SAM format")

    if strandA != "+":
        raise Exception("for SAM, the 1st strand in each alignment must be +")

    score = None
    evalue = None
    mapq = mapqMissing
    for i in aLine.split():
        if i.startswith("score="):
            v = i[6:]
            if v.isdigit(): score = "AS:i:" + v  # it must be an integer
        elif i.startswith("E="):
            evalue = "EV:Z:" + i[2:]
        elif i.startswith("mismap="):
            mapq = mapqFromProb(i[7:])

    pos = str(begA + 1)  # convert to 1-based coordinate

    alignmentColumns = list(zip(rowA.upper(), rowB.upper()))

    revBegB = seqLenB - endB
    cigar = "".join(cigarParts(begB, iter(alignmentColumns), revBegB))

    seq = rowB.replace("-", "")

    qual = "*"
    if qLines:
        qFields = qLines[-1].split()
        if qFields[1] == seqNameB:
            qual = ''.join(j for i, j in zip(rowB, qFields[2]) if i != "-")

    # It's hard to get all the pair info, so this is very
    # incomplete, but hopefully good enough.
    # I'm not sure whether to add 2 and/or 8 to flag.
    if seqNameB.endswith("/1"):
        seqNameB = seqNameB[:-2]
        if strandB == "+": flag = "99"  # 1 + 2 + 32 + 64
        else:              flag = "83"  # 1 + 2 + 16 + 64
    elif seqNameB.endswith("/2"):
        seqNameB = seqNameB[:-2]
        if strandB == "+": flag = "163"  # 1 + 2 + 32 + 128
        else:              flag = "147"  # 1 + 2 + 16 + 128
    else:
        if strandB == "+": flag = "0"
        else:              flag = "16"

    editDistance = sum(x != y for x, y in alignmentColumns)
    # no special treatment of ambiguous bases: might be a minor bug
    editDistance = "NM:i:" + str(editDistance)

    out = [seqNameB, flag, seqNameA, pos, mapq, cigar, "*\t0\t0", seq, qual]
    out.append(editDistance)
    if score: out.append(score)
    if evalue: out.append(evalue)
    if readGroup: out.append(readGroup)
    print("\t".join(out))

def mafConvertToSam(opts, lines):
    readGroup = ""
    if opts.readgroup:
        readGroup = "RG:Z:" + readGroupId(opts.readgroup.split())
    for maf in mafInput(opts, lines):
        writeSam(readGroup, maf)

##### Routines for converting to BLAST-like format: #####

def pairwiseMatchSymbol(alignmentColumn):
    if isMatch(alignmentColumn):
        return "|"
    else:
        return " "

def strandText(strand):
    if strand == "+":
        return "Plus"
    else:
        return "Minus"

def blastBegCoordinate(zeroBasedCoordinate, strand, seqLen):
    if strand == "+":
        return str(zeroBasedCoordinate + 1)
    else:
        return str(seqLen - zeroBasedCoordinate)

def blastEndCoordinate(zeroBasedCoordinate, strand, seqLen):
    if strand == "+":
        return str(zeroBasedCoordinate)
    else:
        return str(seqLen - zeroBasedCoordinate + 1)

def nextCoordinate(coordinate, row, letterSize):
    return coordinate + insertSize(row, letterSize)

def chunker(things, chunkSize):
    for i in range(0, len(things), chunkSize):
        yield things[i:i+chunkSize]

def blastChunker(sLines, lineSize, alignmentColumns):
    seqLens = [i[1] for i in sLines]
    strands = [i[2] for i in sLines]
    letterSizes = [i[3] for i in sLines]
    coords = [i[4] for i in sLines]
    for chunkCols in chunker(alignmentColumns, lineSize):
        chunkRows = list(alignmentRowsFromColumns(chunkCols))
        begs = list(map(blastBegCoordinate, coords, strands, seqLens))
        coords = list(map(nextCoordinate, coords, chunkRows, letterSizes))
        ends = list(map(blastEndCoordinate, coords, strands, seqLens))
        yield chunkCols, chunkRows, begs, ends

def writeBlast(opts, maf, oldQueryName):
    aLine, sLines, qLines, pLines = maf
    fieldsA, fieldsB = pairOrDie(sLines, "Blast")
    seqNameA, seqLenA, strandA, letterSizeA, begA, endA, rowA = fieldsA
    seqNameB, seqLenB, strandB, letterSizeB, begB, endB, rowB = fieldsB

    if seqNameB != oldQueryName:
        print("Query= " + seqNameB)
        print("         (%s letters)" % seqLenB)
        print()

    print(">" + seqNameA)
    print("          Length = %s" % seqLenA)
    print()

    score, evalue = scoreAndEvalue(aLine)

    if score and opts.bitScoreA is not None and opts.bitScoreB is not None:
        bitScore = opts.bitScoreA * float(score) - opts.bitScoreB
        scoreLine = " Score = %.3g bits (%s)" % (bitScore, score)
    else:
        scoreLine = " Score = %s" % score

    if evalue:
        scoreLine += ", Expect = %s" % evalue

    print(scoreLine)

    alignmentColumns = list(zip(rowA, rowB))

    alnSize = len(alignmentColumns)
    matches = sum(x.upper() == y.upper() for x, y in alignmentColumns)
    matchPercent = 100 * matches // alnSize  # round down, like BLAST
    identLine = " Identities = %s/%s (%s%%)" % (matches, alnSize, matchPercent)
    gaps = rowA.count("-") + rowB.count("-")
    if gaps:
        gapPercent = 100 * gaps // alnSize  # round down, like BLAST
        identLine += ", Gaps = %s/%s (%s%%)" % (gaps, alnSize, gapPercent)
    print(identLine)

    print(" Strand = %s / %s" % (strandText(strandB), strandText(strandA)))
    print()

    for chunk in blastChunker(sLines, opts.linesize, alignmentColumns):
        cols, rows, begs, ends = chunk
        begWidth = maxlen(begs)
        matchSymbols = ''.join(map(pairwiseMatchSymbol, cols))
        print("Query: %-*s %s %s" % (begWidth, begs[1], rows[1], ends[1]))
        print("       %-*s %s"    % (begWidth, " ", matchSymbols))
        print("Sbjct: %-*s %s %s" % (begWidth, begs[0], rows[0], ends[0]))
        print()

def mafConvertToBlast(opts, lines):
    oldQueryName = ""
    for maf in mafInput(opts, lines):
        writeBlast(opts, maf, oldQueryName)
        sLines = maf[1]
        oldQueryName = sLines[1][0]

def blastDataFromMafFields(fields):
    seqName, seqLen, strand, letterSize, beg, end, row = fields
    if strand == "+":
        beg += 1
    else:
        beg = seqLen - beg
        end = seqLen - end + 1
    return seqName, str(beg), str(end), row.upper()

def writeBlastTab(opts, maf):
    aLine, sLines, qLines, pLines = maf
    fieldsA, fieldsB = pairOrDie(sLines, "BlastTab")
    seqNameA, begA, endA, rowA = blastDataFromMafFields(fieldsA)
    seqNameB, begB, endB, rowB = blastDataFromMafFields(fieldsB)

    alignmentColumns = list(zip(rowA, rowB))
    alnSize = len(alignmentColumns)
    matches = sum(x == y for x, y in alignmentColumns)
    matchPercent = "%.2f" % (100.0 * matches / alnSize)
    mismatches = alnSize - matches - rowA.count("-") - rowB.count("-")
    gapOpens = gapRunCount(rowA) + gapRunCount(rowB)

    out = [seqNameB, seqNameA, matchPercent, str(alnSize), str(mismatches),
           str(gapOpens), begB, endB, begA, endA]

    score, evalue = scoreAndEvalue(aLine)
    if evalue:
        out.append(evalue)
        if score and opts.bitScoreA is not None and opts.bitScoreB is not None:
            bitScore = opts.bitScoreA * float(score) - opts.bitScoreB
            out.append("%.3g" % bitScore)

    print("\t".join(out))

def mafConvertToBlastTab(opts, lines):
    for maf in mafInput(opts, lines):
        writeBlastTab(opts, maf)

##### Routines for converting to HTML format: #####

def writeHtmlHeader():
    print('''
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01//EN"
 "http://www.w3.org/TR/html4/strict.dtd">
<html lang="en"><head>
<meta http-equiv="Content-type" content="text/html; charset=UTF-8">
<title>Reliable Alignments</title>
<style type="text/css">
/* Try to force monospace, working around browser insanity: */
pre {font-family: "Courier New", monospace, serif; font-size: 0.8125em}
.a {background-color: #3333FF}
.b {background-color: #9933FF}
.c {background-color: #FF66CC}
.d {background-color: #FF3333}
.e {background-color: #FF9933}
.f {background-color: #FFFF00}
.key {display:inline; margin-right:2em}
</style>
</head><body>

<div style="line-height:1">
<pre class="key"><span class="a">  </span> prob &gt; 0.999</pre>
<pre class="key"><span class="b">  </span> prob &gt; 0.99 </pre>
<pre class="key"><span class="c">  </span> prob &gt; 0.95 </pre>
<pre class="key"><span class="d">  </span> prob &gt; 0.9  </pre>
<pre class="key"><span class="e">  </span> prob &gt; 0.5  </pre>
<pre class="key"><span class="f">  </span> prob &le; 0.5  </pre>
</div>
''')

def probabilityClass(probabilityColumn):
    p = ord(min(probabilityColumn)) - 33
    if   p >= 30: return 'a'
    elif p >= 20: return 'b'
    elif p >= 13: return 'c'
    elif p >= 10: return 'd'
    elif p >=  3: return 'e'
    else: return 'f'

def identicalRuns(s):
    """Yield (item, start, end) for each run of identical items in s."""
    beg = 0
    for k, v in groupby(s):
        end = beg + len(list(v))
        yield k, beg, end
        beg = end

def htmlSpan(text, classRun):
    key, beg, end = classRun
    textbit = text[beg:end]
    if key: return '<span class="%s">%s</span>' % (key, textbit)
    else: return textbit

def multipleMatchSymbol(alignmentColumn):
    if isMatch(alignmentColumn): return "*"
    else: return " "

def writeHtml(opts, maf):
    aLine, sLines, qLines, pLines = maf

    scoreLine = "Alignment"
    score, evalue = scoreAndEvalue(aLine)
    if score:
        scoreLine += " score=" + score
        if evalue:
            scoreLine += ", expect=" + evalue
    print("<h3>%s:</h3>" % scoreLine)

    if pLines:
        probRows = [i.split()[1] for i in pLines]
        probCols = zip(*probRows)
        classes = map(probabilityClass, probCols)
    else:
        classes = repeat(None)

    seqNames = [i[0] for i in sLines]
    nameWidth = maxlen(seqNames)
    rows = [i[6] for i in sLines]
    alignmentColumns = list(zip(*rows))

    print('<pre>')
    for chunk in blastChunker(sLines, opts.linesize, alignmentColumns):
        cols, rows, begs, ends = chunk
        begWidth = maxlen(begs)
        endWidth = maxlen(ends)
        matchSymbols = ''.join(map(multipleMatchSymbol, cols))
        classChunk = islice(classes, opts.linesize)
        classRuns = list(identicalRuns(classChunk))
        for n, b, r, e in zip(seqNames, begs, rows, ends):
            spans = [htmlSpan(r, i) for i in classRuns]
            spans = ''.join(spans)
            formatParams = nameWidth, n, begWidth, b, spans, endWidth, e
            print('%-*s %*s %s %*s' % formatParams)
        print(' ' * nameWidth, ' ' * begWidth, matchSymbols)
        print()
    print('</pre>')

def mafConvertToHtml(opts, lines):
    for maf in mafInput(opts, lines):
        writeHtml(opts, maf)

##### Main program: #####

def isFormat(myString, myFormat):
    return myFormat.startswith(myString)

def mafConvertOneFile(opts, formatName, lines):
    if   isFormat(formatName, "axt"):
        mafConvertToAxt(opts, lines)
    elif isFormat(formatName, "blast"):
        mafConvertToBlast(opts, lines)
    elif isFormat(formatName, "blasttab"):
        mafConvertToBlastTab(opts, lines)
    elif isFormat(formatName, "chain"):
        mafConvertToChain(opts, lines)
    elif isFormat(formatName, "html"):
        mafConvertToHtml(opts, lines)
    elif isFormat(formatName, "psl"):
        mafConvertToPsl(opts, lines)
    elif isFormat(formatName, "sam"):
        mafConvertToSam(opts, lines)
    elif isFormat(formatName, "tabular"):
        mafConvertToTab(opts, lines)
    else:
        raise Exception("unknown format: " + formatName)

def mafConvert(opts, args):
    formatName = args[0].lower()
    fileNames = args[1:]

    opts.isKeepComments = False
    opts.bitScoreA = None
    opts.bitScoreB = None

    if not opts.noheader:
        if isFormat(formatName, "html"):
            writeHtmlHeader()
        if isFormat(formatName, "sam"):
            writeSamHeader(opts, fileNames)
        if isFormat(formatName, "tabular"):
            opts.isKeepComments = True

    if not fileNames:
        fileNames = ["-"]
    for i in fileNames:
        f = myOpen(i)
        mafConvertOneFile(opts, formatName, f)
        f.close()

    if not opts.noheader:
        if isFormat(formatName, "html"):
            print("</body></html>")

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message

    usage = """
  %prog --help
  %prog axt mafFile(s)
  %prog blast mafFile(s)
  %prog blasttab mafFile(s)
  %prog chain mafFile(s)
  %prog html mafFile(s)
  %prog psl mafFile(s)
  %prog sam mafFile(s)
  %prog tab mafFile(s)"""

    description = "Read MAF-format alignments & write them in another format."

    op = optparse.OptionParser(usage=usage, description=description)
    op.add_option("-p", "--protein", action="store_true",
                  help="assume protein alignments, for psl match counts")
    op.add_option("-j", "--join", type="float", metavar="N",
                  help="join co-linear alignments separated by <= N letters")
    op.add_option("-n", "--noheader", action="store_true",
                  help="omit any header lines from the output")
    op.add_option("-d", "--dictionary", action="store_true",
                  help="include dictionary of sequence lengths in sam format")
    op.add_option("-f", "--dictfile",
                  help="get sequence dictionary from DICTFILE")
    op.add_option("-r", "--readgroup",
                  help="read group info for sam format")
    op.add_option("-l", "--linesize", type="int", default=60, #metavar="CHARS",
                  help="line length for blast and html formats (default: %default)")
    opts, args = op.parse_args()
    if opts.linesize <= 0: op.error("option -l: should be >= 1")
    if opts.dictionary and opts.dictfile: op.error("can't use both -d and -f")
    if len(args) < 1: op.error("I need a format-name and some MAF alignments")
    if opts.dictionary and (len(args) == 1 or "-" in args[1:]):
        op.error("need file (not pipe) with option -d")

    try: mafConvert(opts, args)
    except Exception as e:
        prog = os.path.basename(sys.argv[0])
        sys.exit(prog + ": error: " + str(e))
