#!/usr/bin/python

import argparse
import collections
import csv
import itertools
import multiprocessing
import os
from os import path

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.linalg import eigh
from scipy.optimize import minimize_scalar
from scipy.spatial.distance import cdist
from scipy.stats import norm
from sklearn.cluster import AgglomerativeClustering
from sklearn.mixture import GaussianMixture

from mist.apps import cell_label_noise as run_second_derivative
from mist.apps import utils
from mist.lib import MistLogger as logging

CELL_CLUST = 1
NONCELL_CLUST = 0
CONFIDENCE_CONST = 1.645  # 90% confidence interval

RANDOM_SEED = 2023
MIN_CELL_NUM = 50
FRIP_THRESHOLD = 0.1


def cli():
    """
    ATAC Call Cells
    """
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--atac-cell-calling-data-file',
        dest='atac_cell_calling_data_file',
        required=False,
        help='CSV file with cell label, ATAC transposase sites in peaks, '
        'fraction of transposase sites in peaks, and total fragments',
    )
    parser.add_argument(
        '--use-dynamic-FRiP-threshold',
        dest='use_dynamic_FRiP_threshold',
        required=False,
        action='store_true',
        help='If True (non-default), the FRiP threshold is dynamically determined. If False (default), the FRiP threshold is 0.1.',
    )

    parser.add_argument(
        '--expected-cell-count',
        dest='expected_cell_count',
        required=False,
        help='The number of expected cell number. Used to determine the log odds ratio threshold',
    )

    parser.add_argument(
        '--run-base-name',
        dest='run_base_name',
        required=True,
        help='The prefix of the file names of cell calling plots',
    )

    parser.add_argument(
        '--output-header', dest='output_header', required=False, help='The header of the output csv file'
    )

    args = parser.parse_args().__dict__
    return args


def call_cells(
    atac_data_per_cell, use_dynamic_FRiP_threshold, expected_cell_count, run_base_name, output_header, outdir
):
    if (outdir is not None) and (not os.path.exists(outdir)):
        os.mkdir(outdir)

    stats_file = os.path.join(outdir, f'{run_base_name}_CellLabelAlgorithmStats.csv')

    atac_putative_cell_list = []

    ## Transforming transposase_sites_in_peaks counts into log scale
    num_cell = atac_data_per_cell.shape[0]

    if num_cell == 0:
        # create the CellLabelAlgorithmStats.csv
        write_cell_label_stats(stats_file, output_header, None)
        # return empty cell list
        return atac_putative_cell_list

    atac_data_per_cell['log_counts'] = np.log10(atac_data_per_cell['transposase_sites_in_peaks'])
    # atac_data_per_cell = atac_data_per_cell.drop(columns=['transposase_sites_in_peaks'])

    ## Select cells with FRiP >= FRIP_THRESHOLD
    if use_dynamic_FRiP_threshold:
        if num_cell < MIN_CELL_NUM:
            FRiP_threshold_used = FRIP_THRESHOLD
            logging.info(
                f'     The threshold for fraction of transposase sites in peaks cannot be determined dynamically \n'
                f'     because the number of cell is less than {MIN_CELL_NUM} \n'
                f'     {FRiP_threshold_used} is  used instead'
            )
        else:
            FRiP_threshold_used = choose_FRiP_threshold(atac_data_per_cell)
            atac_data_per_cell = atac_data_per_cell.drop(columns=['is_landmark', 'cluster_FRiP'])
    else:
        FRiP_threshold_used = FRIP_THRESHOLD

    logging.info(f'     The threshold of fraction of transposase sites in peaks: {FRiP_threshold_used}')

    atac_data_per_cell['cellCluster'] = [NONCELL_CLUST] * num_cell
    atac_data_per_cell.loc[
        atac_data_per_cell['fraction_transposase_sites_in_peaks'] >= FRiP_threshold_used, 'cellCluster'
    ] = CELL_CLUST

    ## Determine the threshold of log_counts
    logCount_perCellATAC_highFRiP = atac_data_per_cell.loc[
        atac_data_per_cell['cellCluster'] == CELL_CLUST, ['log_counts']
    ].values
    logCount_cutoff = select_logCount_cutoff(logCount_perCellATAC_highFRiP, run_base_name, outdir)
    if np.isnan(logCount_cutoff) == True:
        logging.info('     No cell was found in atac because any count cutoff was not valid.')
        # create the CellLabelAlgorithmStats.csv
        write_cell_label_stats(stats_file, output_header, None)
        return atac_putative_cell_list
    else:
        logging.info(
            f'     Cells with less than {int(10**logCount_cutoff)} transposase sites in peaks are filtered out'
        )

    ## Filter out cells using logCount_cutoff
    atac_data_per_cell.loc[atac_data_per_cell['log_counts'] <= logCount_cutoff, 'cellCluster'] = NONCELL_CLUST

    numNoisyCell = atac_data_per_cell[atac_data_per_cell['cellCluster'] == NONCELL_CLUST].shape[0]
    logging.info(
        f'     The number of cells filtered out either by low fraction of transposase sites in peaks or low transposase sites in peaks: {numNoisyCell}'
    )
    numPutativeCell_after_filtering = get_number_putative_cell(atac_data_per_cell['cellCluster'])
    logging.info(
        f'     After filtering out cells with low fraction of transposase sites in peaks or low transposase sites in peaks: {numPutativeCell_after_filtering} cells left'
    )

    if numPutativeCell_after_filtering == 0:
        # create the CellLabelAlgorithmStats.csv
        write_cell_label_stats(stats_file, output_header, None)
        return atac_putative_cell_list

    ## Perform the first GMM using logCount
    atac_data_per_cell = perform_GMM1_cellCalling(atac_data_per_cell, run_base_name, outdir)

    numPutativeCell_gmm1 = get_number_putative_cell(atac_data_per_cell['cellCluster_gmm1'])
    logging.info(f'     After the first cell calling step: {numPutativeCell_gmm1} putative cells found')

    if numPutativeCell_gmm1 == 0:
        atac_data_per_cell['cellType'] = 'Non cells'
        num_cells_inferred_by_basic_algo = 0
        I_overlap = None
    else:
        ## Perform the second GMM using logCount and FRiP
        atac_data_per_cell, I_overlap = perform_GMM2_cellCalling(atac_data_per_cell, run_base_name, outdir)

        if I_overlap:
            num_cells_inferred_by_basic_algo = numPutativeCell_gmm1
        else:
            numPutativeCell_gmm2 = get_number_putative_cell(atac_data_per_cell['cellCluster_gmm2'])
            logging.info(f'     After the second cell calling step: {numPutativeCell_gmm2} putative cells found')
            # the number of putative cell after 2nd GMM is stored as the number of cells from basic algorithm
            num_cells_inferred_by_basic_algo = numPutativeCell_gmm2

        ## Perform the third GMM to reduce either false negatives
        ## (when I_overlap is True) or false positives (when I_overlap is False)
        if num_cells_inferred_by_basic_algo == 0:
            atac_data_per_cell['cellType'] = 'Non cells'
        else:
            atac_data_per_cell = perform_refit_cellCalling(
                atac_data_per_cell, expected_cell_count, I_overlap, run_base_name, outdir
            )

            numPutativeCell_refit = get_number_putative_cell(atac_data_per_cell['cellCluster_refit'])

            if I_overlap:
                logging.info(
                    f'     After filtering out false positive cells: {numPutativeCell_refit} putative cells found'
                )
            else:
                logging.info(
                    f'     After recovering false negative cells: {numPutativeCell_refit} putative cells found'
                )

            if numPutativeCell_refit == 0:
                atac_data_per_cell['cellType'] = 'Non cells'
            else:
                atac_data_per_cell.loc[atac_data_per_cell['cellCluster_refit'] == NONCELL_CLUST, 'cellType'] = (
                    'Non cells'
                )
                atac_data_per_cell.loc[atac_data_per_cell['cellCluster_refit'] == CELL_CLUST, 'cellType'] = 'Cells'

    atac_putative_cell_list = list(atac_data_per_cell.index[atac_data_per_cell['cellType'] == 'Cells'])

    ## make algorithm stat file
    gmm_stats = get_gmm_stats(
        atac_data_per_cell,
        FRiP_threshold_used,
        logCount_cutoff,
        numNoisyCell,
        num_cells_inferred_by_basic_algo,
        I_overlap,
    )
    cell_num_stats = get_cell_num_stats(atac_data_per_cell)
    Tn5_cutsites_in_peaks_mean_stats, Tn5_cutsites_in_peaks_median_stats = get_Tn5_cutsites_in_peaks_stats(
        atac_data_per_cell
    )
    FRiP_mean_stats, FRiP_median_stats = get_FRiP_stats(atac_data_per_cell)

    algo_stats = gmm_stats
    algo_stats += cell_num_stats
    algo_stats += Tn5_cutsites_in_peaks_mean_stats
    algo_stats += Tn5_cutsites_in_peaks_median_stats
    algo_stats += FRiP_mean_stats
    algo_stats += FRiP_median_stats

    write_cell_label_stats(stats_file, output_header, algo_stats)

    return atac_putative_cell_list


def choose_FRiP_threshold(atac_data_per_cell):
    np.random.seed(RANDOM_SEED)

    ### Selecting landmark cells
    num_cell = atac_data_per_cell.shape[0]
    num_landmark_cells = 200
    if num_cell < num_landmark_cells:
        num_landmark_cells = num_cell

    landmark_cell_index = np.random.choice(list(atac_data_per_cell.index), size=num_landmark_cells, replace=False)
    atac_data_per_cell['is_landmark'] = False
    atac_data_per_cell.loc[landmark_cell_index, 'is_landmark'] = True

    # adding outlier cells with high counts
    added_landmark_cell_index = atac_data_per_cell.sort_values(['fragments'], ascending=False).iloc[0:100].index
    atac_data_per_cell.loc[added_landmark_cell_index, 'is_landmark'] = True

    # adding outlier cells with high FRiP
    added_landmark_cell_index = (
        atac_data_per_cell.sort_values(['transposase_sites_in_peaks'], ascending=False).iloc[0:100].index
    )
    atac_data_per_cell.loc[added_landmark_cell_index, 'is_landmark'] = True

    ### Clustering
    atac_data_per_cell_landmark = atac_data_per_cell.loc[atac_data_per_cell['is_landmark']].copy()

    dist_matrix = cdist(
        atac_data_per_cell_landmark[['fragments', 'transposase_sites_in_peaks']],
        atac_data_per_cell_landmark[['fragments', 'transposase_sites_in_peaks']],
        'cosine',
    )
    cluster_landmark = AgglomerativeClustering(n_clusters=2, metric='precomputed', linkage='average').fit(dist_matrix)
    atac_data_per_cell_landmark.loc[:, 'cluster_FRiP'] = cluster_landmark.labels_

    mean_FRiP_FRiP_clust0 = np.mean(
        atac_data_per_cell_landmark.loc[
            atac_data_per_cell_landmark['cluster_FRiP'] == 0, 'fraction_transposase_sites_in_peaks'
        ]
    )
    mean_FRiP_FRiP_clust1 = np.mean(
        atac_data_per_cell_landmark.loc[
            atac_data_per_cell_landmark['cluster_FRiP'] == 1, 'fraction_transposase_sites_in_peaks'
        ]
    )

    if mean_FRiP_FRiP_clust0 > mean_FRiP_FRiP_clust1:
        cell_cluster_FRiP = 0
    else:
        cell_cluster_FRiP = 1

    ### assign the cluster for the rest of the cells
    atac_data_per_cell['cluster_FRiP'] = -1
    atac_data_per_cell.loc[atac_data_per_cell_landmark.index, 'cluster_FRiP'] = atac_data_per_cell_landmark[
        'cluster_FRiP'
    ]

    # calculate  the distance between waypoints and non-waypoints
    atac_data_per_cell_non_landmark = atac_data_per_cell[atac_data_per_cell['is_landmark'] == False]
    dist = cdist(
        atac_data_per_cell_landmark[['fragments', 'transposase_sites_in_peaks']],
        atac_data_per_cell_non_landmark[['fragments', 'transposase_sites_in_peaks']],
        'cosine',
    )

    ## select k nearest way-points
    k_neighbor = 11
    k_neighbor_Idx = np.argsort(dist, axis=0)[:k_neighbor,]

    k_neighbor_clust = []
    for i in range(len(k_neighbor_Idx)):
        this_cluster = atac_data_per_cell_landmark.iloc[k_neighbor_Idx[i]]['cluster_FRiP'].values
        k_neighbor_clust.append(list(this_cluster))
    k_neighbor_clust = np.array(k_neighbor_clust)

    ## make this in parallel
    num_cpu = multiprocessing.cpu_count() - 5

    if num_cell < num_cpu:
        num_cpu = num_cell

    if num_cpu < 1:
        num_cpu = 1

    task_decide_clust_FRiP = np.array_split(k_neighbor_clust, num_cpu, axis=1)  # splitting by cells
    with multiprocessing.Pool(num_cpu) as pool:
        FRiP_cluster_result = pool.map_async(decide_clust_FRiP_for_non_landmarks, task_decide_clust_FRiP).get()

    # unpacking
    FRiP_cluster_result = list(itertools.chain(*FRiP_cluster_result))
    atac_data_per_cell.loc[atac_data_per_cell['is_landmark'] == False, 'cluster_FRiP'] = FRiP_cluster_result

    FRiP_threshold = np.min(
        atac_data_per_cell.loc[
            atac_data_per_cell['cluster_FRiP'] == cell_cluster_FRiP, 'fraction_transposase_sites_in_peaks'
        ]
    )

    return FRiP_threshold


def decide_clust_FRiP_for_non_landmarks(subtask_decide_clust_FRiP):
    this_clusterResult = []
    this_numCell = subtask_decide_clust_FRiP.shape[1]

    for i in range(this_numCell):
        clust_freq_dict = dict(collections.Counter(subtask_decide_clust_FRiP[:, i]))
        max_freq = max(clust_freq_dict.values())
        key_list = [key for key, val in clust_freq_dict.items() if val == max_freq]

        this_clusterResult.extend(key_list)

    return this_clusterResult


def select_logCount_cutoff(logCount_perCellATAC_highFRiP, run_base_name, outdir):
    logCount_min = max([1, np.percentile(logCount_perCellATAC_highFRiP, 1)])
    logCount_max = np.percentile(logCount_perCellATAC_highFRiP, 99)

    logCount_cutoff_arr = np.round(np.arange(logCount_min, logCount_max, 0.1), 1)

    result_logCount_cutoff = []
    weight_diff_delta_arr = []
    prev_weight_diff = None
    for logCount_cutoff in logCount_cutoff_arr:
        this_logCount_perCellATAC_highFRiP = logCount_perCellATAC_highFRiP[
            np.where(logCount_perCellATAC_highFRiP > logCount_cutoff)
        ]

        if len(this_logCount_perCellATAC_highFRiP) < 10:
            break

        gm = GaussianMixture(n_components=2, covariance_type='full', random_state=RANDOM_SEED).fit(
            this_logCount_perCellATAC_highFRiP.reshape(-1, 1)
        )

        # calculating confidence interval confidence interval calculation
        clust0_std = np.sqrt(gm.covariances_[0][0][0])
        clust0_mean = gm.means_[0][0]
        clust0_CI_low = clust0_mean - CONFIDENCE_CONST * clust0_std
        clust0_CI_high = clust0_mean + CONFIDENCE_CONST * clust0_std

        clust1_std = np.sqrt(gm.covariances_[1][0][0])
        clust1_mean = gm.means_[1][0]
        clust1_CI_low = clust1_mean - CONFIDENCE_CONST * clust1_std
        clust1_CI_high = clust1_mean + CONFIDENCE_CONST * clust1_std

        if clust1_CI_high > clust0_CI_high:
            weight_diff = gm.weights_[0] - gm.weights_[1]
            if clust1_CI_low > clust0_CI_low:
                I_consider = True
            else:
                I_consider = False
        else:
            weight_diff = gm.weights_[1] - gm.weights_[0]
            if clust1_CI_low < clust0_CI_low:
                I_consider = True
            else:
                I_consider = False

        if prev_weight_diff is not None:
            weight_diff_delta = weight_diff - prev_weight_diff
            weight_diff_delta_arr.append(weight_diff_delta)
        prev_weight_diff = weight_diff

        result_logCount_cutoff.append([logCount_cutoff, I_consider, weight_diff])

    result_logCount_cutoff = pd.DataFrame(
        result_logCount_cutoff, columns=['logCount_cutoff', 'I_consider', 'weight_diff']
    )

    # adding zero for the last cutoff
    weight_diff_delta_arr.append(0)
    result_logCount_cutoff['weight_diff_delta'] = weight_diff_delta_arr

    # Adding stability score, which is the number of consecutive
    # negative changes in weight diff after each logCount cutoff value.
    # This stability score suggests how stably the GMM finds a non-cell cluster.
    num_logCount_cutoff = result_logCount_cutoff.shape[0]
    stability_score = [0] * num_logCount_cutoff
    for i in range(num_logCount_cutoff):
        for j in range(i, num_logCount_cutoff):
            this_weight_diff_delta = result_logCount_cutoff.iloc[j]['weight_diff_delta']
            if this_weight_diff_delta < 0:
                stability_score[i] += 1
            else:
                break
    result_logCount_cutoff['stability_score'] = stability_score

    logCount_cutoff_toConsider = result_logCount_cutoff.loc[result_logCount_cutoff['I_consider'] == True].sort_values(
        by=['stability_score', 'weight_diff', 'logCount_cutoff'], ascending=False
    )

    if len(logCount_cutoff_toConsider) == 0:
        selected_log_Tn5_cutoff = np.nan
    else:
        selected_log_Tn5_cutoff = logCount_cutoff_toConsider.iloc[0].logCount_cutoff

    # making a diagnostic plot
    sns.set_style('white')
    sns.set_context('paper')
    fig = plt.figure(figsize=(6, 6))

    sns.lineplot(data=result_logCount_cutoff, x='logCount_cutoff', y='weight_diff')
    sns.scatterplot(
        data=result_logCount_cutoff,
        x='logCount_cutoff',
        y='weight_diff',
        style='I_consider',
        markers={True: 'o', False: 'X'},
    )
    if np.isnan(selected_log_Tn5_cutoff) == False:
        plt.axvline(x=selected_log_Tn5_cutoff, color='red')
    plt.tick_params(labelsize=10)
    plt.title(run_base_name)
    plt.savefig(os.path.join(outdir, f'{run_base_name}_weight_diff_logCount_threshold.png'), dpi=300)
    plt.close(fig)

    # making a diagnostic plot
    sns.set_style('white')
    sns.set_context('paper')
    fig = plt.figure(figsize=(6, 6))

    sns.lineplot(data=result_logCount_cutoff, x='logCount_cutoff', y='stability_score')
    sns.scatterplot(
        data=result_logCount_cutoff,
        x='logCount_cutoff',
        y='stability_score',
        style='I_consider',
        markers={True: 'o', False: 'X'},
    )
    if np.isnan(selected_log_Tn5_cutoff) == False:
        plt.axvline(x=selected_log_Tn5_cutoff, color='red')
    plt.tick_params(labelsize=10)
    plt.title(run_base_name)
    plt.savefig(os.path.join(outdir, f'{run_base_name}_stability_score_logCount_threshold.png'), dpi=300)
    plt.close(fig)

    return selected_log_Tn5_cutoff


def perform_GMM1_cellCalling(atac_data_per_cell, run_base_name, outdir):
    ### select cells that pass FRIP and logCount threshold
    atac_data_per_cell_forGMM1 = atac_data_per_cell.loc[atac_data_per_cell['cellCluster'] == CELL_CLUST]

    gmm1 = GaussianMixture(n_components=2, covariance_type='full', random_state=RANDOM_SEED).fit(
        atac_data_per_cell_forGMM1['log_counts'].values.reshape(-1, 1)
    )
    gmm1_cluster_result = gmm1.fit_predict(atac_data_per_cell_forGMM1['log_counts'].values.reshape(-1, 1))

    ### determine which clust label is for the cell cluster
    clust0_CI_high = gmm1.means_[0][0] + CONFIDENCE_CONST * np.sqrt(gmm1.covariances_[0][0][0])
    clust1_CI_high = gmm1.means_[1][0] + CONFIDENCE_CONST * np.sqrt(gmm1.covariances_[1][0][0])
    if clust0_CI_high > clust1_CI_high:
        gmm1_nonCellClust = 1
    else:
        gmm1_nonCellClust = 0

    ### determine log_counts cutoff based on gmm1 result
    gmm1_nonCell_idx = list(np.where(gmm1_cluster_result == gmm1_nonCellClust)[0])
    max_logCount_nonCell = atac_data_per_cell_forGMM1.iloc[gmm1_nonCell_idx]['log_counts'].max()

    atac_data_per_cell['cellCluster_gmm1'] = atac_data_per_cell['cellCluster'].copy()
    atac_data_per_cell.loc[atac_data_per_cell['log_counts'] <= max_logCount_nonCell, 'cellCluster_gmm1'] = NONCELL_CLUST

    plot_GMM1_result(gmm1, atac_data_per_cell_forGMM1['log_counts'], run_base_name, outdir)

    return atac_data_per_cell


def perform_GMM2_cellCalling(atac_data_per_cell, run_base_name, outdir):
    ### select cells that pass GMM1
    cellIdx_forGMM2 = atac_data_per_cell.index[atac_data_per_cell['cellCluster_gmm1'] == CELL_CLUST]
    atac_data_per_cell_forGMM2 = atac_data_per_cell.loc[cellIdx_forGMM2]

    gmm2 = GaussianMixture(n_components=2, covariance_type='full', random_state=RANDOM_SEED).fit(
        atac_data_per_cell_forGMM2[['log_counts', 'fraction_transposase_sites_in_peaks']]
    )
    gmm2_cluster_result = gmm2.fit_predict(
        atac_data_per_cell_forGMM2[['log_counts', 'fraction_transposase_sites_in_peaks']]
    )

    I_overlap = ellipsoid_intersection_test(gmm2)
    plot_GMM2_fit(
        gmm2,
        atac_data_per_cell[['log_counts', 'fraction_transposase_sites_in_peaks', 'cellCluster_gmm1']],
        run_base_name,
        outdir,
    )

    ### If the confidence intervals don't overlap, choose the cell cluster
    if not I_overlap:
        ### determine which cluster is the cell clust
        if gmm2.means_[0][0] > gmm2.means_[1][0]:
            gmm2_cellClust = 0
            gmm2_nonCellClust = 1
        else:
            gmm2_cellClust = 1
            gmm2_nonCellClust = 0

        atac_data_per_cell['cellCluster_gmm2'] = atac_data_per_cell['cellCluster_gmm1'].copy()

        gmm2_nonCellIdx = cellIdx_forGMM2[np.where(gmm2_cluster_result == gmm2_nonCellClust)[0]]
        atac_data_per_cell.loc[gmm2_nonCellIdx, 'cellCluster_gmm2'] = NONCELL_CLUST

        plot_GMM2_result(
            gmm2,
            atac_data_per_cell[['log_counts', 'fraction_transposase_sites_in_peaks', 'cellCluster_gmm2']],
            run_base_name,
            outdir,
        )

    return atac_data_per_cell, I_overlap


def perform_refit_cellCalling(atac_data_per_cell, expected_cell_count, I_overlap, run_base_name, outdir):
    """
    Refitting through 2 dimensional GMM.

    When I_overlap is false:
        the result is sometimes too optimistic, so false positive cells are
        filtered out through refitting.

    When I_overlap is true:
        selecting the cell cluster from GMM2 sometimes results in stringent
        cell calling, so false negative cells will be recovered through this
        refitting process
    """

    cellIdx_forRefit = atac_data_per_cell.index[atac_data_per_cell['cellCluster'] == CELL_CLUST]
    atac_data_per_cell_forRefit = atac_data_per_cell.loc[cellIdx_forRefit]

    if I_overlap:
        clustResult_toUse = 'cellCluster_gmm1'
    else:  # I_overlap == False
        clustResult_toUse = 'cellCluster_gmm2'

    ### Refit the cell cluster
    atac_data_per_cell_forRefit_1 = atac_data_per_cell_forRefit.loc[
        atac_data_per_cell_forRefit[clustResult_toUse] == CELL_CLUST
    ]
    refit_1 = GaussianMixture(n_components=1, covariance_type='full', random_state=RANDOM_SEED).fit(
        atac_data_per_cell_forRefit_1[['log_counts', 'fraction_transposase_sites_in_peaks']]
    )
    refit_1_log_ll = refit_1.score_samples(
        atac_data_per_cell_forRefit[['log_counts', 'fraction_transposase_sites_in_peaks']]
    )

    ### Refit the nonCell cluster
    atac_data_per_cell_forRefit_2 = atac_data_per_cell_forRefit.loc[
        atac_data_per_cell_forRefit[clustResult_toUse] == NONCELL_CLUST
    ]
    refit_2 = GaussianMixture(n_components=1, covariance_type='full', random_state=RANDOM_SEED).fit(
        atac_data_per_cell_forRefit_2[['log_counts', 'fraction_transposase_sites_in_peaks']]
    )
    refit_2_log_ll = refit_2.score_samples(
        atac_data_per_cell_forRefit[['log_counts', 'fraction_transposase_sites_in_peaks']]
    )

    ### Calculate log odds ratio
    log_odds_ratio = refit_1_log_ll - refit_2_log_ll

    # The shift parameter is necessary as the second
    # derivative method uses log transformation
    shift_parameter = np.min(log_odds_ratio) - 1
    log_odds_ratio_shifted = np.sort(log_odds_ratio - shift_parameter)[::-1]

    putative_cell_num = run_second_derivative.main(
        log_odds_ratio_shifted,
        run_base_name,
        0,
        output_header=None,
        outdir=outdir,
        do_plot=True,
        reads_per_cell_minimum=0,
        expected_cell_count=expected_cell_count,
        no_json_csv_output=True,
    )
    if putative_cell_num == 0:
        log_odds_ratio_cutoff = 0
    else:
        log_odds_ratio_cutoff = np.min(log_odds_ratio_shifted[:putative_cell_num]) + shift_parameter

    plot_hist_log_odds_ratio(log_odds_ratio, log_odds_ratio_cutoff, run_base_name, outdir)

    I_include = log_odds_ratio >= log_odds_ratio_cutoff

    atac_data_per_cell['I_include'] = np.array([False] * atac_data_per_cell.shape[0])
    atac_data_per_cell.loc[cellIdx_forRefit, 'I_include'] = I_include

    atac_data_per_cell['cellCluster_refit'] = atac_data_per_cell[clustResult_toUse].copy()
    if I_overlap:  # remove false positives
        cellIdx_toCorrect = atac_data_per_cell.index[
            (atac_data_per_cell[clustResult_toUse] == CELL_CLUST) & (atac_data_per_cell['I_include'] == False)
        ]
        atac_data_per_cell['cellCluster_refit'] = atac_data_per_cell[clustResult_toUse].copy()
        atac_data_per_cell.loc[cellIdx_toCorrect, 'cellCluster_refit'] = NONCELL_CLUST
    else:  # I_overlap == False, add false negatives
        cellIdx_toCorrect = atac_data_per_cell.index[
            (atac_data_per_cell[clustResult_toUse] == NONCELL_CLUST) & (atac_data_per_cell['I_include'] == True)
        ]
        atac_data_per_cell['cellCluster_refit'] = atac_data_per_cell[clustResult_toUse].copy()
        atac_data_per_cell.loc[cellIdx_toCorrect, 'cellCluster_refit'] = CELL_CLUST

    plot_Refit_result(
        atac_data_per_cell[['log_counts', 'fraction_transposase_sites_in_peaks', 'cellCluster_refit']],
        refit_1,
        refit_2,
        run_base_name,
        outdir,
    )

    return atac_data_per_cell


def ellipsoid_intersection_test(gmm2, tau=CONFIDENCE_CONST):
    sigma_1 = gmm2.covariances_[0]
    sigma_2 = gmm2.covariances_[1]
    mu_1 = gmm2.means_[0]
    mu_2 = gmm2.means_[1]

    lambdas, Phi = eigh(sigma_1, b=sigma_2)
    v_squared = np.dot(Phi.T, mu_1 - mu_2) ** 2
    res = minimize_scalar(K_function, bracket=[0.0, 0.5, 1.0], args=(lambdas, v_squared, tau))

    return res.fun >= 0


def K_function(s, lambdas, v_squared, tau):
    return 1.0 - (1.0 / tau**2) * np.sum(v_squared * ((s * (1.0 - s)) / (1.0 + s * (lambdas - 1.0))))


def plot_GMM1_result(gmm1, log_counts, run_base_name, outdir):
    ### plot the fit of the GMM1
    xaxis_shared = np.arange(np.round(log_counts.min(), 1), np.round(log_counts.max(), 1), 0.1)

    xaxis_for_gmm1 = []
    for i in range(len(xaxis_shared) - 1):
        new_x_axis = np.round((xaxis_shared[i] + xaxis_shared[i + 1]) / 2, 2)
        xaxis_for_gmm1.append(new_x_axis)
    yaxis_fitted = gmm1.weights_[0] * norm.pdf(
        xaxis_for_gmm1, gmm1.means_[0][0], np.sqrt(gmm1.covariances_[0][0][0])
    ) + gmm1.weights_[1] * norm.pdf(xaxis_for_gmm1, gmm1.means_[1][0], np.sqrt(gmm1.covariances_[1][0][0]))

    ### scaling
    max_hist_counts = np.max(np.histogram(log_counts.values, bins=xaxis_shared)[0])
    scaler = max_hist_counts / max(yaxis_fitted)
    scaled_pdf = yaxis_fitted * scaler

    sns.set_style('white')
    sns.set_context('paper')
    fig = plt.figure(figsize=(8, 6))

    sns.histplot(data=log_counts, bins=xaxis_shared)
    plt.plot(xaxis_for_gmm1, scaled_pdf, linewidth=3, color='red')
    plt.tick_params(labelsize=10)
    plt.title(run_base_name)
    plt.xlabel('Log Counts (Tn5 cutsites in peaks)')
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f'{run_base_name}_GMM1_fit.png'), dpi=300)
    plt.close(fig)


def plot_GMM2_fit(gmm2, atac_data_per_cell_3col, run_base_name, outdir):
    sns.set_style('white')
    sns.set_context('paper')
    fig = plt.figure(figsize=(8, 6))

    color_dict = dict({CELL_CLUST: 'darkorange', NONCELL_CLUST: 'steelblue'})
    ax = sns.scatterplot(
        data=atac_data_per_cell_3col,
        x='log_counts',
        y='fraction_transposase_sites_in_peaks',
        palette=color_dict,
        hue='cellCluster_gmm1',
        s=10,
        alpha=0.3,
    )
    make_ellipses(gmm2, ax=ax)
    plt.plot(
        gmm2.means_[0][0], gmm2.means_[0][1], marker='o', markersize=10, markeredgecolor='black', markerfacecolor='red'
    )
    plt.plot(
        gmm2.means_[1][0], gmm2.means_[1][1], marker='o', markersize=10, markeredgecolor='black', markerfacecolor='red'
    )
    plt.title(run_base_name)
    plt.tick_params(labelsize=10)
    sns.move_legend(ax, 'upper left', bbox_to_anchor=(1.01, 0.5), ncol=1, frameon=False)
    ax.set_aspect(np.diff(ax.get_xlim()) / np.diff(ax.get_ylim()))
    plt.xlabel('Log Counts (Tn5 cutsites in peaks)')
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f'{run_base_name}_GMM2_fit.png'), dpi=300)
    plt.close(fig)


def plot_GMM2_result(gmm2, atac_data_per_cell_3col, run_base_name, outdir):
    sns.set_style('white')
    sns.set_context('paper')
    fig = plt.figure(figsize=(8, 6))

    color_dict = dict({CELL_CLUST: 'darkorange', NONCELL_CLUST: 'steelblue'})
    ax = sns.scatterplot(
        data=atac_data_per_cell_3col,
        x='log_counts',
        y='fraction_transposase_sites_in_peaks',
        palette=color_dict,
        hue='cellCluster_gmm2',
        s=10,
        alpha=0.3,
    )
    make_ellipses(gmm2, ax=ax)
    plt.plot(
        gmm2.means_[0][0], gmm2.means_[0][1], marker='o', markersize=10, markeredgecolor='black', markerfacecolor='red'
    )
    plt.plot(
        gmm2.means_[1][0], gmm2.means_[1][1], marker='o', markersize=10, markeredgecolor='black', markerfacecolor='red'
    )
    plt.title(run_base_name)
    plt.tick_params(labelsize=10)
    sns.move_legend(ax, 'upper left', bbox_to_anchor=(1.01, 0.5), ncol=1, frameon=False)
    ax.set_aspect(np.diff(ax.get_xlim()) / np.diff(ax.get_ylim()))
    plt.xlabel('Log Counts (Tn5 cutsites in peaks)')
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f'{run_base_name}_GMM2_result.png'), dpi=300)
    plt.close(fig)


def plot_Refit_result(atac_data_per_cell_3col, refit_1, refit_2, run_base_name, outdir):
    numPutativeCell = len(atac_data_per_cell_3col.index[atac_data_per_cell_3col['cellCluster_refit'] == CELL_CLUST])

    sns.set_style('white')
    sns.set_context('paper')
    fig = plt.figure(figsize=(8, 6))

    color_dict = dict({CELL_CLUST: 'darkorange', NONCELL_CLUST: 'steelblue'})  # palette=color_dict

    ax = sns.scatterplot(
        data=atac_data_per_cell_3col,
        x='log_counts',
        y='fraction_transposase_sites_in_peaks',
        palette=color_dict,
        hue='cellCluster_refit',
        s=10,
        alpha=0.3,
    )
    make_ellipses(refit_1, ax)
    make_ellipses(refit_2, ax)
    plt.plot(
        refit_1.means_[0][0],
        refit_1.means_[0][1],
        marker='o',
        markersize=10,
        markeredgecolor='black',
        markerfacecolor='red',
    )
    plt.plot(
        refit_2.means_[0][0],
        refit_2.means_[0][1],
        marker='o',
        markersize=10,
        markeredgecolor='black',
        markerfacecolor='red',
    )
    plt.title(f'{run_base_name} ({numPutativeCell} putative cells)')
    plt.tick_params(labelsize=10)
    sns.move_legend(ax, 'upper left', bbox_to_anchor=(1.01, 0.5), ncol=1, frameon=False)
    ax.set_aspect(np.diff(ax.get_xlim()) / np.diff(ax.get_ylim()))
    plt.xlabel('Log Counts (Tn5 cutsites in peaks)')
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f'{run_base_name}_Refit.png'), dpi=300)
    plt.close(fig)


def plot_hist_log_odds_ratio(log_odds_ratio, log_odds_ratio_cutoff, run_base_name, outdir):
    sns.set_style('white')
    sns.set_context('paper')
    fig = plt.figure(figsize=(8, 6))

    sns.histplot(data=log_odds_ratio)
    plt.axvline(x=log_odds_ratio_cutoff, color='red')
    plt.tick_params(labelsize=10)
    plt.title(run_base_name)
    plt.xlabel('log odds ratio')
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f'{run_base_name}_hist_log_odds_ratio.png'), dpi=300)
    plt.clf()
    plt.close(fig)


def make_ellipses(gmm2, ax, nstd=CONFIDENCE_CONST):
    if gmm2.n_components == 1:
        colors = ['navy']
    else:
        colors = ['navy', 'turquoise']

    for n, color in enumerate(colors):
        covariances = gmm2.covariances_[n][:2, :2]

        v, w = np.linalg.eigh(covariances)
        u = w[0] / np.linalg.norm(w[0])
        angle = np.arctan2(u[1], u[0])
        angle = 180 * angle / np.pi  # convert to degrees
        v = 2.0 * nstd * np.sqrt(v)
        ell = mpl.patches.Ellipse(gmm2.means_[n, :2], v[0], v[1], angle=180 + angle, color=color)
        ell.set_clip_box(ax.bbox)
        ell.set_alpha(0.5)
        ax.add_artist(ell)
        ax.set_aspect('equal', 'datalim')


def get_number_putative_cell(cellClust_eachStep):
    numPutativeCell = cellClust_eachStep[cellClust_eachStep == CELL_CLUST].shape[0]

    return numPutativeCell


def get_gmm_stats(
    atac_data_per_cell, FRiP_threshold_used, logCount_cutoff, numNoisyCell, num_cells_inferred_by_basic_algo, I_overlap
):
    """
    Report   FRiP_Threshold,
              Tn5_Cutsites_in_Peaks_Threshold,
              Num_Cells_Low_FRiP_or_Low_Tn5_Cutsites_in_Peaks,
              Num_Cells_Inferred_by_Basic_Algo,
              Num_Cells_Flagged_Noise_by_Refit,
              Num_Additional_Cells_Detected_by_Refit,
              Did_2nd_GMM_Used
    """

    if I_overlap is None:  # numPutativeCell_gmm1 == 0:
        did_2nd_GMM_used = 'FALSE'
        num_cells_flagged_noise_by_refit = 0
        num_additional_cells_detected_by_refit = 0
    else:
        if I_overlap:
            did_2nd_GMM_used = 'FALSE'
            num_cells_flagged_noise_by_refit = atac_data_per_cell.index[
                (atac_data_per_cell['cellCluster_gmm1'] == CELL_CLUST) & (atac_data_per_cell['I_include'] == False)
            ].shape[0]
            num_additional_cells_detected_by_refit = 0
        else:
            did_2nd_GMM_used = 'TRUE'
            num_cells_flagged_noise_by_refit = 0
            num_additional_cells_detected_by_refit = atac_data_per_cell.index[
                (atac_data_per_cell['cellCluster_gmm2'] == NONCELL_CLUST) & (atac_data_per_cell['I_include'] == True)
            ].shape[0]

    gmm_stats = [
        FRiP_threshold_used,
        int(10**logCount_cutoff),
        numNoisyCell,
        num_cells_inferred_by_basic_algo,
        num_cells_flagged_noise_by_refit,
        num_additional_cells_detected_by_refit,
        did_2nd_GMM_used,
    ]

    return gmm_stats


def get_cell_num_stats(atac_data_per_cell):
    """
    Calculate  num_putative_cells,
               num_noise_cells,
               ratio_noise:putative_cells
    """
    cell_num_counts = atac_data_per_cell['cellType'].value_counts()

    if 'Cells' not in cell_num_counts:
        cell_num_stats = [0, cell_num_counts['Non cells'], 0]
    else:
        cell_num_stats = [
            cell_num_counts['Cells'],
            cell_num_counts['Non cells'],
            cell_num_counts['Non cells'] / cell_num_counts['Cells'],
        ]

    return cell_num_stats


def get_Tn5_cutsites_in_peaks_stats(atac_data_per_cell):
    """
    Calculate
        1. Tn5_cutsites_in_peaks_mean_stats : Tn5_cutsites_in_peaks_per_putative_cell,
                                              Tn5_cutsites_in_peaks_per_noise_cell,
                                              ratio_mean_Tn5_cutsites_in_peaks_per_putative:noise_cell
        2. Tn5_cutsites_in_peaks_median_stats: Tn5_cutsites_in_peaks_per_putative_cell,
                                               Tn5_cutsites_in_peaks_per_noise_cell,
                                               log10_ratio_median_Tn5_cutsites_in_peaks_per_putative:noise_cell
    """

    atac_data_per_cell_cells = atac_data_per_cell.loc[
        atac_data_per_cell['cellType'] == 'Cells', ['transposase_sites_in_peaks']
    ].copy()
    atac_data_per_cell_nonCells = atac_data_per_cell.loc[
        atac_data_per_cell['cellType'] == 'Non cells', ['transposase_sites_in_peaks']
    ].copy()

    Tn5_cutsites_in_peaks_mean_nonCells = np.mean(atac_data_per_cell_nonCells['transposase_sites_in_peaks'])
    if len(atac_data_per_cell_cells) == 0:
        Tn5_cutsites_in_peaks_mean_stats = [0, Tn5_cutsites_in_peaks_mean_nonCells, 0]
    else:
        Tn5_cutsites_in_peaks_mean_stats = [
            np.mean(atac_data_per_cell_cells['transposase_sites_in_peaks']),
            Tn5_cutsites_in_peaks_mean_nonCells,
        ]
        Tn5_cutsites_in_peaks_mean_stats.append(
            Tn5_cutsites_in_peaks_mean_stats[0] / Tn5_cutsites_in_peaks_mean_stats[1]
        )

    Tn5_cutsites_in_peaks_median_nonCells = np.median(atac_data_per_cell_nonCells['transposase_sites_in_peaks'])
    if len(atac_data_per_cell_cells) == 0:
        Tn5_cutsites_in_peaks_median_stats = [0, Tn5_cutsites_in_peaks_median_nonCells, 0]
    else:
        Tn5_cutsites_in_peaks_median_stats = [
            np.median(atac_data_per_cell_cells['transposase_sites_in_peaks']),
            Tn5_cutsites_in_peaks_median_nonCells,
        ]
        Tn5_cutsites_in_peaks_median_stats.append(
            np.log10(Tn5_cutsites_in_peaks_median_stats[0] / Tn5_cutsites_in_peaks_median_stats[1])
        )

    return Tn5_cutsites_in_peaks_mean_stats, Tn5_cutsites_in_peaks_median_stats


def get_FRiP_stats(atac_data_per_cell):
    """
    Calculate
        1. FRiP_mean_stats : FRiP_per_putative_cell,
                             FRiP_per_noise_cell,
                             ratio_mean_FRiP_per_putative:noise_cell
        2. FRiP_median_stats: FRiP_per_putative_cell,
                              FRiP_per_noise_cell,
                              ratio_median_FRiP_per_putative:noise_cell
    """

    atac_data_per_cell_cells = atac_data_per_cell.loc[
        atac_data_per_cell['cellType'] == 'Cells', ['fraction_transposase_sites_in_peaks']
    ].copy()
    atac_data_per_cell_nonCells = atac_data_per_cell.loc[
        atac_data_per_cell['cellType'] == 'Non cells', ['fraction_transposase_sites_in_peaks']
    ].copy()

    FRiP_mean_nonCells = np.mean(atac_data_per_cell_nonCells['fraction_transposase_sites_in_peaks'])
    if len(atac_data_per_cell_cells) == 0:
        FRiP_mean_stats = [0, FRiP_mean_nonCells, 0]
    else:
        FRiP_mean_stats = [np.mean(atac_data_per_cell_cells['fraction_transposase_sites_in_peaks']), FRiP_mean_nonCells]
        FRiP_mean_stats.append(FRiP_mean_stats[0] / FRiP_mean_stats[1])

    FRiP_median_nonCells = np.median(atac_data_per_cell_nonCells['fraction_transposase_sites_in_peaks'])
    if len(atac_data_per_cell_cells) == 0:
        FRiP_median_stats = [0, FRiP_median_nonCells, 0]
    else:
        FRiP_median_stats = [
            np.median(atac_data_per_cell_cells['fraction_transposase_sites_in_peaks']),
            FRiP_median_nonCells,
        ]
        FRiP_median_stats.append(FRiP_median_stats[0] / FRiP_median_stats[1])

    return FRiP_mean_stats, FRiP_median_stats


def write_cell_label_stats(stats_file, output_header, algo_stats):
    """Output the stats calculated from noise algorithm to the file [run_base_name]_CellLabelAlgorithmStats.csv"""

    if algo_stats is None:
        algo_stats = [
            0.0,
            0.0,
            0,
            0,
            0,
            0,
            'FALSE',
            0,
            0,
            0.0,
            0.0,
            0.0,
            0.0,
            0,
            0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
            0.0,
        ]
    else:
        algo_stats = utils.clean_up_decimals(algo_stats)

    with open(stats_file, 'w') as fout:
        if output_header:
            for line in output_header:
                fout.write(f'{line}\n')
        rb = csv.writer(fout)
        rb.writerow(['#GMM_Algorithm_Stats#'])
        rb.writerow(
            [
                'FRiP_Threshold',
                'Tn5_Cutsites_in_Peaks_Threshold',
                'Num_Cells_Low_FRiP_or_Low_Tn5_Cutsites_in_Peaks',
                'Num_Cells_Inferred_by_Basic_Algo',
                'Num_Cells_Flagged_Noise_by_Refit',
                'Num_Additional_Cells_Detected_by_Refit',
                'Did_2nd_GMM_Used',
            ]
        )
        rb.writerow(
            [algo_stats[0], algo_stats[1], algo_stats[2], algo_stats[3], algo_stats[4], algo_stats[5], algo_stats[6]]
        )
        rb.writerow(['#Counts#'])
        rb.writerow(['Putative_Cells', 'Noise_Cells', 'Ratio_Noise:Putative_Cells'])
        rb.writerow([algo_stats[7], algo_stats[8], algo_stats[9]])
        rb.writerow(['#Averages_of_Tn5_Cutsites_in_Peaks#'])
        rb.writerow(
            [
                'Tn5_Cutsites_in_Peaks_per_Putative_Cell',
                'Tn5_Cutsites_in_Peaks_per_Noise_Cell',
                'Ratio_Mean_Tn5_Cutsites_in_Peaks_per_Putative:Noise_Cell',
            ]
        )
        rb.writerow([algo_stats[10], algo_stats[11], algo_stats[12]])
        rb.writerow(['#Medians_of_Tn5_Cutsites_in_Peaks#'])
        rb.writerow(
            [
                'Tn5_Cutsites_in_Peaks_per_Putative_Cell',
                'Tn5_Cutsites_in_Peaks_per_Noise_Cell',
                'Log10_Ratio_Median_Tn5_Cutsites_in_Peaks_per_Putative:Noise_Cell',
            ]
        )
        rb.writerow([algo_stats[13], algo_stats[14], algo_stats[15]])
        rb.writerow(['#Averages_FRiP#'])
        rb.writerow(['FRiP_per_Putative_Cell', 'FRiP_per_Noise_Cell', 'Ratio_Mean_FRiP_per_Putative:Noise_Cell'])
        rb.writerow([algo_stats[16], algo_stats[17], algo_stats[18]])
        rb.writerow(['#Medians_FRiP#'])
        rb.writerow(['FRiP_per_Putative_Cell', 'FRiP_per_Noise_Cell', 'Ratio_Median_FRiP_per_Putative:Noise_Cell'])
        rb.writerow([algo_stats[19], algo_stats[20], algo_stats[21]])


def main():
    args = cli()

    if not args['atac_cell_calling_data_file']:
        logging.warning('Selected putative cell calling using ATAC, but no ATAC_data_per_cell file provided')
        return

    atac_data_per_cell = pd.read_csv(
        args['atac_cell_calling_data_file'],
        delimiter=',',
        index_col=0,
        names=['transposase_sites_in_peaks', 'fraction_transposase_sites_in_peaks', 'fragments'],
        dtype={
            'transposase_sites_in_peaks': np.int32,
            'fraction_transposase_sites_in_peaks': np.float32,
            'fragments': np.int32,
        },
    ).sort_values(by='transposase_sites_in_peaks', ascending=False)

    cell_label_filtering_dir = path.join(os.getcwd(), 'Cell_Label_Filtering')
    if not path.exists(cell_label_filtering_dir):
        os.mkdir(cell_label_filtering_dir)

    call_cells(
        atac_data_per_cell=atac_data_per_cell,
        use_dynamic_FRiP_threshold=args['use_dynamic_FRiP_threshold'],
        expected_cell_count=args['expected_cell_count'],
        run_base_name=args['run_base_name'],
        output_header=args['output_header'],
        outdir=cell_label_filtering_dir,
    )


if __name__ == '__main__':
    main()
