#!/usr/bin/env Rscript

# Takes in a MEX format cell-bioproduct expression datatable, and any cell metadata we generate in the
# sequence analysis pipeline. Creates a Seurat ( https://satijalab.org/seurat/ ) RDS file, making it
# easy to load in Seurat in the future. Currently targeting Seurat 4

library(Matrix)
library(TFBSTools)
library(GenomicRanges)
library(IRanges)

source(file.path(.Platform$file.sep, Sys.getenv(c("RHAPSODY_HOME")), "rscript", "Mist_R_Logger.R"))

import_motif_positions <- function(motif_positions_file, motif_names) {
    # Import motif positions from TSV file and convert to GRangesList.
    #
    # Args:
    #     motif_positions_file: Path to TSV/TSV.GZ file with motif positions
    #                          Expected columns: motif_name, chr, start, end, strand
    #     motif_names: Vector of motif names in the order they appear in the motif object
    #
    # Returns:
    #     GRangesList with motif positions, ordered to match motif_names
    #     Returns NULL if file cannot be read or is empty

    tryCatch({
        log_info(paste("Importing motif positions from:", motif_positions_file))

        # Read the TSV file, handling both .tsv and .tsv.gz
        if (grepl("\\.gz$", motif_positions_file)) {
            motif_positions_df <- read.csv(gzfile(motif_positions_file, "rt"), sep="\t")
        } else {
            motif_positions_df <- read.csv(motif_positions_file, sep="\t")
        }

        if (nrow(motif_positions_df) == 0) {
            log_warn("Motif positions file is empty")
            return(NULL)
        }

        # Expected columns: motif_name, chr, start, end, strand (and possibly others)
        required_cols <- c("motif_name", "chromosome", "start", "end", "strand")
        if (!all(required_cols %in% colnames(motif_positions_df))) {
            log_error(paste("Missing required columns in motif positions file. Expected:", paste(required_cols, collapse=", ")))
            return(NULL)
        }

        # Create GRangesList ordered by motif_names
        granges_list <- list()

        for (motif_name in motif_names) {
            motif_rows <- motif_positions_df[motif_positions_df$motif_name == motif_name, ]

            if (nrow(motif_rows) > 0) {
                # Create GRanges object for this motif
                gr <- GRanges(
                    seqnames = motif_rows$chromosome,
                    ranges = IRanges(start = motif_rows$start, end = motif_rows$end),
                    strand = motif_rows$strand
                )

                # Add any additional metadata columns
                extra_cols <- setdiff(colnames(motif_rows), required_cols)
                if (length(extra_cols) > 0) {
                    for (col in extra_cols) {
                        mcols(gr)[[col]] <- motif_rows[[col]]
                    }
                }

                granges_list[[motif_name]] <- gr
            } else {
                # Create empty GRanges for motifs with no positions
                granges_list[[motif_name]] <- GRanges()
            }
        }

        # Convert to GRangesList
        granges_list <- GRangesList(granges_list)

        # Count total positions
        total_positions <- sum(sapply(granges_list, length))
        log_info(paste("Imported", total_positions, "motif position records for", length(granges_list), "motifs"))

        return(granges_list)

    }, error = function(e) {
        log_error(paste("Failed to import motif positions:", e$message))
        return(NULL)
    })
}

generate_seurat <- function(opt) {
    if (is.na(opt$data_tables) && is.na(opt$atac_seurat_rds)) {
        log_error("--data-tables or --atac-seurat-rds is required")
        stop()
    }

    rna_included <- (!is.na(opt$data_tables))
    atac_included <- (!is.na(opt$atac_seurat_rds))
    abseq_included <- FALSE

    log_info("Loading Seurat")
    suppressPackageStartupMessages(require(Seurat))

    # Process RNA/Abseq data, if present, for Seurat Assay objects
    if (rna_included) {
        data_table_vec <- unlist(strsplit(opt$data_tables, ","))
        rsec_mols_per_cell_file <- NULL
        for (file in data_table_vec) {
            if (grepl("RSEC_MolsPerCell_MEX.zip", file)) {
                rsec_mols_per_cell_file <- file
            }
        }
        if (is.null(rsec_mols_per_cell_file)) {
            log_error("Unable to find data-table ending in RSEC_MolsPerCell_MEX.zip")
            stop()
        }

        log_info("Unzipping: {rsec_mols_per_cell_file}")
        unzip(rsec_mols_per_cell_file, overwrite = TRUE)
        system("gunzip features.tsv.gz")
        # To remove single quote in features.tsv, otherwise the TSV is not read correctly
        system("tr \"\'\" \"_\" <features.tsv> features_filtered0.tsv")
        # To remove double quote in features.tsv, otherwise the TSV is not read correctly
        system("tr \'\"\' \"_\" <features_filtered0.tsv> features_filtered.tsv")

        log_info("Creating expression matrix")
        expression_matrix <- ReadMtx(mtx = "matrix.mtx.gz", cells = "barcodes.tsv.gz", features = "features_filtered.tsv")

        log_info("Getting Abseq Names")
        abseq_names <- grep("pAbO$", rownames(expression_matrix))
        if (!length(abseq_names) == 0) {
            abseq_included <- TRUE
            log_info("Separating Abseq to independent Assay")
            mat_abseq <- expression_matrix[abseq_names, ]
            mat_mRNA <- expression_matrix[-abseq_names, ]
            # When only one feature is selected, mat_abseq or mat_mRNA is a vector, not a matrix.
            # CreateAssayObject requires a matrix so mat_abseq or mat_mRNA should be converted into a matrix.
            if (!is(mat_mRNA, "Matrix")) {
                mat_mRNA <- matrix(
                    mat_mRNA,
                    nrow = length(rownames(expression_matrix)) - length(abseq_names)
                )
                rownames(mat_mRNA) <- rownames(expression_matrix)[-abseq_names]
                colnames(mat_mRNA) <- colnames(expression_matrix)
            }
            if (!is(mat_abseq, "Matrix")) {
                mat_abseq <- matrix(
                    mat_abseq,
                    nrow = length(abseq_names)
                )
                rownames(mat_abseq) <- rownames(expression_matrix)[abseq_names]
                colnames(mat_abseq) <- colnames(expression_matrix)
            }
            log_info("Creating RNA Seurat Assay")
            if (any(grepl("_", rownames(mat_mRNA)))) {
                log_warn("RNA feature names contain underscores ('_'), which may cause issues in Seurat.")
                log_debug(paste("DEBUG: RNA features with underscores:", sum(grepl("_", rownames(mat_mRNA))), "out of", nrow(mat_mRNA)))
                underscore_features <- rownames(mat_mRNA)[grepl("_", rownames(mat_mRNA))]
                log_debug(paste("DEBUG: RNA features containing underscores (first 10):", paste(head(underscore_features, 10), collapse=", ")))
            }
            rna_assay <- CreateAssayObject(counts = mat_mRNA)
            log_info("Creating Abseq Seurat Assay")
            if (any(grepl("_", rownames(mat_abseq)))) {
                log_warn("Abseq feature names contain underscores ('_'), which may cause issues in Seurat.")
                log_debug(paste("DEBUG: Abseq features with underscores:", sum(grepl("_", rownames(mat_abseq))), "out of", nrow(mat_abseq)))
                underscore_features <- rownames(mat_abseq)[grepl("_", rownames(mat_abseq))]
                log_debug(paste("DEBUG: Abseq features containing underscores (first 10):", paste(head(underscore_features, 10), collapse=", ")))
            }
            abseq_assay <- CreateAssayObject(counts = mat_abseq)
        } else {
            log_info("No Abseq Detected, Creating RNA Seurat Assay")
            if (any(grepl("_", rownames(expression_matrix)))) {
                log_warn("RNA feature names contain underscores ('_'), which may cause issues in Seurat.")
                log_debug(paste("DEBUG: RNA features with underscores:", sum(grepl("_", rownames(expression_matrix))), "out of", nrow(expression_matrix)))
                underscore_features <- rownames(expression_matrix)[grepl("_", rownames(expression_matrix))]
                log_debug(paste("DEBUG: RNA features containing underscores (first 10):", paste(head(underscore_features, 10), collapse=", ")))
            }
            rna_assay <- CreateAssayObject(counts = expression_matrix)
        }
    }

    # If ATAC Seurat object exists, add RNA assay to it.
    # Otherwise, construct a new Seurat object.
    if (atac_included) {
        log_info("Loading ATAC Seurat object")
        seurat_object <- readRDS(opt$atac_seurat_rds)

        if (!is.na(opt$atac_peak_seq)) {
            log_info("Adding peak sequence to the ATAC Seurat object")
            peak_seq <- read.csv(gzfile(opt$atac_peak_seq, "rt"))
            seurat_object[["peaks"]]@meta.features <- cbind(seurat_object[["peaks"]]@meta.features, peak_seq["Peak_seq"])
        }

        if (!is.na(opt$atac_background_peaks)) {
            log_info("Adding background peaks to the ATAC Seurat object")

            background_peaks <- read.csv(gzfile(opt$atac_background_peaks, "rt"))
            # drop peak name
            background_peaks <- background_peaks[, -1]
            colnames(background_peaks) <- paste0("bg_peaks", 1:dim(background_peaks)[2])

            seurat_object[["peaks"]]@meta.features <- cbind(seurat_object[["peaks"]]@meta.features, background_peaks)
        }

        if (!is.na(opt$atac_motif_data_table)) {
            atac_motif_data_table <- unlist(strsplit(opt$atac_motif_data_table, ","))

            atac_cell_by_motif_file <- NULL
            atac_peak_by_motif_file <- NULL
            for (atac_data_table in atac_motif_data_table) {
                if (endsWith(atac_data_table, "_ATAC_Cell_by_Motif_MEX.zip")) {
                    atac_cell_by_motif_file <- atac_data_table
                } else if (endsWith(atac_data_table, "_ATAC_Peak_by_Motif_MEX.zip")) {
                    atac_peak_by_motif_file <- atac_data_table
                }
            }

            if (!is.null(atac_cell_by_motif_file)) {
                log_info("Updating ATAC Seurat object with the cell by motif table")
                unzip(atac_cell_by_motif_file, overwrite = TRUE)
                system("gunzip atac-features.tsv.gz")
                # To remove single quote in features.tsv, otherwise the TSV is not read correctly
                system("tr \"\'\" \"_\" <atac-features.tsv> atac_features_filtered0.tsv")
                # To remove double quote in features.tsv, otherwise the TSV is not read correctly
                system("tr \'\"\' \"_\" <atac_features_filtered0.tsv> atac_features_filtered.tsv")

                atac_motif_by_cell_matrix <- ReadMtx(mtx = "atac-matrix.mtx.gz", cells = "atac-barcodes.tsv.gz", features = "atac_features_filtered.tsv")
                chromvar_assay <- CreateAssayObject(data = atac_motif_by_cell_matrix)
                seurat_object[["chromvar"]] <- chromvar_assay

                system("rm atac-matrix.mtx.gz")
                system("rm atac-barcodes.tsv.gz")
                system("rm atac_features_filtered.tsv atac_features_filtered0.tsv atac-features.tsv")
            }

            if (!is.null(atac_peak_by_motif_file)) {
                log_info("Updating ATAC Seurat object with the peak by motif table")
                unzip(atac_peak_by_motif_file, overwrite = TRUE)
                atac_motif_by_peak_matrix <- t(readMM("peak-motif-matrix.mtx.gz"))

                system("gunzip peak-motif-peaks.tsv.gz")
                peak_names <- read.table("peak-motif-peaks.tsv")
                peak_names <- peak_names[, 1]

                system("gunzip peak-motif-motifs.tsv.gz")
                motif_names <- read.table("peak-motif-motifs.tsv")[, 1]

                rownames(atac_motif_by_peak_matrix) <- motif_names
                # peak name format is chr-start-end in Signac, which is different from what's in MEX file (chr:start-end)
                colnames(atac_motif_by_peak_matrix) <- rownames(seurat_object@assays$peaks@counts)

                system("rm peak-motif-matrix.mtx.gz peak-motif-motifs.tsv peak-motif-peaks.tsv")

                if (!is.na(opt$pfm_file)) {
                    log_info("Creating motif object")
                    pfm <- TFBSTools::readJASPARMatrix(opt$pfm_file, matrixClass = "PFM")
                    # This doesn't actually change the internal names of the motifs,
                    # but it should be good enough for our purposes.
                    names(pfm) <- motif_names

                    # Import motif positions if available
                    motif_positions_granges <- NULL
                    if (!is.na(opt$atac_motif_positions)) {
                        motif_positions_granges <- import_motif_positions(opt$atac_motif_positions, motif_names)
                    }

                    # Make motif_names into a named vector.
                    names(motif_names) <- motif_names

                    motif_object <- Signac::CreateMotifObject(
                        data = as(t(atac_motif_by_peak_matrix), "CsparseMatrix"),
                        pwm = pfm,
                        motif.names = motif_names,
                        positions = motif_positions_granges,
                    )

                    seurat_object@assays$peaks@motifs <- motif_object
                }
            }
        }

        if (rna_included) {
            log_info("Updating ATAC Seurat object with RNA info")
            seurat_object[["RNA"]] <- rna_assay
            DefaultAssay(seurat_object) <- "RNA"
        }
    } else {
        log_info("Creating Seurat object")
        seurat_object <- CreateSeuratObject(counts = rna_assay, assay = "RNA", project = opt$base_name)
    }

    # Add Abseq assay if it is present.
    if (abseq_included) {
        seurat_object[["ADT"]] <- abseq_assay
    }

    # Add relevant metadata
    if (!is.na(opt$bioproduct_stats)) {
        log_info("Adding Bioproduct stats")
        bioproduct_stats_meta <- read.table(opt$bioproduct_stats, sep = ",", header = TRUE, comment = "#", row.names = 1)
        # Seurat automatically converts "|" into "-"
        rownames(bioproduct_stats_meta) <- gsub("\\|", "-", rownames(bioproduct_stats_meta))

        abseq_rowIdx <- grep("pAbO$", rownames(bioproduct_stats_meta))
        if ((length(abseq_rowIdx) > 0) && (abseq_included)) {
            log_info("  Adding Abseq bioproduct stats")
            feature_metadata_abseq <- bioproduct_stats_meta[, 1:6][abseq_rowIdx, ]

            feature_in_abseq <- rownames(seurat_object@assays$ADT@counts)
            feature_metadata_abseq <- bioproduct_stats_meta[feature_in_abseq, 1:6]
            seurat_object[["ADT"]] <- AddMetaData(
                object = seurat_object[["ADT"]],
                metadata = feature_metadata_abseq,
                col.name = colnames(feature_metadata_abseq)
            )
        }

        mRNA_bioproduct_num <- dim(bioproduct_stats_meta)[1] - length(abseq_rowIdx)
        if ((mRNA_bioproduct_num > 0) && (rna_included)) {
            log_info("  Adding mRNA bioproduct stats")

            feature_in_mRNA <- rownames(seurat_object@assays$RNA@counts)
            feature_metadata_mRNA <- bioproduct_stats_meta[feature_in_mRNA, 1:6]
            seurat_object[["RNA"]] <- AddMetaData(
                object = seurat_object[["RNA"]],
                metadata = feature_metadata_mRNA,
                col.name = colnames(feature_metadata_mRNA)
            )
        }
    }

    if (!is.na(opt$cell_type_experimental)) {
        log_info("Adding Cell Type metadata")
        cell_type_meta <- read.table(opt$cell_type_experimental, skip = 0, sep = ",", header = TRUE, row.names = 1, stringsAsFactors = TRUE)
        seurat_object <- AddMetaData(object = seurat_object, metadata = cell_type_meta)
    }

    if (!is.na(opt$sample_tag_calls)) {
        log_info("Adding Sample Tag Calls metadata")
        tag_calls_meta <- read.table(opt$sample_tag_calls, skip = 0, sep = ",", header = TRUE, row.names = 1, stringsAsFactors = TRUE)
        seurat_object <- AddMetaData(object = seurat_object, metadata = tag_calls_meta)
    }

    if (!is.na(opt$sample_tag_csvs)) {
        sample_tag_csvs <- unlist(strsplit(opt$sample_tag_csvs, ","))
        sample_tag_reads_per_cell_file <- NULL
        for (file in sample_tag_csvs) {
            if (endsWith(file, "_Sample_Tag_ReadsPerCell.csv")) {
                sample_tag_reads_per_cell_file <- file
            }
        }
        if (!is.null(sample_tag_reads_per_cell_file)) {
            log_info("Adding Sample Tag Reads per Cell metadata")
            sample_tag_reads_per_cell <- read.table(sample_tag_reads_per_cell_file, sep = ",", header = TRUE, comment = "#", row.names = 1)

            seurat_object <- AddMetaData(
                object = seurat_object,
                metadata = sample_tag_reads_per_cell
            )
        }
    }

    if (!is.na(opt$putative_cells_origin)) {
        log_info("Adding Putative Cell Origin metadata")
        putative_origin_meta <- read.table(opt$putative_cells_origin, skip = 0, sep = ",", header = TRUE, row.names = 1)
        seurat_object <- AddMetaData(object = seurat_object, metadata = putative_origin_meta)
    }

    if (!is.na(opt$protein_aggregates_experimental)) {
        log_info("Adding Protein Aggregates metadata")
        prot_agg_meta <- read.table(opt$protein_aggregates_experimental, skip = 0, sep = ",", header = TRUE, row.names = 1)
        seurat_object <- AddMetaData(object = seurat_object, metadata = prot_agg_meta)
    }

    if (!is.na(opt$vdj_per_cell)) {
        log_info("Adding VDJ per cell metadata")
        vdj_meta <- read.table(opt$vdj_per_cell, skip = 0, sep = ",", header = TRUE, row.names = 1)
        vdj_meta <- subset(vdj_meta, select = -c(Cell_Type_Experimental)) # Redundant column to Cell type import above
        seurat_object <- AddMetaData(object = seurat_object, metadata = vdj_meta)
    }

    # Coordinate files come in a comma separated list, get each separately
    tsne_coordinates_file <- NULL
    umap_coordinates_file <- NULL
    if (!is.na(opt$coordinates_file_list)) {
        coord_file_vec <- unlist(strsplit(opt$coordinates_file_list, ","))
        for (file in coord_file_vec) {
            if (endsWith(file, "tSNE_coordinates.csv")) {
                tsne_coordinates_file <- file
            } else if (endsWith(file, "UMAP_coordinates.csv")) {
                umap_coordinates_file <- file
            }
        }
    }

    if (!is.null(tsne_coordinates_file)) {
        write("Adding t-SNE coordinates", stdout())
        tsne_table <- read.table(tsne_coordinates_file, skip = 0, sep = ",", header = TRUE, row.names = 1)
        tsne_matrix <- data.matrix(tsne_table)

        if (endsWith(tsne_coordinates_file, "_Joint_tSNE_coordinates.csv")) {
            used_assay <- "RNA+peaks"
        } else if (endsWith(tsne_coordinates_file, "_ATAC_tSNE_coordinates.csv")) {
            used_assay <- "peaks"
        } else {
            used_assay <- "RNA"
        }

        seurat_object[["tsne"]] <- CreateDimReducObject(embeddings = tsne_matrix, key = "tSNE_", global = TRUE, assay = used_assay)
    }

    if (!is.null(umap_coordinates_file)) {
        write("Adding UMAP coordinates", stdout())
        umap_table <- read.table(umap_coordinates_file, skip = 0, sep = ",", header = TRUE, row.names = 1)
        umap_matrix <- data.matrix(umap_table)

        if (endsWith(umap_coordinates_file, "_Joint_UMAP_coordinates.csv")) {
            used_assay <- "RNA+peaks"
        } else if (endsWith(umap_coordinates_file, "_ATAC_UMAP_coordinates.csv")) {
            used_assay <- "peaks"
        } else {
            used_assay <- "RNA"
        }

        seurat_object[["umap"]] <- CreateDimReducObject(embeddings = umap_matrix, key = "UMAP_", global = TRUE, assay = used_assay)
    }

    output_file <- paste(opt$base_name, "_Seurat.rds", sep = "")
    log_info("Writing Seurat object to disk: {output_file}")
    saveRDS(seurat_object, file = output_file)

    log_info("Removing intermediate files")
    files_to_remove <- c(
        "features_filtered.tsv",
        "features_filtered0.tsv",
        "features.tsv",
        "barcodes.tsv.gz",
        "matrix.mtx.gz"
    )
    unlink(files_to_remove)
}

suppressPackageStartupMessages(require(optparse))

option_list <- list(
    make_option(
        c("--data-tables"),
        action = "store",
        default = NA,
        type = "character",
        help = "A comma-separated list of data-table files."
    ),
    make_option(
        c("--bioproduct-stats"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of bioproduct stats .csv file."
    ),
    make_option(
        c("--cell-type-experimental"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of cell type experimental .csv file."
    ),
    make_option(
        c("--putative-cells-origin"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of putative cells origin .csv file."
    ),
    make_option(
        c("--sample-tag-calls"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of sample tag calls .csv file."
    ),
    make_option(
        c("--sample-tag-csvs"),
        action = "store",
        default = NA,
        type = "character",
        help = "A comma-separated list of sample tag csv files"
    ),
    make_option(
        c("--coordinates-file-list"),
        action = "store",
        default = NA,
        type = "character",
        help = "Comma separated list of dim reduction coordinates for tSNE and UMAP"
    ),
    make_option(
        c("--protein-aggregates-experimental"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of protein aggregates experimental .csv file."
    ),
    make_option(
        c("--vdj-per-cell"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of VDJ per cell file."
    ),
    make_option(
        c("--atac-seurat-rds"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of ATAC Seurat RDS file."
    ),
    make_option(
        c("--atac-motif-data-table"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of ATAC motif data table MEX files"
    ),
    make_option(
        c("--atac-peak-seq"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of ATAC peak sequence .csv.gz file"
    ),
    make_option(
        c("--atac-background-peaks"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of atac background pekas .csv.gz file"
    ),
    make_option(
        c("--pfm-file"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of atac background pekas .csv.gz file"
    ),
    make_option(
        c("--atac-motif-positions"),
        action = "store",
        default = NA,
        type = "character",
        help = "Name of ATAC motif positions .tsv.gz file"
    ),
    make_option(
        c("--base-name"),
        action = "store",
        default = NA,
        type = "character",
        help = "Base output file name"
    )
)

parser <- OptionParser(option_list = option_list)
opt <- parse_args(parser, convert_hyphens_to_underscores = TRUE)

tryCatch(
    expr = {
        generate_seurat(opt)
    },
    error = function(e) {
        log_error("Error in GenerateSeurat:")
        print(e)
    },
    # warning = function(w) {
    #     message('Caught an warning!')
    #     print(w)
    # },
    finally = {
        log_info("Finished GenerateSeurat")
    }
)
