#!/usr/bin/env python3

import os
import re
import sys
import subprocess
import itertools

def replace_queue_statement_in_file(file_name, new):
    with open(file_name) as file:
        lines = file.readlines()
    with open(file_name, "w") as file:
        in_queue_statement = False
        for line in lines:
            if in_queue_statement:
                if line == ")\n":
                    in_queue_statement = False
                continue

            if re.match("^queue", line, re.I):
                file.write("{}\n".format(new))
                if line.endswith("(\n"):
                    in_queue_statement = True
            else:
                file.write(line)


lmkeys = [
           "max_idle",
           "max_materialize",
           "materialize_max_idle",
           "JobMaterializeLimit",
           "JobMaterializeMaxIdle",
         ]
LATE_MATERIALIZATION_KEYS = [ name.casefold() for name in lmkeys ]

def copy(submit_file_name, dag_submit_file_name):
    with open(submit_file_name) as submit_file:
        lines = submit_file.readlines()

    with open(dag_submit_file_name, "w") as dag_submit_file:
        for line in lines:
            # Disable late materialization.
            if "=" in line:
                (key, value) = line.split("=")
                if key.strip().casefold() in LATE_MATERIALIZATION_KEYS:
                    continue
            dag_submit_file.write(line)


#
# Given 'QUEUE <COUNT> <VARS> FROM <FILE>', writes a set of lines of the
# form 'JOB <INDEX> <dag-submit-file-name>' and 'VARS <INDEX> <VAR>="<VALUE>"'
# for each VAR in VARS, with each VALUE take from the corresponding column
# and line INDEX in FILE.
#
def convert_queue_from_file(queue_statement, submit_file_name):
    dag_submit_file_name = "generated.submit"

    match = re.match("^[Qq]ueue (\d+)( ([^ ]+)?( \[[\d:]+\])? from (.*))?$", queue_statement)
    if not match:
        print("Internal error, aborting.", file=sys.stderr)
        sys.exit(-4)

    count = int(match.group(1))
    vars_string = match.group(3)
    slice_string = match.group(4)
    item_data_file_name = match.group(5)

    var_names = []
    if vars_string is not None:
        var_names = vars_string.split(',')

    item_data = []
    item_indices = None
    if item_data_file_name is not None:
        with open(item_data_file_name) as item_data_file:
            item_data_lines = item_data_file.readlines()
        os.remove(item_data_file_name)

        for item_index in range(len(item_data_lines)):
            item_data.append({})
            line_data = item_data_lines[item_index].split('\x1F')
            for var_index in range(len(var_names)):
                item_data[item_index][var_names[var_index]] = line_data[var_index].strip()

        if slice_string is not None:
            match = re.match(" \[(\d+)?:(\d+)?:(\d+)?\]", slice_string)
            (start, stop, step) = match.group(1, 2, 3)
            if start is not None:
                start = int(start)
            if stop is not None:
                stop = int(stop)
            if step is not None:
                step = int(step)
            slice = itertools.islice(range(len(item_data_lines)), start, stop, step)
            item_indices = [i for i in slice]
        else:
            item_indices = range(0, len(item_data_lines))
    else:
        item_data = ["one-single-job"]
        item_indices = [0]

    dag_file_name = "generated.dag"
    with open(dag_file_name, "w") as dag_file:
        job_index = 0
        for item_index in item_indices:
            item = item_data[item_index]
            for i in range(count):
                dag_file.write(' JOB {} {}\n'.format(job_index, dag_submit_file_name))

                # This next line doesn't work in any version of HTCondor yet.
                dag_file.write('VARS {} {}="{}"\n'.format(job_index, "Process", job_index))
                dag_file.write('VARS {} {}="{}"\n'.format(job_index, "ProcID", job_index))

                dag_file.write('VARS {} {}="{}"\n'.format(job_index, "Step", i))

                dag_file.write('VARS {} {}="{}"\n'.format(job_index, "ItemIndex", item_index))
                dag_file.write('VARS {} {}="{}"\n'.format(job_index, "Row", item_index))

                if var_names is not None:
                    for var in var_names:
                        dag_file.write('VARS {} {}="{}"\n'.format(job_index, var, item[var]))
                dag_file.write('\n')
                job_index += 1

    copy(submit_file_name, dag_submit_file_name)
    replace_queue_statement_in_file(dag_submit_file_name, "queue 1")


#
# Given a submit file with a 'complicated' QUEUE statement, return a
# 'simple' QUEUE FROM FILE that can fed into convert_queue_from_file().
#
# Uses the condor_submit command-line tool to generate the statement and
# the corresponding FILE.
#
def simplify_complex_queue_statement(submit_file_name):
    generated_digest_filename = "generated.digest"
    generated_classad_filename = "generated.classad"
    cp = subprocess.run(
        ['condor_submit', '-factory',
         '-dry-run:digest', generated_classad_filename,
         '-digest', generated_digest_filename,
         submit_file_name,
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
    )
    os.remove(generated_classad_filename)

    if cp.returncode != 0:
        # The extra newline is deliberate.
        print("condor_submit failed to generate item data for {}, output follows.\n".format(submit_file_name), file=sys.stderr)
        print(cp.stdout, file=sys.stderr)
        return None

    with open(generated_digest_filename) as generated_digest:
        lines = generated_digest.readlines()
        for line in lines:
            if line.startswith("Queue"):
                os.remove(generated_digest_filename)
                return line

    return None


def main(argv):
    if len(argv) != 2:
        print("condor_submit_as_dag takes exactly one argument, the name of the submit file", file=sys.stderr)
        return -1

    submit_file_name = argv[1]
    with open(submit_file_name) as submit_file:
        lines = submit_file.readlines()

    queue_statement = None
    for line in lines:
        if line.startswith("queue"):
            queue_statement = line
            break
    else:
        print("no queue statement found in submit file", file=sys.stderr)
        return -2

    if queue_statement == "queue":
        print("refusing to a create a single-node DAG", file=sys.stderr)
        return -3

    if not re.match("^queue \d+$", queue_statement, re.I):
        queue_statement = simplify_complex_queue_statement(submit_file_name)
        if queue_statement is None:
            return -4

    convert_queue_from_file(queue_statement, submit_file_name)
    return 0

if __name__ == "__main__":
    sys.exit(main(sys.argv))
