import pysam
import sys
import getopt
import logging
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
from pycuda.compiler import SourceModule
import pycuda.gpuarray as gpuarray
from pycuda.autoinit import context
import multiprocessing
from multiprocessing import Process
# Options
def parse_arguments():
        opts, args = getopt.getopt(sys.argv[1:], 'b:w:s:t:o:e:')
        bamfile_path, window_size, step_size, output_file, logfile = None, None, None, None, None
        for opt, arg in opts:
            if opt in ("-b"):
                bamfile_path = arg
            if opt in ("-w"):
                window_size = int(arg)
            if opt in ("-s"):
                step_size = int(arg)
            if opt in ("-o"):
                output_file = arg
            if opt in ("-e"):
                logfile = arg
        return bamfile_path, window_size, step_size, output_file, logfile
    except getopt.GetoptError:
        print('Invalid argument')

if __name__ == '__main__':
    bamfile_path, window_size, step_size, output_file, logfile = parse_arguments()
    logging.basicConfig(filename='%s' % (logfile), filemode='a', level=logging.INFO, format='%(asctime)s %(levelname)s - %(message)s')
    global seq
# Code CUDA
mod = SourceModule("""
//Kernel pour calculer la profondeur moyenne brute

__global__ void calcul_depth_kernel(int *depth_data, int seq_length, int window_size, int step_size, float *depth_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        int pos_start = (idx * step_size) + 1;
        int pos_end = pos_start + window_size;
        int count_reads = 0;

        for (int i = pos_start; i < pos_end; i++) {
            count_reads += depth_data[i];

        float avg_depth = (float)count_reads / window_size;
        depth_results[idx] = avg_depth;

//Kernel pour calculer le GC content

__global__ void calcul_gc_kernel(char *gc_data, int seq_length, int window_size, int step_size, float *gc_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        int pos_start = (idx * step_size) + 1;
        int pos_end = pos_start + window_size;
        int gc_count = 0;
        //printf(" ");
        for (int i = pos_start; i <= pos_end; ++i) {
            if ((gc_data[i] == 'G') or (gc_data[i] == 'C') or (gc_data[i] == 'g') or (gc_data[i] == 'c')) {
                //printf(" ");
        float avg_gc = ((float)gc_count / window_size) * 100;
        gc_results[idx] = avg_gc;

// Kernel pour calculer la mappabilite moyenne ponderee

__global__ void calcul_map_kernel(float *map_data, int seq_length, int window_size, int step_size, float *map_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        int pos_start = (idx * step_size) + 1;
        int pos_end = pos_start + window_size;
        float weighted_sum = 0.0;
        float total_weight = 0.0;

        for (int i = pos_start; i <= pos_end; ++i) {
            float weight = map_data[i];
            weighted_sum += weight;
            total_weight += 1;

        float avg_map = weighted_sum / total_weight;
        map_results[idx] = avg_map;

# Obtention de la fonction de kernel compilée
calcul_depth_kernel_cuda = mod.get_function("calcul_depth_kernel")
calcul_gc_kernel_cuda = mod.get_function("calcul_gc_kernel")
calcul_map_kernel_cuda = mod.get_function("calcul_map_kernel")

######<---Fonctions  mappability--->#########
def merge_intervals(intervals):
    #sys.stderr.write("\t Entering merge_intervals\n")
    merged = []
    start, end, score = intervals[0]
    for interval in intervals[1:]:
        if interval[2] == score:
            end = interval[1]
            merged.append((start, end, score))
            start, end, score = interval
    merged.append((start, end, score))
    return merged

def dico_mappabilite(mappability_file):
    sys.stderr.write("\t Entering dico_mappabilite\n")
    mappability_dico = {}

    with open(mappability_file, 'r') as f:
        for line in f:
            fields = line.strip().split('\t')
            if len(fields) < 4:
            chromosome = fields[0]
            start_pos = int(fields[1])
            end_pos = int(fields[2])
            score = float(fields[3])

            if chromosome not in mappability_dico:
                mappability_dico[chromosome] = []

            mappability_dico[chromosome].append((start_pos, end_pos, score))

    # Ajout de la position 0 pour chaque chromosome
    for chromosome, intervals in mappability_dico.items():
        if intervals[0][0] != 0:
            mappability_dico[chromosome].insert(0, (0, intervals[0][0], 0))

    # Fusion des intervalles ayant le même score
    for chromosome, intervals in mappability_dico.items():
        merged_intervals = merge_intervals(intervals)
        mappability_dico[chromosome] = {start: score for start, _, score in merged_intervals}
    sys.stderr.write("\t Leaving dico_mappabilite\n")
    return mappability_dico #Dictionnaire avec les bornes de mappabilité en fonction des positions pour chaque chromosome.
def calcul_mappability(seq_length, mappability, chr):
    sys.stderr.write("\t Entering calcul_mappability \n")
    map_data = np.zeros(seq_length, dtype=np.float32)
    sorted_keys = sorted(mappability[chr].keys())
    #sys.stderr.write("\t sorted_keys =\n")
    prev_bound = 0
    prev_mappability = 0
    for bound in sorted_keys:
        for i in range(prev_bound, min(seq_length, bound)):
            map_data[i] = prev_mappability
        prev_bound = bound
        prev_mappability = mappability[chr][bound]
    # Fill remaining positions if sequence length exceeds last bound
    for i in range(prev_bound, seq_length):
        map_data[i] = prev_mappability

    sys.stderr.write("\t Leaving calcul_mappability \n")
    return map_data
######<---Fonctions calcul gc--->############
def parse_fasta(gc_file):
    sys.stderr.write("\t Entering parse_fasta\n")
    sequences = {}
    with open(gc_file, 'r') as f:
        data = f.read().split('>')
        for entry in data[1:]:
            lines = entry.split('\n')
            header = lines[0]
            sequence = ''.join(lines[1:])
            sequences[header] = sequence
    sys.stderr.write("\t Leaving parse_fasta\n")
    return sequences

def calcul_gc_content(seq_length, chr, seq):
    sys.stderr.write("\t Entering calcul_gc_content\n")
    gc_data = np.zeros(seq_length, dtype="S")
    for i in range(len(seq[chr])):
        gc_data[i] = seq[chr][i]
    sys.stderr.write("\t Leaving calcul_gc_content\n")
    return gc_data

######<---Fonctions calcul Depth Seq--->######
def calcul_depth_seq(seq_length, bamfile_path, chr):
    sys.stderr.write("\t Entering calcul_depth_seq\n")
    depth_data = np.zeros(seq_length, dtype=np.int32)
    for pileupcolumn in bamfile_path.pileup():
        #sys.stderr.write("%s : %s \n" % (pileupcolumn.reference_pos, pileupcolumn.nsegments))
        if pileupcolumn.reference_pos > seq_length:
        depth_data[pileupcolumn.reference_pos] = pileupcolumn.nsegments
    sys.stderr.write("\t Leaving calcul_depth_seq\n")
    return depth_data

######<---Fonction main--->######
def main_calcul(bamfile_path, chr, seq_length, window_size, step_size, output_file):
    sys.stderr.write("\t entering main_calcul\n")
    global seq

    # Calcul mappability
    map_data = Process(target = calcul_mappability, args=(seq_length, mappability, chr))

    # Calcul GC
    gc_data = Process(target = calcul_gc_content, args = (seq_length, chr, seq))

    # Calcul depth seq
    depth_data = Process(target = calcul_depth_seq, args = (seq_length, bamfile_path, chr))

    # Transférer le tableau NumPy vers CUDA
    d_depth_data = cuda.mem_alloc(depth_data.nbytes)
    cuda.memcpy_htod(d_depth_data, depth_data)
    sys.stderr.write("\t d_depth_data : %s, %s\n" % (d_depth_data, d_depth_data.as_buffer(sys.getsizeof(depth_data))))
    d_gc_data = cuda.mem_alloc(gc_data.nbytes)
    cuda.memcpy_htod(d_gc_data, gc_data)
    sys.stderr.write("\t d_gc_data : %s, %s\n" % (d_gc_data, d_gc_data.as_buffer(sys.getsizeof(gc_data))))
    d_map_data = cuda.mem_alloc(map_data.nbytes)
    cuda.memcpy_htod(d_map_data, map_data)
    sys.stderr.write("\t d_map_data : %s, %s\n" % (d_map_data, d_map_data.as_buffer(sys.getsizeof(map_data))))
    # Calculer la taille de la grille et de bloc pour CUDA
    block_size = num_cores
    sys.stderr.write("\t blocksize (nb de threads) = %s\n" % (num_cores))
    grid_size = int((int((seq_length - window_size) / step_size) + 1) / block_size)+1
    sys.stderr.write("\t grid_size = \n")
    # Initialiser le tableau pour stocker les résultats de la profondeur moyenne
    depth_results = np.zeros(int((seq_length - window_size) / step_size) + 1, dtype=np.float32)
    sys.stderr.write("\t Definition de depth_results\n")

    gc_results = np.zeros(int((seq_length - window_size) / step_size) + 1, dtype=np.float32)
    sys.stderr.write("\t Definition de gc_results\n")
    map_results = np.zeros(int((seq_length - window_size) / step_size) + 1, dtype=np.float32)
    sys.stderr.write("\t Definition de map_results\n")
    # Allouer de la mémoire pour les résultats sur le périphérique CUDA
    d_depth_results = cuda.mem_alloc(depth_results.nbytes)
    sys.stderr.write("\t d_depth_results = %s\n" % d_depth_results.as_buffer(sys.getsizeof(d_depth_results)))
    sys.stderr.write("\t depth_results.nbytes = %s\n" % depth_results.nbytes)

    d_gc_results = cuda.mem_alloc(gc_results.nbytes)
    sys.stderr.write("\t d_gc_results = %s\n" % d_gc_results.as_buffer(sys.getsizeof(d_gc_results)))
    sys.stderr.write("\t gc_results.nbytes = %s\n" % gc_results.nbytes)
    d_map_results = cuda.mem_alloc(map_results.nbytes)
    sys.stderr.write("\t d_map_results = %s\n" % d_map_results.as_buffer(sys.getsizeof(d_map_results)))
    sys.stderr.write("\t map_results.nbytes = %s\n" % map_results.nbytes)
    # Appeler la fonction de calcul de profondeur avec CUDA
    calcul_depth_kernel_cuda(d_depth_data, np.int32(seq_length), np.int32(window_size), np.int32(step_size), d_depth_results, block=(block_size, 1, 1), grid=(grid_size, 1))
    sys.stderr.write("\t appel fonction calc_depth_kernel_cuda\n")

    calcul_gc_kernel_cuda(d_gc_data, np.int32(seq_length), np.int32(window_size), np.int32(step_size), d_gc_results, block=(block_size, 1, 1), grid=(grid_size, 1))
    sys.stderr.write("\t appel fonction calc_gc_kernel_cuda\n")

    calcul_map_kernel_cuda(d_map_data, np.int32(seq_length), np.int32(window_size), np.int32(step_size), d_map_results, block=(block_size, 1, 1), grid=(grid_size, 1))
    sys.stderr.write("\t appel fonction calc_map_kernel_cuda\n")


    # Copier les résultats depuis le périphérique CUDA vers l'hôte
    cuda.memcpy_dtoh(depth_results, d_depth_results) #cuda.memcpy_dtoh(dest, src)
    sys.stderr.write("\t Copie les resultats du GPU (d_depth_results) vers le CPU (depth_results)\n")
    cuda.memcpy_dtoh(gc_results, d_gc_results) #cuda.memcpy_dtoh(dest, src)
    sys.stderr.write("\t Copie les resultats du GPU (d_gc_results) vers le CPU (gc_results)\n")
    cuda.memcpy_dtoh(map_results, d_map_results) #cuda.memcpy_dtoh(dest, src)
    sys.stderr.write("\t Copie les resultats du GPU (d_map_results) vers le CPU (map_results)\n")
    # Ecrire les résultats dans le fichier de sortie
    with open(output_file, 'a') as f:
        sys.stderr.write("\t ecriture des fichiers\n")
        for i, (avg_depth, avg_gc, avg_map) in enumerate(zip(depth_results, gc_results, map_results)):
            pos_start = (i * step_size) + 1
            pos_end = pos_start + window_size

# Programme principal
#Calcul nombre de coeurs max pour le GPU

device = cuda.Device(0)
attributes = device.get_attributes()
num_cores = attributes[1]
print("Nombre de CPU: ", multiprocessing.cpu_count())
print(f"Nombre de coeurs max GPU: {num_cores}")

gc_file = '/work/gad/shared/pipeline/grch38/index/grch38_essential.fa'
mappability_file = '/work/gad/shared/analyse/test/cnvGPU/test_scalability/wgEncodeCrgMapabilityAlign100mer_no_uniq.grch38.bedgraph'
seq = parse_fasta(gc_file)
mappability = dico_mappabilite(mappability_file)
with pysam.AlignmentFile(bamfile_path, "rb") as bamfile_handle:
    for i, seq_length in enumerate(bamfile_handle.lengths):
        chr = bamfile_handle.references[i]
        sys.stderr.write("Chromosome : %s, seq length : %s\n" % (chr, seq_length))

        # Appeler la fonction de calcul de la profondeur moyenne pour ce chromosome
        main_calcul(bamfile_handle, chr, seq_length, window_size, step_size, output_file)