"""
ATAC Call Peaks
"""

import argparse
import multiprocessing
import os
import shutil
import tempfile
from pathlib import Path

import mist.lib.MistLogger as logging
from mist.lib.MistShellUtils import shell_command_log_stderr


def cli():
    """
    ATAC Call Peaks
    """
    parser = argparse.ArgumentParser()

    parser.add_argument('--run-name', required=True, help='The run name for this experiment')
    parser.add_argument(
        '--bed-fp', action='store', required=True, type=Path, help='An ATAC deduplicated alignment BED file'
    )
    parser.add_argument(
        '--genome_size', action='store', type=str, default='hs', help='Genome size, %default by default for human'
    )
    parser.add_argument(
        '--bp-shift', action='store', type=str, default='-100', help='MACS2 "--shift" parameter, %default by default'
    )
    parser.add_argument(
        '--bp-ext-size',
        action='store',
        type=str,
        default='200',
        help='MACS2 "--extsize" parameter, %default by default',
    )

    args = parser.parse_args().__dict__
    return args


def call_peaks(
    run_name,
    bed_fp,
    contig_sizes,
    threads=None,
    genome_size='hs',
    bp_shift='-100',
    bp_ext_size='200',
):
    """
    ATAC: Call Peaks

    Args:
        run_name:        Run name for the experiment
        bed_fp:          ATAC deduplicated alignment BED file
        contig_sizes:    Dict for sanity checking peak edges
        genome_size:     MACS2 --gsize parameter
        bp_shift:        MACS2 --shift parameter
        bp_ext_size:     MACS2 --extsize parameter
    """

    if threads is None:
        threads = multiprocessing.cpu_count()
    output_peaks_bgzip = f'{run_name}_ATAC_Peaks.bed.gz'
    output_peaks_tabix = f'{run_name}_ATAC_Peaks.bed.gz.tbi'

    # list of final outputs to return
    final_outputs = [
        output_peaks_bgzip,
        output_peaks_tabix,
    ]

    output_dir = 'macs2_peaks'
    tmp_dir = tempfile.mkdtemp()

    logging.info('ATAC: Calling Peaks...')
    logging.info(f'Using {threads} CPU threads...')
    # TODO: should any of these params be customizable?
    call_peaks_cmd = (
        f'macs3 callpeak '
        f'-t {bed_fp} '
        f'--format BED '
        f'--name {run_name} '
        f'--gsize {genome_size} '
        f'--qval 0.01 '
        f'--tempdir {tmp_dir} '
        f'--nomodel '
        f'--keep-dup all '
        f'--shift {bp_shift} '
        f'--extsize {bp_ext_size} '
        f'--outdir {output_dir}'
    )
    shell_command_log_stderr(call_peaks_cmd, shell=True)

    logging.info('Sanity check peaks BED file...')
    peaks_file_path = f'{output_dir}/{run_name}_peaks.narrowPeak'
    sanity_file_path = f'{output_dir}/{run_name}_peaks.sane'

    # Create a temporary AWK script file to avoid command line length limits
    with tempfile.NamedTemporaryFile(mode='w', suffix='.awk', delete=False) as awk_script_file:
        awk_script_path = awk_script_file.name

        # Write the AWK script to the temporary file
        awk_script_content = """BEGIN {
    FS = "\\t"; OFS = "\\t"; # Set input and output field separators to tab
    # Define contig lengths directly in the script
"""

        # Add contig lengths directly to the AWK script
        for contig, length in contig_sizes.items():
            # Escape any special characters in contig names
            escaped_contig = contig.replace('\\', '\\\\').replace('"', '\\"')
            awk_script_content += f'    contig_lengths["{escaped_contig}"] = {length};\n'

        awk_script_content += """}
{
    contig_length = contig_lengths[$1] + 0; # Retrieve contig length and force numeric conversion
    if ($2 < 0) { $2 = 0; } # Correct start if less than 0
    if ($3 > contig_length) { $3 = contig_length; } # Correct end if greater than contig length (BED file is zero-based)
    print; # Print the (potentially modified) line
}"""

        awk_script_file.write(awk_script_content)

    try:
        # Execute the AWK command using the script file
        awk_command = f"awk -f '{awk_script_path}' '{peaks_file_path}' > '{sanity_file_path}'"
        shell_command_log_stderr(awk_command, shell=True)
    finally:
        # Clean up the temporary AWK script file
        os.unlink(awk_script_path)

    logging.info('bgzip peaks BED file...')
    shell_command_log_stderr(
        f'bgzip -f {sanity_file_path} -@ {threads} -c > {output_peaks_bgzip}',
        shell=True,
    )

    logging.info('tabix peaks BED file...')
    shell_command_log_stderr(
        f'tabix -p bed {output_peaks_bgzip}',
        shell=True,
    )

    """
    # TODO: remove temporary files?
    remove_files = [
        f'{output_dir}/{library_name}_peaks.xls',
        f'{output_dir}/{library_name}_summits.bed',
        ]

    for fp in remove_files:
        os.remove(fp)
    """

    # remove tmp directory
    shutil.rmtree(tmp_dir)

    logging.info('ATAC: Done Calling Peaks')

    return final_outputs


def main():
    """Main method to call ATAC Peaks"""
    call_peaks(**cli())


if __name__ == '__main__':
    main()
