#! /usr/bin/env python3
# Author: Martin C. Frith 2019
# SPDX-License-Identifier: MIT

from __future__ import print_function

import collections
import gzip
import itertools
import logging
import math
import optparse
import os
import shutil
import signal
import subprocess
import sys
import tempfile
from operator import itemgetter

def openFile(fileName):
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def fastxInput(lines):
    s = []
    for i in lines:
        if i[0] == ">" and s and s[0][0] == ">" or i[0] == "@" and len(s) == 4:
            yield s
            s = []
        if not i.isspace():
            s.append(i.rstrip())
    if s:
        yield s

def nameAndSeqFromFastx(seqLines):
    title = seqLines[0]
    name = title[1:].split()[0]
    if title[0] == ">":
        seq = "".join(seqLines[1:])
    else:
        seq = seqLines[1]
    return name, seq

##### Routines for getting score parameters from a last-train file:

builtinTrainFiles = {
    "promethion-2019": """
# scale of score parameters: 4.5512
# delOpenProb: 0.0369615
# insOpenProb: 0.0340916
# delExtendProb: 0.439744
# insExtendProb: 0.403943
# probability matrix (query letters = columns, reference letters = rows):
#   A              C              G              T             
# A 0.278908       0.00109169     0.012899       0.00107855    
# C 0.00154506     0.197869       0.00044938     0.00272089    
# G 0.0272926      0.000508552    0.177789       0.000859018   
# T 0.00126293     0.00328421     0.000545484    0.291896
""",
    "promethion-rna-2019" :"""
# scale of score parameters: 4.5512
# delOpenProb: 0.0578071
# insOpenProb: 0.0297262
# delExtendProb: 0.438808
# insExtendProb: 0.380129
# probability matrix (query letters = columns, reference letters = rows):
#   A              C              G              T             
# A 0.246603       0.00134        0.00742782     0.00762015    
# C 0.00168584     0.221462       0.000388905    0.0143456     
# G 0.00693157     0.000421241    0.239649       0.000805103   
# T 0.00803073     0.0131941      0.000617237    0.229477      
""",
    "sequel-ii-clr-2019": """
# scale of score parameters: 4.5512
# delOpenProb: 0.0329087
# insOpenProb: 0.0343556
# delExtendProb: 0.0674455
# insExtendProb: 0.288643
# probability matrix (query letters = columns, reference letters = rows):
#   A              C              G              T             
# A 0.292742       0.00340172     0.000378204    0.00239594    
# C 0.000849175    0.20731        0.000104201    0.00037284    
# G 0.0013137      0.000304338    0.187333       0.00680248    
# T 0.0135957      0.000785525    0.00475765     0.277554      
"""
}

def parametersFromLastTrain(trainFile):
    builtin = builtinTrainFiles.get(trainFile.lower())
    lines = builtin.splitlines() if builtin else openFile(trainFile)

    alphabetSize = 4
    scale = delOpenProb = delExtendProb = insOpenProb = insExtendProb = -1
    probMatrix = range(alphabetSize)
    for line in lines:
        fields = line.split()
        if "scale of score parameters" in line:
            scale = float(fields[5])
        if "delOpenProb" in line:
            delOpenProb = float(fields[2])
        if "insOpenProb" in line:
            insOpenProb = float(fields[2])
        if "delExtendProb" in line:
            delExtendProb = float(fields[2])
        if "insExtendProb" in line:
            insExtendProb = float(fields[2])
        if "probability matrix" in line:
            probMatrix = []
        if len(fields) == 2 + alphabetSize and len(probMatrix) < alphabetSize:
            probMatrix.append([float(i) for i in fields[2:]])
    gapProbs = delOpenProb, delExtendProb, insOpenProb, insExtendProb
    if scale < 0 or -1 in gapProbs or len(probMatrix) != alphabetSize:
        raise RuntimeError("can't read the last-train data")
    return scale, probMatrix, gapProbs

def scoreFromProb(scale, prob):
    assert prob > 0
    logProb = math.log(prob)
    return scale * logProb

def costFromProb(scale, prob):
    return -scoreFromProb(scale, prob)

def matrixWithComplementSymmetricRows(matrix):  # xxx breaks matrix symmetry
    for i, j in zip(matrix, reversed(matrix)):
        avg = (sum(i) + sum(j)) * 0.5
        yield [k * avg / sum(i) for k in i]

def substitutionProbMatrix(refMatrix, qryMatrix):
    r = range(4)
    s = [sum(row) for row in refMatrix]
    m = [[None for j in r] for i in r]
    for i in r:
        for j in r:
            m[i][j] = sum(refMatrix[k][i] * qryMatrix[k][j] / s[k] for k in r)
    return m

def scoreMatrixFromProbMatrix(scale, probMatrix):
    r = range(len(probMatrix))
    rowSums = [sum(probMatrix[i][j] for j in r) for i in r]
    colSums = [sum(probMatrix[i][j] for i in r) for j in r]
    out = [[None for j in r] for i in r]
    for i in r:
        for j in r:
            ratio = probMatrix[i][j] / (rowSums[i] * colSums[j])
            out[i][j] = scoreFromProb(scale, ratio)
    return out

def substitutionScoreMatrix(scale, refMatrix, qryMatrix):
    p = substitutionProbMatrix(refMatrix, qryMatrix)
    return scoreMatrixFromProbMatrix(scale, p)

def newGapExtendProb(gapProbs):
    delOpenProb, delExtendProb, insOpenProb, insExtendProb = gapProbs
    delFrac = delOpenProb / (1 - delExtendProb)
    insFrac = insOpenProb / (1 - insExtendProb)
    gapFrac = delFrac + insFrac
    return (delFrac * delExtendProb + insFrac * insExtendProb) / gapFrac

def gapScoresFromProbs(scale, gapProbs):
    delOpenProb, delExtendProb, insOpenProb, insExtendProb = gapProbs
    gapOpenProb = 1 - (1 - delOpenProb) * (1 - insOpenProb)
    gapExtendProb = newGapExtendProb(gapProbs)
    logging.info("gapOpenProb: {0:.6}".format(gapOpenProb))
    logging.info("gapExtendProb: {0:.6}".format(gapExtendProb))
    gapCloseProb = 1 - gapExtendProb
    firstGapProb = gapOpenProb * gapCloseProb
    gapExtendProb += firstGapProb
    gapExistProb = firstGapProb / gapExtendProb
    gapExistCost = costFromProb(scale, gapExistProb)
    gapExtendCost = costFromProb(scale, gapExtendProb)
    return gapExistCost, gapExtendCost

def alignmentScoresFromProbs(scale, probMatrix, gapProbs):
    compMatrix = [row[::-1] for row in reversed(probMatrix)]
    fwdMatrix = substitutionScoreMatrix(scale, probMatrix, probMatrix)
    revMatrix = substitutionScoreMatrix(scale, probMatrix, compMatrix)
    gapExistCost, gapExtendCost = gapScoresFromProbs(scale, gapProbs)
    return fwdMatrix, revMatrix, gapExistCost, gapExtendCost

##### Routines for printing the score parameters in MAFFT format:

def mafftNum(x):
    return format(x, ".3")

def mafftGapOptions(scores):
    fwdMatrix, revMatrix, gapExistCost, gapExtendCost = scores
    alphSize = len(fwdMatrix)
    # xxx assume a, c, g, t have equal frequency:
    meanMatch = sum(fwdMatrix[i][i] for i in range(alphSize)) / alphSize
    meanUnrelated = sum(map(sum, fwdMatrix)) / (alphSize ** 2)
    meanDiff = meanMatch - meanUnrelated
    gapExistCostMafft = gapExistCost / meanDiff
    gapExtendCostMafft = (gapExtendCost * 2 + meanUnrelated) / meanDiff
    op = mafftNum(gapExistCostMafft)
    gop = mafftNum(-gapExistCostMafft)
    ep = mafftNum(gapExtendCostMafft)
    gep = mafftNum(-gapExtendCostMafft)
    lep = mafftNum(-meanUnrelated / meanDiff)
    lexp = mafftNum(-gapExtendCost / meanDiff)
    return ["--op", op, "--gop", gop, "--ep", ep, "--gep", gep, "--gexp", "0",
            "--lop", gop, "--lep", lep, "--lexp", lexp]

def printMafftMatrix(mafftAlphabet, fwdMatrix, revMatrix, outFile):
    d = {"A": (0,0), "C": (0,1), "G": (0,2), "T": (0,3),  # + strand bases
         "E": (1,0), "Q": (1,1), "F": (1,2), "P": (1,3)}  # - strand bases

    print(" ", *[format(i, ">5") for i in mafftAlphabet], file=outFile)

    for i, x in enumerate(mafftAlphabet):
        outFile.write(x)
        j = i + 1
        for y in mafftAlphabet[:j]:
            score = 0.0
            if x in d and y in d:
                xm, xi = d[x]
                ym, yi = d[y]
                if xm == 0 and ym == 0:
                    score = fwdMatrix[xi][yi]
                if xm == 0 and ym == 1:
                    score = revMatrix[xi][yi]
                if xm == 1 and ym == 0:
                    score = revMatrix[3-xi][3-yi]
                if xm == 1 and ym == 1:
                    score = fwdMatrix[3-xi][3-yi]
            outFile.write(" " + format(score, "5.3"))
        outFile.write("\n")

def printMafftParams(fwdMatrix, revMatrix, mafftGapOpts, outFile):
    mafftAlphabet = "ARNDCQEGHILKMFPSTWYV"
    print("# MAFFT cost:", *mafftGapOpts, file=outFile)
    printMafftMatrix(mafftAlphabet, fwdMatrix, revMatrix, outFile)
    mafftFrequencies = [(0.0, 0.25)[i in "ACGT"] for i in mafftAlphabet]
    # must be consistent with meanMatch
    print("frequency", *mafftFrequencies, file=outFile)

##### Routines for making a consensus sequence:

def alignmentColumnsForConsensus(opts, rows):
    alignmentLength = max(map(len, rows))
    seqNumChange = [0] * (alignmentLength + 1)
    for i in rows:
        beg = len(i) - len(i.lstrip("-"))
        end = len(i.rstrip("-"))
        seqNumChange[beg] += 1
        seqNumChange[end] -= 1

    if opts.seq_min.endswith("%"):
        minPercent = float(opts.seq_min[:-1])
        minSeqs = len(rows) * (minPercent / 100)
    else:
        minSeqs = int(opts.seq_min)

    beg = end = alignmentLength
    seqNum = 0
    for i, x in enumerate(seqNumChange):
        seqNum += x
        if seqNum >= minSeqs:
            if beg > i:
                beg = i
            end = i + 1

    # xxx try using gapProbs?

    columns = zip(*rows)
    seqNum = len(rows) if opts.end else 0
    for i, z in enumerate(zip(seqNumChange, columns)):
        x, y = z
        if not opts.end:
            seqNum += x
        if i >= beg and i < end:
            gapNum = seqNum + y.count("-") - len(y)
            if gapNum * 100 <= seqNum * opts.gap_max:
                yield "".join(y)

def strandScore(scoreRow, column, bases):
    return sum(scoreRow[i] * column.count(x) for i, x in enumerate(bases))

def columnScore(priorScores, scoreMatrix, column, fwdBaseIndex):
    revBaseIndex = 3 - fwdBaseIndex
    fwdScore = strandScore(scoreMatrix[fwdBaseIndex], column, "ACGT")
    revScore = strandScore(scoreMatrix[revBaseIndex], column, "tgca")
    return priorScores[fwdBaseIndex] + fwdScore + revScore

def asciiFromInvProb(invProb):
    errProb = 1 - 1 / invProb
    s = int(math.floor(-10 * math.log10(max(errProb, 1e-10))))
    return chr(min(s + 33, 126))

def consensusCol(priorScores, scoreMatrix, column):
    bases = "acgt"
    column = column.replace("U", "T") # keep
    scores = [columnScore(priorScores, scoreMatrix, column, i)
              for i in range(len(bases))]
    j, m = max(enumerate(scores), key=itemgetter(1))
    invProb = sum(math.exp(i - m) for i in scores)
    return bases[j], invProb

def consensusSeq(opts, probMatrix, alignedSequences):
    if not alignedSequences:
        raise RuntimeError("can't make a consensus of zero sequences")

    baseProbs = [sum(row) for row in probMatrix]
    priorScores = [math.log(i) for i in baseProbs]
    scoreMatrix = [[math.log(x / sum(row)) for x in row] for row in probMatrix]

    rows = [seq for name, seq in alignedSequences]
    cols = alignmentColumnsForConsensus(opts, rows)
    return [consensusCol(priorScores, scoreMatrix, i) for i in cols]

def printConsensusSeq(opts, probMatrix, alignedSequences):
    c = consensusSeq(opts, probMatrix, alignedSequences)
    seq = "".join(i[0] for i in c)
    if opts.format.lower() in ("fastq", "fq"):
        qual = "".join(asciiFromInvProb(i[1]) for i in c)
        print("@" + opts.name, seq, "+", qual, sep="\n")
    else:
        print(">" + opts.name, seq, sep="\n")

##### Routines for aligning the sequences:

def toInt(x):
    return int(round(x))

def printScoreParameters(scoreMatrix, gapExistCost, gapExtendCost, out):
    alphabet = "ACGT"
    print("#last", "-a", toInt(gapExistCost), "-b", toInt(gapExtendCost),
          file=out)
    print(" ", *alphabet, file=out)
    for i, j in zip(alphabet, scoreMatrix):
        print(i, *map(toInt, j), file=out)

xCrazyComplement = "ACGT" "RYKMBDHV" "U"  # uracil
yCrazyComplement = "PFQE" "YRMKVHDB" "E"
def crazyComplement(base):
    i = xCrazyComplement.find(base)
    return yCrazyComplement[i] if i >= 0 else base

def transliteratedFastaLines(lines):
    for i in lines:
        if i[0] == ">":
            isRev = i.startswith(">_R_")
        elif isRev:
            i = i.replace(*"Ea").replace(*"Qc").replace(*"Fg").replace(*"Pt")
            i = i.lower()
        yield i

def printFasta(name, seq, out):
    print(">", name, sep="", file=out)
    print(seq, file=out)

def printNumberedSequences(sequences, out):
    for i, x in enumerate(sequences):
        name, seq = x
        printFasta(i, seq, out)

def printSequencesWithFakeAlphabet(opts, seqsZip, out):
    for seqRank, s, isRev, beg, end in seqsZip:
        if seqRank < 1:
            continue
        name, seq = s
        seq = seq.upper()  # xxx ???
        if isRev:
            name = "_R_" + name
            seq = "".join(map(crazyComplement, reversed(seq)))
        else:
            seq = seq.replace("U", "T")
        if not opts.all:
            if beg >= end:
                beg = end = 0
            seq = seq[beg:end]
        printFasta(name, seq, out)

def alignmentInput(opts, lines):
    for line in lines:
        if opts.verbose > 1:
            sys.stderr.write(line)
        fields = line.split()
        if line[0] == "#":
            if line.count("=") > 9:
                logging.info(line.rstrip())
        elif line[0] == "a":
            for i in fields:
                if i.startswith("score="):
                    score = int(i[6:])
                    seqRecords = []
        elif line[0] == "s":
            seqNum = int(fields[1])
            seqBeg = int(fields[2])
            seqEnd = seqBeg + int(fields[3])
            strand = fields[4]
            seqLen = int(fields[5])
            seqAln = fields[6]
            seqRecords.append((seqNum, seqLen, strand, seqBeg, seqEnd, seqAln))
        elif line[0] == "p":
            ref, qry = seqRecords
            if ref[0] != qry[0]:
                probCodes = fields[1]
                yield -score, qry, ref, probCodes

def rankPerSeq(groupNumPerSeq, keptGroupNum):
    rank = 0
    for i in groupNumPerSeq:
        if i == keptGroupNum:
            rank += 1  # the first rank is 1, not 0, for MAFFT
            yield rank
        else:
            yield 0

def findRange(oldRanges, newRange):
    qryBeg, qryEnd, refBeg, refEnd = newRange
    for i, x in enumerate(oldRanges):
        qb, qe, rb, re = x
        if qryBeg < qb and qryEnd < qe and refBeg < rb and refEnd < re:
            return i
        if qryBeg <= qb or qryEnd <= qe or refBeg <= rb or refEnd <= re:
            return -1
    return len(oldRanges)

def isBigDiagonalChange(opts, oldRanges, newRange, newRangeIndex):
    qryBeg, qryEnd, refBeg, refEnd = newRange
    if newRangeIndex > 0:
        qb, qe, rb, re = oldRanges[newRangeIndex - 1]
        if abs((re - qe) - (refBeg - qryBeg)) > opts.diagonal_max:
            return True
    if newRangeIndex < len(oldRanges):
        qb, qe, rb, re = oldRanges[newRangeIndex]
        if abs((rb - qb) - (refEnd - qryEnd)) > opts.diagonal_max:
            return True
    return False

def layoutOfSeqs(opts, numOfSequences, alignmentsSortedByScore):
    dataPerSeq = [(i, False) for i in range(numOfSequences)]
    alignmentOrder = []
    keptAlignments = []
    alignedRanges = collections.defaultdict(list)
    for aln in alignmentsSortedByScore:
        negScore, qry, ref, probCodes = aln
        if qry[0] > ref[0]:
            qry, ref = ref, qry
        refNum, refLen, refStrand, refBeg, refEnd, refAln = ref
        qryNum, qryLen, qryStrand, qryBeg, qryEnd, qryAln = qry
        refGroupNum, isRevRef = dataPerSeq[refNum]
        qryGroupNum, isRevQry = dataPerSeq[qryNum]
        isOpposite = (isRevQry != isRevRef)
        isRevAln = (qryStrand != refStrand)
        isFlip = (isRevAln != isOpposite)
        oldRanges = alignedRanges[qryNum, refNum]
        if refStrand == "-":
            refBeg, refEnd = refLen - refEnd, refLen - refBeg
            qryBeg, qryEnd = qryLen - qryEnd, qryLen - qryBeg
        newRange = qryBeg, qryEnd, refBeg, refEnd
        rpos = findRange(oldRanges, newRange)
        if refGroupNum == qryGroupNum:
            if (isFlip or rpos < 0 or
                isBigDiagonalChange(opts, oldRanges, newRange, rpos)):
                continue
        else:
            minGroupNum = min(refGroupNum, qryGroupNum)
            maxGroupNum = max(refGroupNum, qryGroupNum)
            alignmentOrder.append((minGroupNum, maxGroupNum))
            for i, x in enumerate(dataPerSeq):
                groupNum, isRev = x
                if groupNum == maxGroupNum:
                    dataPerSeq[i] = minGroupNum, isRev != isFlip
        keptAlignments.append(aln)
        oldRanges.insert(rpos, newRange)
    return dataPerSeq, alignmentOrder, keptAlignments

def updateRange(begPerSeq, endPerSeq, isRevPerSeq, seqRecord):
    seqNum, seqLen, seqStrand, seqBeg, seqEnd, seqAln = seqRecord
    if (seqStrand == "-") != isRevPerSeq[seqNum]:
        seqBeg, seqEnd = seqLen - seqEnd, seqLen - seqBeg
    begPerSeq[seqNum] = min(begPerSeq[seqNum], seqBeg)
    endPerSeq[seqNum] = max(endPerSeq[seqNum], seqEnd)

def alignedRangePerSeq(keptAlignments, isRevPerSeq):
    begPerSeq = [sys.maxsize] * len(isRevPerSeq)
    endPerSeq = [0] * len(isRevPerSeq)
    for negScore, qry, ref, probCodes in keptAlignments:
        updateRange(begPerSeq, endPerSeq, isRevPerSeq, ref)
        updateRange(begPerSeq, endPerSeq, isRevPerSeq, qry)
    return begPerSeq, endPerSeq

def pairwiseAnchors(opts, minProbCode, alns, seqRanks, isRevPerSeq, begPerSeq):
    for negScore, qry, ref, probCodes in alns:
        if qry[0] > ref[0]:
            qry, ref = ref, qry
        refNum, refLen, refStrand, refBeg, refEnd, refAln = ref
        qryNum, qryLen, qryStrand, qryBeg, qryEnd, qryAln = qry
        refRank = seqRanks[refNum]
        qryRank = seqRanks[qryNum]
        if refRank < 1 or qryRank < 1:
            continue
        if isRevPerSeq[refNum] != (refStrand == "-"):
            refAln = refAln[::-1]
            qryAln = qryAln[::-1]
            probCodes = probCodes[::-1]
            refBeg = refLen - refEnd
            qryBeg = qryLen - qryEnd
        if not opts.all:
            refBeg -= begPerSeq[refNum]
            qryBeg -= begPerSeq[qryNum]
        qBeg = rBeg = size = 0
        for p, q, r in zip(probCodes, qryAln, refAln):
            if q != "-":
                if r != "-":
                    if p >= minProbCode:
                        if qBeg + size < qryBeg or rBeg + size < refBeg:
                            if size:
                                yield qryRank, refRank, qBeg, rBeg, size
                            qBeg = qryBeg
                            rBeg = refBeg
                            size = 0
                        size += 1
                    refBeg += 1
                qryBeg += 1
            elif r != "-":
                refBeg += 1
        if size:
            yield qryRank, refRank, qBeg, rBeg, size

def prepareToAlign(opts, scores, sequences, tmpdir):
    optP = "-P" + str(opts.P)
    fwdMatrix, revMatrix, gapExistCost, gapExtendCost = scores

    fwdMat = os.path.join(tmpdir, "fwd.mat")
    with open(fwdMat, "w") as f:
        printScoreParameters(fwdMatrix, gapExistCost, gapExtendCost, f)

    revMat = os.path.join(tmpdir, "rev.mat")
    with open(revMat, "w") as f:
        printScoreParameters(revMatrix, gapExistCost, gapExtendCost, f)

    seqFile = os.path.join(tmpdir, "x.fa")
    with open(seqFile, "w") as f:
        printNumberedSequences(sequences, f)

    db = os.path.join(tmpdir, "db")
    if opts.u:
        cmd = "lastdb", "-c", optP, "-uRY" + str(opts.u), db, seqFile
    else:
        cmd = "lastdb", "-c", optP, "-uNEAR", "-W" + str(opts.W), db, seqFile
    logging.info(" ".join(cmd))
    subprocess.check_call(cmd)

    return fwdMat, revMat, seqFile, db

def addStrandAlignments(alignments, opts, db, seqFile, matFile, strandNum):
    optP = "-P" + str(opts.P)
    optD = "-D1e9"
    mOpt = "-m" + str(opts.m)
    zOpt = "-z" + str(opts.z) + "g"
    sOpt = "-s" + str(strandNum)
    cmd = ("lastal", sOpt, "-j4", optD, optP, mOpt, zOpt, "-p", matFile,
           db, seqFile)
    logging.info(" ".join(cmd))
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
                            universal_newlines=True)
    alignments.extend(alignmentInput(opts, proc.stdout))
    retcode = proc.wait()
    if retcode:
        raise subprocess.CalledProcessError(retcode, cmd)

def pairwiseAlignments(opts, fwdMat, revMat, seqFile, db):
    alignments = []
    addStrandAlignments(alignments, opts, db, seqFile, fwdMat, 1)
    addStrandAlignments(alignments, opts, db, seqFile, revMat, 0)
    alignments.sort()
    return alignments

def checkMafftVersion():
    cmd = "mafft", "--version"
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                         universal_newlines=True)
    out, err = p.communicate()
    if not err.startswith("v"):
        raise RuntimeError("can't get MAFFT version")
    logging.info("MAFFT " + err.rstrip())
    v = err.split()[0]
    minVersion = "v4.442"
    if len(v) <= len(minVersion) and v < minVersion:
        raise RuntimeError("MAFFT version must be >= " + minVersion)

def assemble(opts, scale, probMatrix, gapProbs, sequences, tmpdir):
    scores = alignmentScoresFromProbs(scale, probMatrix, gapProbs)
    fwdMatrix, revMatrix, gapExistCost, gapExtendCost = scores
    mafftGapOpts = mafftGapOptions(scores)
    tmpFileNames = prepareToAlign(opts, scores, sequences, tmpdir)

    alignments = pairwiseAlignments(opts, *tmpFileNames)
    layoutData = layoutOfSeqs(opts, len(sequences), alignments)
    dataPerSeq, alignmentOrder, keptAlignments = layoutData
    groupNumPerSeq, isRevPerSeq = zip(*dataPerSeq)

    numOfSeqsPerGroup = [0] * len(sequences)
    for i in groupNumPerSeq:
        numOfSeqsPerGroup[i] += 1
    maxSeqsPerGroup = max(numOfSeqsPerGroup)

    numOfGroups = len(sequences) - len(alignmentOrder)
    text = "groups (of sequences linked by pairwise alignments): {0}"
    logging.info(text.format(numOfGroups))

    if maxSeqsPerGroup < len(sequences):
        text = "using {0} out of {1} sequences (linked by pairwise alignments)"
        logging.warning(text.format(maxSeqsPerGroup, len(sequences)))

    keptGroupNum = numOfSeqsPerGroup.index(maxSeqsPerGroup)
    seqRanks = list(rankPerSeq(groupNumPerSeq, keptGroupNum))

    begPerSeq, endPerSeq = alignedRangePerSeq(keptAlignments, isRevPerSeq)

    outName = opts.out if opts.out else os.path.join(tmpdir, "mafft")

    with open(outName + ".fa", "w") as f:
        seqsZip = zip(seqRanks, sequences, isRevPerSeq, begPerSeq, endPerSeq)
        printSequencesWithFakeAlphabet(opts, seqsZip, f)

    with open(outName + ".tree", "w") as f:
        for i, j in alignmentOrder:
            iRank = seqRanks[i]
            jRank = seqRanks[j]
            if iRank > 0 and jRank > 0:
                print(iRank, jRank, 0.1, 0.1, file=f)

    minProbScore = int(math.ceil(-10 * math.log10(opts.prob)))
    minProbCode = chr(33 + minProbScore)
    logging.info("minProbScore: " + str(minProbScore))
    logging.info("minProbCode: " + minProbCode)

    with open(outName + ".pair", "w") as f:
        # the following groupby saves memory:
        for negScore, alns in itertools.groupby(keptAlignments, itemgetter(0)):
            p = pairwiseAnchors(opts, minProbCode, alns,
                                seqRanks, isRevPerSeq, begPerSeq)
            for k, v in itertools.groupby(sorted(p)):
                qryRank, refRank, qBeg, rBeg, size = k
                print(qryRank, refRank, qBeg+1, qBeg+size, rBeg+1, rBeg+size,
                      -negScore, file=f)

    with open(outName + ".mtx", "w") as f:
        printMafftParams(fwdMatrix, revMatrix, mafftGapOpts, f)

    if opts.out:
        return

    checkMafftVersion()

    cmd = ["mafft", "--amino"] + mafftGapOpts
    if not opts.verbose:
        cmd += ["--quiet"]
    if opts.mafft:
        cmd += opts.mafft.split()
    cmd += ["--aamatrix", outName+".mtx", "--treein", outName+".tree",
            "--anchors", outName+".pair", outName+".fa"]
    logging.info(" ".join(cmd))
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
                            universal_newlines=True)
    fastaLines = transliteratedFastaLines(proc.stdout)

    if opts.alignment:
        for i in fastaLines:
            print(i, end="")
    else:
        alignedSeqs = [nameAndSeqFromFastx(i) for i in fastxInput(fastaLines)]
        printConsensusSeq(opts, probMatrix, alignedSeqs)

def main(opts, trainFile, seqFile):
    logLevel = logging.INFO if opts.verbose else logging.WARNING
    logging.basicConfig(format="%(filename)s: %(message)s", level=logLevel)

    scale, probMatrix, gapProbs = parametersFromLastTrain(trainFile)
    probMatrix = list(matrixWithComplementSymmetricRows(probMatrix))

    fastx = fastxInput(openFile(seqFile))
    sequences = [nameAndSeqFromFastx(i) for i in fastx]

    okSymbols = "ACGT" "U" "MRWSYKVHDBN" "-"
    okSet = set(okSymbols.upper() + okSymbols.lower())
    for name, seq in sequences:
        diff = set(seq).difference(okSet)
        if diff:
            bad = "".join(sorted(diff))
            raise RuntimeError("not allowed in nucleotide sequences: " + bad)

    if opts.consensus:
        return printConsensusSeq(opts, probMatrix, sequences)

    if not sequences:
        raise RuntimeError("zero sequences")

    sequences = [(name, seq.replace("-", "")) for name, seq in sequences]

    tmpdir = tempfile.mkdtemp(prefix="lamassemble")
    try:
        assemble(opts, scale, probMatrix, gapProbs, sequences, tmpdir)
    finally:
        shutil.rmtree(tmpdir)

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message
    usage = "%prog [options] last-train.out sequences.fx > consensus.fa"
    descr = "Merge DNA sequences into a consensus sequence."
    op = optparse.OptionParser(usage=usage, description=descr)
    op.add_option("-a", "--alignment", action="store_true",
                  help="print an alignment, not a consensus")
    op.add_option("-c", "--consensus", action="store_true",
                  help="just make a consensus, of already-aligned sequences")
    op.add_option("-f", "--format", metavar="FMT", default="fasta", help=
                  "output format: fasta/fa or fastq/fq (default=%default)")
    op.add_option("-g", "--gap-max", metavar="G", type="float", default=50,
                  help="use alignment columns with <= G% gaps "
                  "(default=%default)")
    op.add_option("--end", action="store_true",
                  help="... including gaps past the ends of the sequences")
    op.add_option("-s", "--seq-min", metavar="S", default="1", help="omit "
                  "consensus flanks with < S sequences (default=%default)")
    op.add_option("-n", "--name", default="lamassembled",
                  help="name of the consensus sequence (default=%default)")
    op.add_option("-o", "--out", metavar="BASE",
                  help="just write MAFFT input files, named BASE.xxx")
    op.add_option("-p", "--prob", metavar="P", type="float", default=0.002,
                  help="use pairwise restrictions with error probability <= P "
                  "(default=%default)")
    op.add_option("-d", "--diagonal-max", type="int", default=1000,
                  metavar="D", help="max change in alignment diagonal "
                  "between pairwise alignments (default=%default)")
    op.add_option("-v", "--verbose", action="count", default=0,
                  help="show progress messages")
    op.add_option("--all", action="store_true",
                  help="use all of each sequence, not just aligning part")
    op.add_option("--mafft", metavar="ARGS",
                  help="additional arguments for MAFFT")

    og = optparse.OptionGroup(op, "LAST options")
    og.add_option("-P", type="int", default=1,
                  help="number of parallel threads (default=%default)")
    og.add_option("-u", metavar="RY", type="int", help=
                  "use ~1 per this many initial matches")
    og.add_option("-W", type="int", default=19, help="use minimum positions "
                  "in length-W windows (default=%default)")
    og.add_option("-m", type="int", default=5, help=
                  "max initial matches per query position (default=%default)")
    og.add_option("-z", type="int", default=30,
                  help="max gap length (default=%default)")
    op.add_option_group(og)

    opts, args = op.parse_args()
    if len(args) == 2:
        try:
            main(opts, *args)
        except RuntimeError as e:
            prog = os.path.basename(sys.argv[0])
            sys.exit(prog + ": error: " + str(e))
    else:
        op.print_help()
