from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import copy
import gzip
import os
from ftplib import FTP


# Get a VDJ specific FASTA reference from the Gencode data - human and mouse

species_list = ["HomoSapiens", "MusMusculus"]
#species_list = ["HomoSapiens"]

gencode_server = 'ftp.ebi.ac.uk'
gencode_urls = {"HomoSapiens": "pub/databases/gencode/Gencode_human/release_47/gencode.v47.transcripts.fa.gz",
                "MusMusculus": "pub/databases/gencode/Gencode_mouse/release_M36/gencode.vM36.transcripts.fa.gz"
                }

# These c fragments are based on the location of our N2 primers for these genes
cFragments = {"HomoSapiens": { "TRAC": "ATATCCAGAACCCTGACCCTGCCGTGTACCAGCTGAGAGACTCTAAATCCAGTGACAAGTCTGTCTGCCTATTCACCGATTTTGAT",
                               "TRBC": "AGGACCTGAANAANGTGTTCCCACCCNAGGTCGCTGTGTTTGAGCCATCAGAAGCAGAGATC",
                               "TRGC": "ATAAACAACTTGATGCAGATGTTTCCC",
                               "TRDC": "GAAGTCAGCCTCATACCAAACCATCCGTTTTTGTCATGAAAAATGGAACAAATGTCGCTTGTCTGGTGAAGGAATTCTACCCCAAGGATAT",
                               "IGKC": "GAACTGTGGCTGCACCATCTGTCTTCATCTTCCCGCCATCTGA",
                               "IGLC1": "GTCAGCCCAAGGCCAACCCCACTGTCACTCTGTTCCCGCCCTCCTCTGAGGAGCTCCAAGCCAACAAGGCCACACT",
                               "IGLC2": "GTCAGCCCAAGGCTGCCCCCTCGGTCACTCTGTTCCCGCCCTCCTCTGAGGAGCTTCAAGCCAACAAGGCCACACTGGT",
                               "IGLC3": "GTCAGCCCAAGGCTGCCCCCTCGGTCACTCTGTTCCCACCCTCCTCTGAGGAGCTTCAAGCCAACAAGGCCACACTGGT",
                               "IGHM": "GGAGTGCATCCGCCCCAACCCTTTTCCCCCTCGTCTCCTGT",
                               "IGHD": "CACCCACCAAGGCTCCGGATGTGTTCCCCATCATATCAGGGTGCAGACA",
                               "IGHA": "CATCCCCGACCAGCCCCAAGGTCTTCCCGCTGAGCCTCNNCAGCACCCNNCNAGATGGGAACGTGGTCNTCGCNTGCCTGGTCCAGGGCTTCTTCCCCCAGGAGCCACTCAGTGTGACCTGGAGCGAAAG",
                               "IGHG": "CNTCCACCAAGGGCCCATCGGTCTTCCCCCTGGCNCCCTNCTCCANGAGCACCTCNGNGNGCACAGCNGCCCTGGGCTGCCTGGTCAAGGACTACTT",
                               "IGHE": "CCTCCACACAGAGCCCATCCGTCTTCCCCTTGAC"
                            },
               "MusMusculus": { "TRAC": "ACATCCAGAACCCAGAACCT",
                                "TRBC": "AGGATCTGAGAAATGTGACTCCACCCAAGGTCTCCTTGTTTGAGCCATCAAAAGCAGAGATTG",
                                "TRGC1": "ACAAAAGGCTTGATGCAGACATTTCCCCCAA",
                                "TRGC3": "ACAAAAAGCTTGATGCAGACATTTCCCCCAA",
                                "TRGC4": "ACAAACGCACTGACTCAGACTTTTCTCCCAAGCCTACTAT",
                                "TRDC": "AAAGCCAGCCTCCGGCCAAACCATCTGTTTTCATCATGAAAAATGGAACAAATGTTGCTTGTCTGGTGAAAGATTTCTAC",
                                "IGKC": "GGGCTGATGCTGCACCAACTGTATCCATCTTCCCACCATCCAGTGAGCAGTTAACATCT",
                                "IGLC1": "GCCAGCCCAAGTCTTCGCCATCAGTCACCCTGTTTCCACCTTCCTCTGAAGAGCTCGAGACTAAC",
                                "IGLC2": "GGTCAGCCCAAGTCCACTCCCACTCTCACCGTGTTTCCACCTTCCTCTGAGGAGCTCAAGGAAAACAAAGCCACACTG",
                                "IGLC3": "GTCAGCCCAAGTCCACTCCCACACTCACCATGTTTCCACCTTCCCCTGAGGAGCTCCAGGAAAACAAAGCCACACT",
                                "IGLC4": "GCCAACCCAAGGCTACACCCTCAGTTAATCTGTTCCCACCTTCCTCTGAAGAGCTC",
                                "IGHM": "AGAGTCAGTCCTTCCCAAATGTC",
                                "IGHD": "GTAATGAAAAGGGACCTGACATGTTCCTCCTCTCAGAGTGCAAAGCCCCAGAGGAAAATGAAAAGATAAACCTGGGCTGTTTAGTAATTGGAAGTCAG",
                                "IGHA": "AGNCTGCNAGANANCCCACCATCTACCCACTGACA",
                                "IGHG1": "CCAAAACGACACCCCCATCTGTCTATCCACTG",
                                "IGHG2b": "CCAAAACAACACCCCCATCAGTCTATCCAC",
                                "IGHG2c": "CCAAAACAACAGCCCCAT",
                                "IGHG3": "CTACAACAACAGCCCCATCTGTCTATCCCTTGGTCCCTGGCTGCGGTGACACATCTGGATCCTCGGTGACACTGGGATGCCTTGTCAAAGGCTACTTCCCT",
                                "IGHE": "CTATCAGGAACCCTCAGCTCTA"
                            }
             }


TRchainTypes = ["TRA", "TRB", "TRG", "TRD"]
IGchainTypes = ["IGK", "IGL", "IGH"]


# Want to balance the need for long enough sequences that 200+bp reads will align successfully, with the desire for speed
# More sequences in the reference will slow down the alignment
# V segments are long enough to be aligned against on their own
# J, and C region segments are not always long enough - so concat these in all possible combinations
# D segments are generally too short to matter, so the are not added to reference


def gencode_VDJ():

    for species in species_list:
        
        print(f'Start {species}')
        toDeleteGencodeGz = False

        Segments = { "TR":{}, "IG":{} }
        protoSegDict = { "V": {}, "D": {}, "J": {}, "C": {}, "cFragments": {} }

        # Assemble a nested dict to hold all the gene and sequence segment data
        for chainType in TRchainTypes:
            Segments["TR"][chainType] = copy.deepcopy(protoSegDict)

        for chainType in IGchainTypes:
            Segments["IG"][chainType] = copy.deepcopy(protoSegDict)
        
        # Check if the file exists in the local path, if so, skip the FTP download
        ftp_path, ftp_file = gencode_urls[species].rsplit('/', 1)
        gencodeFastaGZ = ftp_file

        if not os.path.exists(ftp_file):
            print(f'Downloading: {gencode_server} {gencode_urls[species]}')
            toDeleteGencodeGz = True
            ftp = FTP(gencode_server)
            ftp.login("anonymous", "")
            
            ftp.cwd(ftp_path)
            with open(ftp_file, 'wb') as gencodeFastaGZHandle:
                ftp.retrbinary('RETR %s' % ftp_file, gencodeFastaGZHandle.write)
            ftp.close()

        print('Reading gencode fasta and collecting segments')
        getSegmentsFromFasta(gencodeFastaGZ, Segments)

        addCfragments(species, Segments)

        # Now ready to assemble the segments - in reverse complement form
        # Final seqRecords will be stored here before written as fasta
        print('Assembling segments')
        SeqRecords = {"TR": [], "IG": []}

        # Add all the V segements alone - they are long enough to align with on their own
        # Create all possible combinations of J and Cfragments
        # Don't want to use the full C segment, because we don't want to collect reads from the 3' end of the C gene

        for chainType in Segments["TR"]:
            for vgene in Segments["TR"][chainType]["V"]:
                SeqRecords["TR"].append(getReverseRecord(vgene, Segments["TR"][chainType]["V"][vgene], f'{chainType}_TVDJ'))

            for jgene in Segments["TR"][chainType]["J"]:
                for cfrag in Segments["TR"][chainType]["cFragments"]:
                
                    combinedSeq = Segments["TR"][chainType]["J"][jgene] + Segments["TR"][chainType]["cFragments"][cfrag]
                    combinedID = jgene + '|' + cfrag
                    SeqRecords["TR"].append(getReverseRecord(combinedID, combinedSeq, f'{chainType}_TVDJ'))


        for chainType in Segments["IG"]:
            for vgene in Segments["IG"][chainType]["V"]:
                SeqRecords["IG"].append(getReverseRecord(vgene, Segments["IG"][chainType]["V"][vgene], f'{chainType}_BVDJ'))

            for jgene in Segments["IG"][chainType]["J"]:
                for cfrag in Segments["IG"][chainType]["cFragments"]:
                
                    combinedSeq = Segments["IG"][chainType]["J"][jgene] + Segments["IG"][chainType]["cFragments"][cfrag]
                    combinedID = jgene + '|' + cfrag
                    SeqRecords["IG"].append(getReverseRecord(combinedID, combinedSeq, f'{chainType}_BVDJ'))

        print('Writing new VDJ fasta')
        for majorType in SeqRecords:
            SeqIO.write(SeqRecords[majorType], species + "_" + majorType + "_VDJsegments" + ".fasta", "fasta-2line")

        if toDeleteGencodeGz:
            print('Deleting downloaded Gencode fasta')
            os.remove(gencodeFastaGZ)


def getSegmentsFromFasta(gencodeFastaGZ, Segments):

    with gzip.open(gencodeFastaGZ, "rt") as handle:
        for transcript in SeqIO.parse(handle, "fasta"):
            # Downstream tools can choke on certain characters in fasta headers
            transcript.id = transcript.id.replace('(', '_') \
                             .replace(')', '_') \
                             .replace('[', '_') \
                             .replace(']', '_')

            geneSymbol = transcript.id.split("|")[5]
            biotype  = transcript.id.split("|")[7]

            # Special case due to incorrect biotype annotation in human Gencode v47
            if geneSymbol == 'TRBV11-2':
                biotype = 'TR_V_gene'

            mainType = biotype[0:2]  #  Should be IG or TR
            chainType = geneSymbol[0:3].upper()  # e.g. IGL or TRB

            if "IG_" in biotype or "TR_" in biotype:
                if mainType in Segments and chainType in Segments[mainType]:
                    
                    if "_V" in biotype:
                        Segments[mainType][chainType]["V"][geneSymbol] = transcript.seq.upper()
                    elif "_D" in biotype:
                        Segments[mainType][chainType]["D"][geneSymbol] = transcript.seq.upper()
                    elif "_J" in biotype:
                        Segments[mainType][chainType]["J"][geneSymbol] = transcript.seq.upper()
                    elif "_C" in biotype:
                        Segments[mainType][chainType]["C"][geneSymbol] = transcript.seq.upper() 
                    else:
                        print("Unknown biotype: " + biotype + ' : ' + geneSymbol)
                else:
                    print("Unknown chainType: " + biotype + ' : ' + geneSymbol)


def addCfragments(species, Segments):

    for cFrag in cFragments[species]:
        mainType = cFrag[0:2]  #  Should be IG or TR
        chainType = cFrag[0:3]  # e.g. IGL or TRB

        Segments[mainType][chainType]['cFragments'][cFrag] = cFragments[species][cFrag]


def getReverseRecord(newID, seq, end_header_section):

    revCompSeq = seq.upper().reverse_complement()
    return SeqRecord(revCompSeq, id=newID + "|" + end_header_section, description="", name="")






    # if species == "HomoSapiens":
    #     ENST00000526893rc = Seq("GTCAATGAGGATATTTATTGGGGTTTCATGAGTGCAGGGAGAAGGGCTGGATGACTTGGGATGGGGAGAGAGACCCCTCCCCTGGGATCCTGCAGCTCCAGGCTCCCGTGGGTGGGGTTAGAGTTGGGAACCTATGAACATTCTGTAGGGGCCACTGTCTTCTCCACGGTGCTCCCTTCATGCGTGACCTGGCAGCTGTAGCTTCTGTGGGACTTCCACTGCTCGGGCGTCAGGCTCAGGTAGCTGCTGGCCGCGTACTTGTTGTTGCTCTGTTTGGAGGGTTTGGTGGTCTCCACTCCCGCCTTGACGGGGCTGCCATCTGCCTTCCAGGCCACTGTCACAGCTCCCGGGTAGAAGTCACTGATCAGACACACTAGTGTGGCCTTGTTGGCTTGGAGCTCCTCAGAGGAGGGCGGGAACAGAGTGACAGTGGGGTTGGCCTTGGGCTGACCTAGGACGGTGACCTTGGTCCCAGTTCCGAAGACATAACACAGTGACTGAGGCTCAGACCAAAACCCCCGGGGCCAGCACCTGGGGTCTGCTCTCTGGGGGCTGGGCTGGAGCAGGAGCCTGCCCCACAGGCTCCGCAGGCTGGATCGGCTGCTTCCAACTGAGGCTCCAGGGTCTGGGTCCCCGCTTTGCGGTGCAACCATTGGGCGCAGCAGGCCATGGGCGACCATGGCCAGACCCAGCAGCAGCAGGGGCCAGCGCTGCCTGG")
    #     SeqRecords["IG"].append(SeqRecord(ENST00000526893rc, id="IGLL5|ENST00000526893.5|rcIGLC_capture", description="", name=""))
    #     ENST00000531372rc = Seq("GTCAATGAGGATATTTATTGGGGTTTCATGAGTGCAGGGAGAAGGGCTGGATGACTTGGGATGGGGAGAGAGACCCCTCCCCTGGGATCCTGCAGCTCCAGGCTCCCGTGGGTGGGGTTAGAGTTGGGAACCTATGAACATTCTGTAGGGGCCACTGTCTTCTCCACGGTGCTCCCTTCATGCGTGACCTGGCAGCTGTAGCTTCTGTGGGACTTCCACTGCTCGGGCGTCAGGCTCAGGTAGCTGCTGGCCGCGTACTTGTTGTTGCTCTGTTTGGAGGGTTTGGTGGTCTCCACTCCCGCCTTGACGGGGCTGCCATCTGCCTTCCAGGCCACTGTCACAGCTCCCGGGTAGAAGTCACTGATCAGACACACTAGTGTGGCCTTGTTGGCTTGGAGCTCCTCAGAGGAGGGCGGGAACAGAGTGACAGTGGGGTTGGCCTTGGGCTGACCTGCCCCACAGGCTCCGCAGGCTGGATCGGCTGCTTCCAACTGAGGCTCCAGGGTCTGGGTCCCCGCTTTGCGGTGCAACCATTGGGCGCAGCAGGCCATGGGCGACCATGGCCAGACCCAGCAGCAGCAGGGGCCAGCGCTGCCTGGGACCAG")
    #     SeqRecords["IG"].append(SeqRecord(ENST00000531372rc, id="IGLL5|ENST00000531372.1|rcIGLC_capture", description="", name=""))



def main():
    gencode_VDJ()

if __name__ == '__main__':
    main()
