import glob
import os
import subprocess
import textwrap
import time
from typing import Tuple

import numpy as np

from mist.lib import MistLogger as logging
from mist.lib.MistShellUtils import shell_command_log_stderr, tail


def choose_bam_idx_type(bam_file: str, size_threshold: int = 500_000_000) -> str:
    """
    Check if any contig in the BAM file header is larger than a given threshold (500Mb by default).

    Args:
        bam_file: Name of the bam file to get the contig details from.
        size_threshold: Size threshold to determine whether to use CSI index. Default is 500Mb.

    Returns:
        index_ext: 'bai' or 'csi' based on whether any contig exceeds the threshold.
    """
    result = subprocess.run(['samtools', 'view', '-H', bam_file], capture_output=True, text=True)

    if result.returncode != 0:
        logging.error(f'Error extracting BAM header: {result.stderr}')
        raise RuntimeError(f'Error extracting BAM header from {bam_file}')

    for line in result.stdout.splitlines():
        if line.startswith('@SQ'):
            length = int([x.split(':')[1] for x in line.split('\t') if x.startswith('LN:')][0])
            if length > size_threshold:
                logging.info(f'Contig larger than {size_threshold}: {line}')
                return 'csi'

    # Default to bai if no contig exceeds the size threshold
    return 'bai'


def get_bwa_samtools_threads(threads: int) -> Tuple[int, int]:
    """Get distribution of bwa-mem2 and samtools threads given total number of available threads.

    Args:
        threads: Total no. of threads available for use
    Returns:
        (bwa_mem2_threads, samtools_threads)
    """
    # Observed data based on local testing
    threads_avail_X_obs = np.array([8, 16, 24, 36, 48, 64, 96, 128])
    samtools_assigned_Y_obs = np.array([1, 1, 2, 4, 6, 6, 10, 16])

    # use interpolation func to get assignment on given thread count
    samtools_threads = int(np.floor(np.interp([threads], threads_avail_X_obs, samtools_assigned_Y_obs))[0])
    # substract for bwa assignment
    bwa_threads = max(threads - samtools_threads, 1)

    return (bwa_threads, samtools_threads)


def select_bwa_mem2_bin():
    """Select appropriate bwa-mem2 binary to use.
    bwa-mem2 fails to chose an appropriate binary on certain processors (e.g. AMD EPYC).
    In such cases manually try binaries compilied for using older SIMD instruction sets.
    """
    binaries = [
        'bwa-mem2',
        # order: newer to older generation
        'bwa-mem2.avx512bw',
        'bwa-mem2.avx2',
        'bwa-mem2.avx',
        'bwa-mem2.sse42',
        'bwa-mem2.sse41',
    ]
    logging.info('Checking bwa-mem2 compatibility')
    for b in binaries:
        try:
            subprocess.run(
                [b, 'version'],
                check=True,
                capture_output=True,
                text=True,
            )
            logging.info(f'{b} works')
            return b
        except subprocess.CalledProcessError as e:
            indented_stderr = textwrap.indent(e.stderr, prefix='    ')
            logging.info(f'{b} failed, stderr:\n{indented_stderr}')

    raise Exception('bwa-mem2 cannot be run on this processor')


def run_bwa_mem2(ref_index: str, threads: int, library_name: str, extra_mapping_params: str = ''):
    """Align compressed fastqs from Qualcy with bwa-mem2.
    Note: The fastqs will be deleted post alignment.

    Args:
        ref_index: Reference index for bwa-mem2.
        threads: No. of threads to use for alignment + bam generation.
        library_name: Name of the library being aligned. Used to infer R1/R2 fastq file names from Qualcy.
        extra_mapping_params: Additional command line mapping parameters to use with the bwa-mem2
    """
    # Get reference index path for bwa-mem2.
    # ------------------------------------------------------------------------------
    logging.info('Aligning cell barcode annotated fastqs using bwa-mem2')
    bwa_mem2_index = os.path.join(ref_index, 'bwa-mem2_index')
    reference_index_prefix = [f for f in glob.glob(f'{bwa_mem2_index}/*') if f.endswith('.ann')]
    if len(reference_index_prefix) > 1:
        raise ValueError(
            f'Found multiple index files in the bwa-mem2 index directory: {reference_index_prefix} ! '
            'Please ensure there is only 1.'
        )
    reference_index_prefix = reference_index_prefix[0].removesuffix('.ann')

    input_fastqs = (
        f'{library_name}_R1.annotated.fastq.gz',
        f'{library_name}_R2.annotated.fastq.gz',
    )
    bwa_log_file = f'{library_name}.bwa-mem2.alignment.log.txt'
    samtools_log_file = f'{library_name}.samtools-fixmate.log.txt'

    # Run the alignment command.
    # ------------------------------------------------------------------------------
    bwa_mem2_bin = select_bwa_mem2_bin()
    bwa_threads, samtools_threads = get_bwa_samtools_threads(threads)
    alignment_cmd = (
        f'{bwa_mem2_bin} mem -C -t {bwa_threads} {extra_mapping_params} '
        f'{reference_index_prefix} {" ".join(input_fastqs)} 2>{library_name}.bwa-mem2.alignment.log.txt '
        ' | '
        f'samtools fixmate -m -u --input-fmt-option \'filter= ! flag.supplementary\' -@ {samtools_threads} - {library_name}.fixmate.bam 2>{library_name}.samtools-fixmate.log.txt'
    )
    try:
        shell_command_log_stderr(alignment_cmd, shell=True)
        for f in input_fastqs:
            os.remove(f)
    except Exception as e:
        logging.exception(e)
        logging.error(f'    bwa-mem2 logfile contents:\n    {tail(100, bwa_log_file)}')
        logging.error(f'    samtools logfile contents:\n    {tail(100, samtools_log_file)}')
        raise Exception(f'Alignment with bwa-mem2 failed for library: {library_name}') from e


def sort_and_generate_metrics(threads: int, mito_contigs: str, library_name: str):
    """Generate Alignment related metrics and sort the intermediate bam - both will be run simultaneously.

    Args:
        threads: No. of threads to use for alignment + bam generation.
        mito_contigs: Comma separated string of the mitochondrial contigs. Used to infer nuclear and mitochondrial read pair counts.
        library_name: Name of the library being aligned. Used to infer R1/R2 fastq file names from Qualcy.
    """
    # Does not help to have more than 2 threads
    atac_metrics_threads = 2 if threads >= 4 else 1
    # Does not help much to have more than 24 threads, can keep memory low by setting too alloting too many cores
    samtools_threads = min(24, max(threads - atac_metrics_threads, 1))

    logging.info('Sorting aligned bam and generating related metrics.')
    log_files = [
        f'{library_name}.samtools-sort.log.txt',
        f'{library_name}.AtacMetrics.log.txt',
    ]
    fixmate_bam = f'{library_name}.fixmate.bam'
    idx_ext = choose_bam_idx_type(fixmate_bam)

    samtools_pid = shell_command_log_stderr(
        f'samtools sort -m 1G -@ {samtools_threads} --write-index -o {library_name}.sort.bam##idx##{library_name}.sort.bam.{idx_ext} {fixmate_bam} 2> {log_files[0]}'
        ' && '
        f'touch {library_name}.sort.bam.bai',  # to avoid samtools warning of index timestamp being older than bam.
        backgroundProcess=True,
        shell=True,
        preexec_fn=os.setsid,
    )
    metrics_pid = shell_command_log_stderr(
        f'AtacMetrics {fixmate_bam} {mito_contigs} {atac_metrics_threads} > {library_name}.alignment.metrics.txt 2> {log_files[1]}',
        backgroundProcess=True,
        shell=True,
        preexec_fn=os.setsid,
    )
    background_processes = [samtools_pid, metrics_pid]
    # Check for completion.
    while True:
        if all(p.poll() is not None for p in background_processes):
            break
        time.sleep(0.5)
    # Check return code and raise exception if any command(s) failed.
    msgs = []
    for idx, p in enumerate(background_processes):
        if p.returncode != 0:
            msgs.append((f'Command Failed: {p.args}', tail(50, log_files[idx])))

    for m, tail_50 in msgs:
        logging.error(f'\n    {m}:\n    {tail_50}\n')

    for m, _ in msgs:
        raise Exception(m)

    # Clean up intermediate bam.
    os.remove(fixmate_bam)
