#!/usr/bin/env Rscript

library(Signac)
library(Seurat)
library(rtracklayer)
library(jsonlite)

source(file.path(.Platform$file.sep, Sys.getenv(c("RHAPSODY_HOME")), "rscript", "Mist_R_Logger.R"))

atac_cell_by_peak <- function(opt) {

    suppressWarnings({

        get_fragment_metrics <- TRUE

        # For local testing
        #opt$base_name <- "Local_Testing"
        #opt$fragments <- "/data/ATAC_Generate_Fragments/run_dir/AM053_3_S3_chr21.fragments.bed.gz"
        #opt$transposase_sites <- "/data/ATAC_Generate_Fragments/run_dir/AM053_3_S3_chr21.cutsites.bed.gz"
        #opt$peaks <- "/data/ATAC_Call_Peaks/run_dir/AM053_3_S3_chr21.peaks.bed.gz"
        #opt$gtf <- "/data/ATAC_Cell_by_Peak/gencodev29-20181205.gtf"

        required_inputs <- list(
            "run base name" = opt$base_name,
            "ATAC fragments file" = opt$fragments,
            "ATAC transposase cut sites file" = opt$transposase_sites,
            "ATAC peaks file" = opt$peaks,
            "GTF file" = opt$gtf
        )

        for (input_name in names(required_inputs)) {
            if (is.na(required_inputs[[input_name]])) {
                log_error("The {input_name} is a required input")
                stop()
            }
        }

        # extract gene annotations from gtf
        log_info("Loading gene annotations from gtf...")
        gene_annotations <- rtracklayer::import(opt$gtf)
        # This becomes unnecessary if we add gene_biotype to the reference gtf file, but is necessary for compatibility with older references.
        if ("gene_type" %in% names(mcols(gene_annotations))) {
            gene_annotations$gene_biotype <- gene_annotations$gene_type
        }
        if (! "gene_biotype" %in% names(mcols(gene_annotations))) {
            log_error("Please ensure you have either the 'gene_type' or 'gene_biotype' attribute in your gtf.")
            stop()
        }
        if ("transcript_id" %in% names(mcols(gene_annotations))) {
            gene_annotations$tx_id <- gene_annotations$transcript_id
        }
        if (opt$map_gene_name_id) {
            log_info("Mapping gene names and ids in GTF...")
            gene_name_map <- create_gene_name_map(gene_annotations)
            log_info("Saving ATAC gene_name_map JSON...")
            write_json(gene_name_map, "atac.gene_info.json", auto_unbox=TRUE)
        }
        log_info("Gene annotations loaded... {head(gene_annotations)}")

        # Get the peaks granges
        log_info("Loading peaks file: {opt$peaks}")
        extraColTypes <- c(
            fold_change = "numeric",
            neglog10pvalue = "numeric",
            neglog10qvalue = "numeric",
            summit = "integer"
            )

        peaks_granges <- tryCatch(
            {
                rtracklayer::import(opt$peaks, extraCols = extraColTypes)
            },
            error = function(e) {
                log_warn("  Extra information of the peaks (fold change, p value, q value, summit of the peak) is not available. Reading only peak coordinates.")
                rtracklayer::import(opt$peaks)
            }
        )

        # If there is one peak, `CreateChromatinAssay` will fail
        if (length(peaks_granges) <= 1) {
            log_error("Too few peaks ({length(peaks_granges)} peak/peaks) were detected in the peak file.  Skipping all ATAC downstream steps...")
            stop()
        }

        log_info("Peaks: {head(peaks_granges)}")

        # Remove chromosomes from annotations GRanges that are not in peaks GRanges
        seqlevels(gene_annotations, pruning.mode = "coarse") <- seqlevels(peaks_granges)

        # Use Signac functions to read in and process information about transposase sites
        log_info("Loading transposase file: {opt$transposase_sites}")
        transposase_object <- CreateFragmentObject(opt$transposase_sites)
        #log_debug("Transposase Cut-Sites: {head(transposase_object)}")

        # Generate a cell-by-peak matrix from the transposase sites
        # FeatureMatrix() excludes cell barcodes that do not have any peak transposase sites
        log_info("Creating Feature Matrix from transposase sites and peaks...")
        transposase_peaks_counts <- FeatureMatrix(
            fragments = transposase_object,
            features = peaks_granges,
            sep = c("%-%", "%-%"),
            verbose = FALSE
        )
        # sorting the matrix using cell labels
        sorted_cells <- sprintf("%d", sort(as.numeric(colnames(transposase_peaks_counts))))
        transposase_peaks_counts <- transposase_peaks_counts[, sorted_cells]
        log_info("Done creating transposase by peaks matrix...")

        # The cell-by-peaks matrix and the transposase site fragments are combined to create a Seurat object
        # 'CreateChromatinAssay' fails with only one peak
        log_info("Creating chromatin assay from transposase by peaks matrix...")
        transposase_chromatin_assay <- CreateChromatinAssay(
            transposase_peaks_counts,
            ranges = peaks_granges,
            annotation = gene_annotations,
            fragments = transposase_object,
            sep = c("%-%", "%-%")
        )
        log_info("Done creating Chromatin Assay...")

        log_info("Creating Seurat Object...")
        seurat_object <- CreateSeuratObject(
            transposase_chromatin_assay,
            assay = "peaks",
            project = opt$base_name
        )
        #log_debug("Transposase Seurat Object: {head(seurat_object)}")

        # Update the seurat object with fragment counts from fragments file
        log_info("Creating Count Fragments Object from file: {opt$fragments}")
        total_fragments <- CountFragments(opt$fragments)
        rownames(total_fragments) <- total_fragments$CB
        #log_debug("Total Fragments: {head(total_fragments)}")

        log_info("Update Seurat Object with fragments...")
        seurat_object$fragments <- total_fragments[colnames(seurat_object), "frequency_count"]
        seurat_object$transposase_sites <- 2 * seurat_object$fragments
        seurat_object$read_pairs <- total_fragments[colnames(seurat_object), "reads_count"]
        seurat_object$NFR_fragments <- total_fragments[colnames(seurat_object), "nucleosome_free"]
        seurat_object$mononucleosomal_fragments <- total_fragments[colnames(seurat_object), "mononucleosomal"]
        seurat_object$ratio_fragments_to_readpairs <- seurat_object$fragments / seurat_object$read_pairs

        # Store total values that apply to the whole experiment
        fragment_metrics <- list(
            "Total_Fragments" = sum(total_fragments$frequency_count),
            "Total_Transposase_Sites" = 2 * sum(total_fragments$frequency_count),
            "Total_Nucleosome_Free_Fragments" = sum(total_fragments$nucleosome_free),
            "Total_Mononucleosomal_Fragments" = sum(total_fragments$mononucleosomal)
        )
        log_info("Fragment metrics... {fragment_metrics}")
        log_info("Fragment metrics names... {names(fragment_metrics)}")

        output_file <- paste0(opt$base_name, "_Total_Fragment_Metrics.json", sep = "")
        log_info("Saving ATAC Fragment Metrics JSON: {output_file}")
        write_json(fragment_metrics, output_file, pretty=TRUE, auto_unbox=TRUE)

        log_info("Get FRiP score for transposase sites in peaks...")
        # Get fraction of transposase sites in peaks
        seurat_object <- FRiP(
            object = seurat_object,
            assay = 'peaks',
            total.fragments = 'transposase_sites',
            col.name = 'fraction_transposase_sites_in_peaks'
        )
        log_info("Updated Seurat Object with fraction transposase sites in peaks: {head(seurat_object)}")

        # Copy the metadata so it has a more readable name
        seurat_object$transposase_sites_in_peaks <- seurat_object$nCount_peaks

        # Add the gene information to the object
        Annotation(seurat_object) <- gene_annotations

        log_info("Loading fragments file: {opt$fragments}")
        fragments_object <- CreateFragmentObject(
            opt$fragments,
            cells = Cells(seurat_object),
            max.lines = NULL
        )
        log_info("Fragments: {head(fragments_object)}")

        # Change the Fragments object from transposase sites to original fragments
        Fragments(seurat_object) <- NULL
        Fragments(seurat_object) <- fragments_object
        log_info("Updated Seurat Object with fragments file: {head(seurat_object)}")

        # Compute nucleosome signal score on fragments per cell
        seurat_object <- NucleosomeSignal(seurat_object, n = NULL)

        # Compute TSS enrichment score on fragments per cell
        seurat_object <- TSSEnrichment(seurat_object, fast = FALSE)

        if (get_fragment_metrics) {
            log_info("Getting fragments by peaks matrix...")
            # Compute Fraction Fragments overlapping Peaks also
            fragment_peaks_counts <- FeatureMatrix(
                fragments = fragments_object,
                features = peaks_granges,
                cells = Cells(seurat_object),
                sep = c("%-%", "%-%"),
                verbose = FALSE
            )
            log_info("Done creating fragments by peaks matrix...")

            log_info("Creating chromatin assay from fragments by peaks matrix...")
            fragment_chromatin_assay <- CreateChromatinAssay(
                fragment_peaks_counts,
                annotations = gene_annotations,
                fragments = fragments_object,
                sep = c("%-%", "%-%")
            )
            log_info("Done creating Chromatin Assay...")

            log_info("Creating Fragments Seurat Object...")
            fragment_seurat_object <- CreateSeuratObject(
                fragment_chromatin_assay,
                assay = "peaks",
                project = opt$base_name
            )
            log_info("Initial Fragments Seurat Object: {head(fragment_seurat_object)}")

            log_info("Get Fragments FRiP score...")
            fragment_seurat_object$fragments <- total_fragments[colnames(seurat_object), "frequency_count"]
            fragment_seurat_object <- FRiP(
                object = fragment_seurat_object,
                assay = 'peaks',
                total.fragments= 'fragments',
                col.name = 'fraction_fragments_overlapping_peaks'
            )
            #log_info("Updated Fragments Seurat Object with FRiP: {head(seurat_object)}")

            # Add the fragment metadata to the original seurat object
            seurat_object$fragments_overlapping_peaks <- fragment_seurat_object$nCount_peaks
            seurat_object$fraction_fragments_overlapping_peaks <- fragment_seurat_object$fraction_fragments_overlapping_peaks
        }

        cell_calling_file <- paste(seurat_object@project.name, "_ATAC_Cell_Calling_Data.csv", sep="")
        atac_data_columns <- c("transposase_sites_in_peaks", "fraction_transposase_sites_in_peaks", "fragments")
        log_info("Outputting the {atac_data_columns[1]}, {atac_data_columns[2]}, and {atac_data_columns[3]} for ATAC cell calling: {cell_calling_file}")
        write.table(seurat_object[[atac_data_columns]], file = cell_calling_file, quote = F, sep = ",", row.names=T, col.names = F)

        output_file <- paste0(seurat_object@project.name, "_Initial_Seurat.rds", sep = "")
        log_info("Writing ATAC Initial Seurat object to disk: {output_file}")
        saveRDS(seurat_object, file = output_file)
    })
}

create_gene_name_map <- function(gene_annotations) {
    #' Create a named list mapping gene names to gene IDs.
    #'
    #' This function takes a GRanges object containing gene annotations and
    #' creates a named list where the names are gene names and the elements
    #' are named lists containing gene_id and gene_name pairs. It also checks
    #' for inconsistencies where a gene name is associated with multiple gene IDs
    #' or vice versa.
    #'
    #' @param gene_annotations A GRanges object with gene annotations, including
    #'                         'gene_name' and 'gene_id' metadata columns.
    #' @return A named list mapping gene names to gene ID and gene name pairs.
    #' @examples
    #' # Assuming you have a GRanges object named 'gene_annotations'
    #' # gene_name_map <- create_gene_name_map(gene_annotations)

    # Ensure required attributes exist
    if (!("gene_name" %in% names(mcols(gene_annotations)))) {
        log_error("The 'gene_name' attribute is missing from the gene annotations.")
        gene_annotations$gene_name <- NA
    }
    if (!("gene_id" %in% names(mcols(gene_annotations)))) {
        log_error("The 'gene_id' attribute is missing from the gene annotations.")
        gene_annotations$gene_id <- NA
    }

    # Filter out rows with both gene_name and gene_id as NA
    valid_annotations <- gene_annotations[!(is.na(gene_annotations$gene_name) & is.na(gene_annotations$gene_id))]

    # Handle missing gene_name or gene_id
    valid_annotations$gene_name[is.na(valid_annotations$gene_name)] <- paste0("NO_NAME_", valid_annotations$gene_id[is.na(valid_annotations$gene_name)])
    valid_annotations$gene_id[is.na(valid_annotations$gene_id)] <- paste0("NO_ID_", valid_annotations$gene_name[is.na(valid_annotations$gene_id)])

    # Create mappings
    gene_name_to_ids <- lapply(split(valid_annotations$gene_id, valid_annotations$gene_name), unique)
    gene_id_to_names <- lapply(split(valid_annotations$gene_name, valid_annotations$gene_id), unique)

    # Log warnings for duplicates
    lapply(names(gene_name_to_ids), function(gene_name) {
    if (length(gene_name_to_ids[[gene_name]]) > 1) {
        log_info(paste0("Warning: Gene name '", gene_name, "' is associated with multiple gene IDs: ", paste(gene_name_to_ids[[gene_name]], collapse = ", ")))
    }
    })

    lapply(names(gene_id_to_names),
        function(gene_id) {
            if (length(gene_id_to_names[[gene_id]]) > 1) {
                log_info(paste0("Warning: Gene ID '", gene_id, "' is associated with multiple gene names: ", paste(gene_id_to_names[[gene_id]], collapse = ", ")))
            }
        }
    )

    # Create feature map
    gene_feature_map <- mapply(
        function(gene_name, gene_id) list(gene_id = gene_id, gene_name = gene_name),
        names(gene_name_to_ids),
        sapply(gene_name_to_ids, `[`, 1), # First element of each sublist
        SIMPLIFY = FALSE
    )

    return(gene_feature_map)
}

suppressPackageStartupMessages(require(optparse))
suppressPackageStartupMessages(require(Signac))
suppressPackageStartupMessages(require(Seurat))

option_list <- list(
    make_option(
        c("--fragments"),
        action = "store",
        default = NA,
        type = "character",
        help = "Fragments file"
    ),
    make_option(
        c("--transposase-sites"),
        action = "store",
        default = NA,
        type = "character",
        help = "Transposase sites file"
    ),
    make_option(
        c("--peaks"),
        action = "store",
        default = NA,
        type = "character",
        help = "Peaks file"
    ),
    make_option(
        c("--gtf",
        action = "store",
        default = NA,
        type = "character",
        help = "Name of the gtf file to use for importing gene annotations")
    ),
    make_option(
        c("--map-gene-name-id"),
        action = "store_true",
        default = FALSE,
        help = "If this flag is provided, the script will generate the gene name/id json file."
    ),
    make_option(
        c("--base-name"),
        action = "store",
        default = NA,
        type = "character",
        help = "Base output file name"
    )
)

parser <- OptionParser(option_list = option_list)
opt <- parse_args(parser, convert_hyphens_to_underscores = TRUE)
#log_debug("Running with options: {jsonlite::toJSON(opt, auto_unbox = TRUE, pretty = TRUE)}")

atac_cell_by_peak(opt)
