import pysam
import sys
import getopt
import logging
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
from pycuda.compiler import SourceModule
import pycuda.gpuarray as gpuarray
from pycuda.autoinit import context
import multiprocessing
from multiprocessing import Process, Queue
import time

# Options


def parse_arguments():
    """
    Parse command-line arguments.

    This function parses the command-line arguments provided to the script and extracts the values for various parameters.

    Parameters
    ----------
    None

    Returns
    -------
    tuple
        A tuple containing the following elements:
        bamfile_path : str or None
            The path to the BAM file.
        window_size : int or None
            The size of the window.
        step_size : int or None
            The step size for the analysis.
        zscore_threshold : float or None
            The threshold for the Z-score.
        lengthFilter : int or None
            The filter for length.
        output_file : str or None
            The path to the output file.
        logfile : str or None
            The path to the log file.

        Each element can be None if the corresponding argument was not provided.
    """
    try:
        opts, args = getopt.getopt(sys.argv[1:], "b:w:s:z:l:t:o:e:")
        (
            bamfile_path,
            window_size,
            step_size,
            zscore_threshold,
            lengthFilter,
            output_file,
            logfile,
        ) = (None, None, None, None, None, None, None)
        for opt, arg in opts:
            if opt in ("-b"):
                bamfile_path = arg
            if opt in ("-w"):
                window_size = int(arg)
            if opt in ("-s"):
                step_size = int(arg)
            if opt in ("-z"):
                zscore_threshold = float(arg)
            if opt in ("-l"):
                lengthFilter = int(arg)
            if opt in ("-o"):
                output_file = arg
            if opt in ("-e"):
                logfile = arg
        return (
            bamfile_path,
            window_size,
            step_size,
            zscore_threshold,
            lengthFilter,
            output_file,
            logfile,
        )
    except getopt.GetoptError:
        print("Invalid argument")
        sys.exit(1)


if __name__ == "__main__":
    (
        bamfile_path,
        window_size,
        step_size,
        zscore_threshold,
        lengthFilter,
        output_file,
        logfile,
    ) = parse_arguments()
    logging.basicConfig(
        filename="%s" % (logfile),
        filemode="a",
        level=logging.INFO,
        format="%(asctime)s %(levelname)s - %(message)s",
    )
    logging.info("start")
    global seq

# Code CUDA
mod = SourceModule(
    """
//Kernel pour calculer la profondeur moyenne brute

__global__ void calcul_depth_kernel(int *depth_data, int seq_length, int window_size, int step_size, float *depth_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        int pos_start = (idx * step_size) + 1;
        int pos_end = pos_start + window_size;
        int count_reads = 0;

                
        for (int i = pos_start; i < pos_end; i++) {
            count_reads += depth_data[i];
        }

        float avg_depth = (float)count_reads / window_size;
        depth_results[idx] = avg_depth;
    }
}

//Kernel pour calculer le GC content

__global__ void calcul_gc_kernel(char *gc_data, int seq_length, int window_size, int step_size, float *gc_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        int pos_start = (idx * step_size) + 1;
        int pos_end = pos_start + window_size;
        int gc_count = 0;
    
        //printf(gc_data);
        //printf(" ");
        
        for (int i = pos_start; i <= pos_end; ++i) {
            if ((gc_data[i] == 'G') or (gc_data[i] == 'C') or (gc_data[i] == 'g') or (gc_data[i] == 'c')) {
                //printf(&gc_data[i]);
                //printf(" ");
                gc_count++;
            }
        }
        
        float avg_gc = ((float)gc_count / window_size) * 100;
        gc_results[idx] = avg_gc;
    }
}

// Kernel pour calculer la mappabilite moyenne ponderee

__global__ void calcul_map_kernel(float *map_data, int seq_length, int window_size, int step_size, float *map_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        int pos_start = (idx * step_size) + 1;
        int pos_end = pos_start + window_size;
        float weighted_sum = 0.0;
        float total_weight = 0.0;

        for (int i = pos_start; i <= pos_end; ++i) {
            float weight = map_data[i];
            weighted_sum += weight;
            total_weight += 1;
        }

        float avg_map = weighted_sum / total_weight;
        map_results[idx] = avg_map;
    }
}


// Kernel pour calculer la lecture de profondeur corrigee par la mappabilitee

__global__ void calcul_depth_correction_kernel(float *depth_results, float *map_results, int seq_length, int window_size, int step_size, float *depth_correction_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        float avg_depth = depth_results[idx];
        float avg_map = map_results[idx];

        // Verification si avg_map est egal a 0 pour eviter la division par 0
        float depth_correction = (avg_map != 0.0f) ? (avg_depth / avg_map) : 0.0f;
        depth_correction_results[idx] = depth_correction;
    }
}


// Kernel pour normaliser la profondeur corrigee

__global__ void normalize_depth_kernel(float *depth_correction, float *gc_results, float m, float *gc_to_median, int seq_length, int window_size, int step_size, float *depth_normalize) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        float mGC = gc_to_median[(int)gc_results[idx]];
        
        // Verification si mGC est egal a 0 pour eviter la division par 0
        float depth_normalize_val = (mGC != 0.0f) ? (depth_correction[idx] * m / mGC) : 0.0f;
        depth_normalize[idx] = depth_normalize_val;
    }
}

// Kernel pour calculer le ratio par window

__global__ void ratio_par_window_kernel(float *depth_normalize_val, float mean_chr, int seq_length, int window_size, int step_size, float *ratio_par_window_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        float ratio = depth_normalize_val[idx] / mean_chr;
        ratio_par_window_results[idx] = ratio;
    }
}

// Kernel pour calculer le z_score par window sur le ratio

__global__ void z_score_kernel(float *ratio, float mean_ratio, float std_ratio, int seq_length, int window_size, int step_size, float *z_score_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        float z_score = (ratio[idx] - mean_ratio) / std_ratio;
        z_score_results[idx] = z_score;
    }
}

// Kernel pour calculer le ratio divise par le ratio moyen

__global__ void ratio_par_mean_ratio_kernel(float *ratio, float mean_ratio, int seq_length, int window_size, int step_size, float *ratio_par_mean_ratio_results) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;

    if (idx < seq_length - window_size + 1) {
        float ratio_mean_ratio = ratio[idx] / mean_ratio;
        ratio_par_mean_ratio_results[idx] = ratio_mean_ratio;
    }
}

"""
)

# Obtention de la fonction de kernel compilée
calcul_depth_kernel_cuda = mod.get_function("calcul_depth_kernel")
calcul_gc_kernel_cuda = mod.get_function("calcul_gc_kernel")
calcul_map_kernel_cuda = mod.get_function("calcul_map_kernel")
calcul_depth_correction_kernel_cuda = mod.get_function("calcul_depth_correction_kernel")
normalize_depth_kernel_cuda = mod.get_function("normalize_depth_kernel")
ratio_par_window_kernel_cuda = mod.get_function("ratio_par_window_kernel")
z_score_kernel_cuda = mod.get_function("z_score_kernel")
ratio_par_mean_ratio_kernel_cuda = mod.get_function("ratio_par_mean_ratio_kernel")


def merge_intervals(intervals):
    """
    Merge overlapping intervals with the same score.

    This function takes a list of intervals and merges those that have the same score.

    Parameters
    ----------
    intervals : list of tuples
        A list where each element is a tuple (start, end, score).
        The intervals should be sorted by start position.

    Returns
    -------
    list of tuples
        A list of merged intervals (start, end, score) where overlapping intervals with the same score are combined.
    """
    merged = []
    start, end, score = intervals[0]
    for interval in intervals[1:]:
        if interval[2] == score:
            end = interval[1]
        else:
            merged.append((start, end, score))
            start, end, score = interval
    merged.append((start, end, score))
    return merged


def dico_mappabilite(mappability_file):
    """
    Create a dictionary of mappability scores from a file.

    This function reads a mappability file and creates a dictionary with chromosomes as keys and mappability scores as values.

    Parameters
    ----------
    mappability_file : str
        The path to the mappability file. Each line in the file should have the format:
        chromosome, start_pos, end_pos, score.

    Returns
    -------
    dict
        A dictionary where keys are chromosome names and values are another dictionary with start positions as keys
        and mappability scores as values.
    """
    sys.stderr.write("\t Entering dico_mappabilite\n")
    mappability_dico = {}

    with open(mappability_file, "r") as f:
        for line in f:
            fields = line.strip().split("\t")
            if len(fields) < 4:
                continue

            chromosome = fields[0]
            start_pos = int(fields[1])
            end_pos = int(fields[2])
            score = float(fields[3])

            if chromosome not in mappability_dico:
                mappability_dico[chromosome] = []

            mappability_dico[chromosome].append((start_pos, end_pos, score))

    # Add position 0 for each chromosome
    for chromosome, intervals in mappability_dico.items():
        if intervals[0][0] != 0:
            mappability_dico[chromosome].insert(0, (0, intervals[0][0], 0))

    # Merge intervals with the same score
    for chromosome, intervals in mappability_dico.items():
        merged_intervals = merge_intervals(intervals)
        mappability_dico[chromosome] = {
            start: score for start, _, score in merged_intervals
        }

    sys.stderr.write("\t Leaving dico_mappabilite\n")
    return mappability_dico


def calcul_mappability(seq_length, mappability, chr):
    """
    Calculate mappability array for a given sequence length and chromosome.

    This function generates an array of mappability scores for a specific chromosome and sequence length.

    Parameters
    ----------
    seq_length : int
        The length of the sequence.
    mappability : dict
        A dictionary containing mappability information, with chromosomes as keys and dictionaries of start positions
        and scores as values.
    chr : str
        The chromosome for which the mappability is calculated.

    Returns
    -------
    numpy.ndarray
        An array of mappability scores for the given sequence length.
    """
    sys.stderr.write("\t Entering calcul_mappability \n")
    map_data = np.zeros(seq_length, dtype=np.float32)
    sorted_keys = sorted(mappability[chr].keys())

    prev_bound = 0
    prev_mappability = 0

    for bound in sorted_keys:
        for i in range(prev_bound, min(seq_length, bound)):
            map_data[i] = prev_mappability
        prev_bound = bound
        prev_mappability = mappability[chr][bound]

    # Fill remaining positions if sequence length exceeds last bound
    for i in range(prev_bound, seq_length):
        map_data[i] = prev_mappability

    sys.stderr.write("\t Leaving calcul_mappability \n")
    return map_data


def parse_fasta(gc_file):
    """
    Parse a FASTA file and extract sequences.

    This function reads a FASTA file and extracts the sequences, storing them in a dictionary with headers as keys.

    Parameters
    ----------
    gc_file : str
        The path to the FASTA file. The file should be in standard FASTA format.

    Returns
    -------
    dict
        A dictionary where keys are sequence headers and values are the corresponding sequences.
    """
    sys.stderr.write("\t Entering parse_fasta\n")
    sequences = {}
    with open(gc_file, "r") as f:
        data = f.read().split(">")
        for entry in data[1:]:
            lines = entry.split("\n")
            header = lines[0]
            sequence = "".join(lines[1:])
            sequences[header] = sequence

    sys.stderr.write("\t Leaving parse_fasta\n")
    return sequences


def calcul_gc_content(seq_length, chr, seq):
    """
    Calculate the GC content of a given sequence.

    This function generates an array representing the GC content for a specific chromosome and sequence length.

    Parameters
    ----------
    seq_length : int
        The length of the sequence.
    chr : str
        The chromosome for which the GC content is calculated.
    seq : dict
        A dictionary containing sequences, with chromosome names as keys and sequences as values.

    Returns
    -------
    numpy.ndarray
        An array of bytes ('S' dtype) representing the GC content for the given sequence length.
    """
    sys.stderr.write("\t Entering calcul_gc_content\n")
    gc_data = np.zeros(seq_length, dtype="S")
    for i in range(len(seq[chr])):
        gc_data[i] = seq[chr][i]

    sys.stderr.write("\t Leaving calcul_gc_content\n")
    return gc_data


def calcul_depth_seq(seq_length, bamfile_path, chr):
    """
    Calculate the sequencing depth for a given chromosome.

    This function computes the sequencing depth for a specific chromosome and sequence length using a BAM file.

    Parameters
    ----------
    seq_length : int
        The length of the sequence.
    bamfile_path : pysam.AlignmentFile
        The path to the BAM file opened with pysam.AlignmentFile.
    chr : str
        The chromosome for which the depth is calculated.

    Returns
    -------
    numpy.ndarray
        An array of integers representing the sequencing depth for the given sequence length.
    """
    sys.stderr.write("\t Entering calcul_depth_seq\n")
    depth_data = np.zeros(seq_length, dtype=np.int32)
    for pileupcolumn in bamfile_path.pileup():
        if pileupcolumn.reference_pos >= seq_length:
            break
        depth_data[pileupcolumn.reference_pos] = pileupcolumn.nsegments

    sys.stderr.write("\t Leaving calcul_depth_seq\n")
    return depth_data


def calcul_med_total(depth_correction_results):
    """
    Calculate the median of non-zero depth correction results.

    This function filters out zero values from the depth correction results and computes the median of the remaining values.

    Parameters
    ----------
    depth_correction_results : list or numpy.ndarray
        A list or array of depth correction values.

    Returns
    -------
    float
        The median of the non-zero depth correction values, or 0 if no non-zero values are present.
    """
    sys.stderr.write("\t entering calcul_med_total\n")
    depth_correction_results = np.array(depth_correction_results)
    # Filter results to remove zero values
    non_zero_results = depth_correction_results[depth_correction_results != 0]
    # Calculate the median of non-zero results
    m = np.median(non_zero_results) if non_zero_results.size > 0 else 0
    sys.stderr.write("\t Leaving calcul_med_total\n")
    return m


def calcul_med_same_gc(gc_results, depth_correction_results):
    """
    Calculate the median depth correction for each unique GC content value.

    This function computes the median depth correction values for each unique GC content value, filtering out zero values.

    Parameters
    ----------
    gc_results : list or numpy.ndarray
        A list or array of GC content values.
    depth_correction_results : list or numpy.ndarray
        A list or array of depth correction values.

    Returns
    -------
    dict
        A dictionary where keys are unique GC content values and values are the median depth correction for those GC values.
    """
    sys.stderr.write("\t entering calcul_med_same_gc\n")
    mGC = []
    depth_correction_results_array = np.array(depth_correction_results)
    unique_gc_values = np.unique(gc_results)

    for gc in unique_gc_values:
        indices = np.where(
            gc_results == gc
        )  # Get positions of each unique GC value in gc_results
        # Filter depth correction results to remove zero values
        filtered_depths = depth_correction_results_array[indices][
            depth_correction_results_array[indices] != 0
        ]

        if (
            filtered_depths.size > 0
        ):  # Calculate median only if filtered results are not empty
            median_gc = np.median(filtered_depths)
        else:
            median_gc = 0  # Or another default value if all results are 0

        mGC.append((gc, median_gc))

    gc_to_median = dict(mGC)

    sys.stderr.write("\t Leaving calcul_med_same_gc\n")
    return gc_to_median


def calcul_moy_totale(normalize_depth_results):
    """
    Calculate the mean of non-zero normalized depth results.

    This function filters out zero values from the normalized depth results and computes the mean of the remaining values.

    Parameters
    ----------
    normalize_depth_results : list or numpy.ndarray
        A list or array of normalized depth values.

    Returns
    -------
    float
        The mean of the non-zero normalized depth values, or 0 if no non-zero values are present.
    """
    sys.stderr.write("\t entering calcul_moy_totale\n")
    normalize_depth_results = np.array(normalize_depth_results)
    # Filter results to remove zero values
    non_zero_results = normalize_depth_results[normalize_depth_results != 0]
    # Calculate the mean of non-zero results
    mean_chr = np.mean(non_zero_results) if non_zero_results.size > 0 else 0
    print(mean_chr)
    sys.stderr.write("\t Leaving calcul_moy_totale\n")
    return mean_chr


def calcul_moy_totale_ratio(ratio_par_window_results):
    """
    Calculate the mean of non-zero ratio results per window.

    This function filters out zero values from the ratio results per window and computes the mean of the remaining values.

    Parameters
    ----------
    ratio_par_window_results : list or numpy.ndarray
        A list or array of ratio values per window.

    Returns
    -------
    float
        The mean of the non-zero ratio values per window, or 0 if no non-zero values are present.
    """
    sys.stderr.write("\t entering calcul_moy_totale_ratio\n")
    ratio_par_window_results = np.array(ratio_par_window_results)
    # Filter results to remove zero values
    non_zero_results = ratio_par_window_results[ratio_par_window_results != 0]
    # Calculate the mean of non-zero results
    mean_ratio = np.mean(non_zero_results) if non_zero_results.size > 0 else 0
    print(mean_ratio)
    sys.stderr.write("\t Leaving calcul_moy_totale_ratio\n")
    return mean_ratio


def calcul_std(normalize_depth_results):
    """
    Calculate the standard deviation of non-zero normalized depth results.

    This function filters out zero values from the normalized depth results and computes the standard deviation of the remaining values.

    Parameters
    ----------
    normalize_depth_results : list or numpy.ndarray
        A list or array of normalized depth values.

    Returns
    -------
    float
        The standard deviation of the non-zero normalized depth values, or 0 if no non-zero values are present.
    """
    sys.stderr.write("\t entering calcul_std\n")
    normalize_depth_results = np.array(normalize_depth_results)
    # Filter results to remove zero values
    non_zero_results = normalize_depth_results[normalize_depth_results != 0]
    # Calculate the standard deviation of non-zero results
    std_chr = np.std(non_zero_results) if non_zero_results.size > 0 else 0
    print(std_chr)
    sys.stderr.write("\t Leaving calcul_std\n")
    return std_chr


def calcul_std_ratio(ratio_par_window_results):
    """
    Calculate the standard deviation of non-zero ratio results per window.

    This function filters out zero values from the ratio results per window and computes the standard deviation of the remaining values.

    Parameters
    ----------
    ratio_par_window_results : list or numpy.ndarray
        A list or array of ratio values per window.

    Returns
    -------
    float
        The standard deviation of the non-zero ratio values per window, or 0 if no non-zero values are present.
    """
    sys.stderr.write("\t entering calcul_std_ratio\n")
    ratio_par_window_results = np.array(ratio_par_window_results)
    # Filter results to remove zero values
    non_zero_results = ratio_par_window_results[ratio_par_window_results != 0]
    # Calculate the standard deviation of non-zero results
    std_ratio = np.std(non_zero_results) if non_zero_results.size > 0 else 0
    print(std_ratio)
    sys.stderr.write("\t Leaving calcul_std_ratio\n")
    return std_ratio


def compute_mean_and_std(ratio_par_window_results):
    """
    Compute the mean, standard deviation, and median of non-zero ratio results per window.

    This function filters out zero and -1 values from the ratio results per window and computes the mean, standard deviation, and median of the remaining values. Values greater than or equal to 5 are capped at 5.

    Parameters
    ----------
    ratio_par_window_results : list or numpy.ndarray
        A list or array of ratio values per window.

    Returns
    -------
    tuple
        A tuple containing the mean, standard deviation, and median of the filtered ratio values.
    """
    sys.stderr.write("Computing stats : \n")

    # Filter results to remove zero and -1 values
    ratio_par_window_results = np.array(ratio_par_window_results)
    non_zero_results = ratio_par_window_results[ratio_par_window_results != 0]

    # Initialize list for stats computation
    table = []

    for value in non_zero_results:
        if float(value) >= 5:
            table.append(5)
        elif float(value) != -1:
            table.append(float(value))

    # Calculate the mean, standard deviation, and median of the filtered values
    mean_ratio = np.mean(table) if table else 0
    std_ratio = np.std(table) if table else 0
    med_ratio = np.median(table) if table else 0

    # Display results
    print(mean_ratio, std_ratio, med_ratio)
    sys.stderr.write("Computation done\n")

    # Return results
    return mean_ratio, std_ratio, med_ratio


def cn_level(x):
    """
    Determine the copy number level based on the given value.

    This function returns the copy number level based on the input value `x`.

    Parameters
    ----------
    x : float
        The input value used to determine the copy number level.

    Returns
    -------
    int
        The copy number level:
        - 0 if x < 0.1
        - 1 if 0.1 <= x <= 0.75
        - 2 if 0.75 < x < 1 or round(x) == 1
        - round(x) if round(x) > 1
    """
    if x < 1:
        if x <= 0.75:
            if x >= 0.1:
                return 1
            else:
                return 0
        else:
            return 2
    else:
        if round(x) == 1:
            return 2
        if round(x) > 1:
            return round(x)


def get_sample_name(bamfile_path):
    """
    Extract the sample name from a BAM file.

    This function reads the header of a BAM file to extract the sample name from the read groups.

    Parameters
    ----------
    bamfile_path : str
        The path to the BAM file.

    Returns
    -------
    str
        The sample name extracted from the BAM file. If no sample name is found, returns "UnknownSample".
    """
    with pysam.AlignmentFile(bamfile_path, "rb") as bamfile:
        for read_group in bamfile.header.get("RG", []):
            if "SM" in read_group:
                return read_group["SM"]
    return "UnknownSample"


def create_signal(signal, chr, z_score_results, step_size):
    """
    Create a signal dictionary for a specific chromosome based on z-score results.

    This function populates a signal dictionary with positions and corresponding z-score results for a given chromosome.

    Parameters
    ----------
    signal : dict
        A dictionary to store the signal data.
    chr : str
        The chromosome for which the signal is created.
    z_score_results : list or numpy.ndarray
        A list or array of z-score results.
    step_size : int
        The step size used to calculate the positions.

    Returns
    -------
    None
        The function modifies the signal dictionary in place.
    """
    if chr not in signal:
        signal[chr] = {}
    for i in range(len(z_score_results)):
        pos = (i * step_size) + 1
        signal[chr][pos] = z_score_results[i]


def detect_events(
    z_score_results,
    zscore_threshold,
    events,
    med_ratio,
    ratio_par_mean_ratio_results,
    chr,
):
    """
    Detect genomic events based on z-score results and a z-score threshold.

    This function identifies significant genomic events where z-scores exceed the given threshold. Events are recorded in the `events` dictionary for the specified chromosome.

    Parameters
    ----------
    z_score_results : list or numpy.ndarray
        A list or array of z-score values.
    zscore_threshold : float
        The threshold for detecting significant z-score events.
    events : dict
        A dictionary to store detected events.
    med_ratio : float
        The median ratio used for copy number level calculations.
    ratio_par_mean_ratio_results : list or numpy.ndarray
        A list or array of ratio values compared to the mean ratio.
    chr : str
        The chromosome for which events are detected.

    Returns
    -------
    None
        The function modifies the events dictionary in place.
    """
    sys.stderr.write("\t starting detect_events\n")
    c = 0
    for i, z_score in enumerate(z_score_results):
        if (z_score <= -zscore_threshold) or (z_score >= zscore_threshold):
            if chr not in events:
                events[chr] = {}
            if med_ratio == 0:
                c = 0
            else:
                c = cn_level(float(ratio_par_mean_ratio_results[i]))

            if z_score >= zscore_threshold:
                c = 3
            elif c == 2 and z_score <= -zscore_threshold:
                c = 1
            pos_start = (i * step_size) + 1
            pos_end = pos_start + window_size

            events[chr][(pos_start, pos_end)] = c
    sys.stderr.write("\t ending detect_events\n")


def segmentation(events, segment):
    """
    Segment the detected events into contiguous regions with the same copy number level.

    This function processes the detected events and groups contiguous regions with the same copy number level into segments.

    Parameters
    ----------
    events : dict
        A dictionary of detected events for each chromosome.
    segment : dict
        A dictionary to store the segmented regions.

    Returns
    -------
    None
        The function modifies the segment dictionary in place.
    """
    sys.stderr.write("starting segmentation : \n")
    for k in events.keys():
        sys.stderr.write("\tfor chromosome %s\n" % k)
        starts = 0
        oldPos = 0
        oldLevel = -1
        for p in sorted(events[k].keys()):
            level = events[k][p]
            # new coordinates
            if p[0] > (oldPos + window_size):
                if (starts != 0) and (starts != p[0]):
                    if k not in segment:
                        segment[k] = {}
                    index = str(starts) + "-" + str(oldPos)
                    segment[k][index] = {}
                    segment[k][index]["start"] = starts
                    segment[k][index]["end"] = oldPos + window_size
                    segment[k][index]["cn"] = oldLevel
                    oldPos = p[0]
                    starts = p[0]
                    oldLevel = level
                    continue
                else:
                    starts = p[0]
            # case where it's contiguous but different level
            if level != oldLevel:
                if oldLevel != -1:
                    if k not in segment:
                        segment[k] = {}
                    index = str(starts) + "-" + str(oldPos)
                    segment[k][index] = {}
                    segment[k][index]["start"] = starts
                    segment[k][index]["end"] = oldPos
                    segment[k][index]["cn"] = oldLevel
                    oldPos = p[0]
                    starts = p[0]
                    oldLevel = level
                    continue
                else:
                    oldLevel = level
            oldPos = p[0]
            oldLevel = level
    sys.stderr.write("segmentation done\n")


def display_results_vcf(sample, segment, signal, lengthFilter, output_file):
    """
    Generate a VCF file containing structural variant calls based on segmented regions and signal data.

    This function creates a VCF (Variant Call Format) file containing structural variant calls derived from segmented regions and signal data. The structural variant type (DEL for deletion or DUP for duplication) is determined based on copy number levels and signal values. The resulting VCF file includes information about the chromosome, position, type of structural variant, copy number, and other relevant information.

    Parameters
    ----------
    sample : str
        The sample name to be included in the VCF file header.
    segment : dict
        A dictionary containing segmented regions with copy number information.
    signal : dict
        A dictionary containing signal data for each chromosome.
    lengthFilter : int
        The minimum length threshold for including variants in the VCF file.
    output_file : str
        The path to the output VCF file.

    Returns
    -------
    None
        This function writes the structural variant calls to the specified output file in VCF format.
    """
    global header_written
    sys.stderr.write("starting display results\n")

    with open(output_file, "a") as f:
        if not header_written:
            f.write("##fileformat=VCFv4.2\n")
            f.write("##source=cnvcaller\n")
            f.write(
                '##INFO=<ID=SVTYPE,Number=1,Type=String,Description="Type of structural variant">\n'
            )
            f.write(
                '##INFO=<ID=SVLEN,Number=.,Type=Integer,Description="Difference in length between REF and ALT alleles">\n'
            )
            f.write(
                '##INFO=<ID=END,Number=1,Type=Integer,Description="End position of the variant described in this record">\n'
            )
            f.write('##INFO=<ID=CN,Number=1,Type=Integer,Description="Copy number">\n')
            f.write('##ALT=<ID=DEL,Description="Deletion">\n')
            f.write('##ALT=<ID=DUP,Description="Duplication">\n')
            f.write('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n')
            f.write(
                "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s\n" % (sample)
            )
            header_written = True

        for k in segment.keys():
            for elt in sorted(segment[k].keys()):
                if segment[k][elt]["start"] == segment[k][elt]["end"]:
                    continue
                if (segment[k][elt]["end"] - segment[k][elt]["start"]) < lengthFilter:
                    continue
                if int(signal[k][segment[k][elt]["start"]]) < 0:
                    f.write(
                        "%s\t%s\t.\tN\t<DEL>\t.\t.\tSVTYPE=DEL;END=%s;VALUE=%s\tGT:GQ\t./.:0\n"
                        % (
                            k,
                            segment[k][elt]["start"],
                            segment[k][elt]["end"],
                            int(segment[k][elt]["cn"]),
                        )
                    )
                else:
                    f.write(
                        "%s\t%s\t.\tN\t<DUP>\t.\t.\tSVTYPE=DUP;END=%s;VALUE=%s\tGT:GQ\t./.:0\n"
                        % (
                            k,
                            segment[k][elt]["start"],
                            segment[k][elt]["end"],
                            int(segment[k][elt]["cn"]),
                        )
                    )


#################################
###### <---Fonction main--->######
#################################
def main_calcul(
    bamfile_path,
    chr,
    seq_length,
    window_size,
    step_size,
    zscore_threshold,
    lengthFilter,
    output_file,
    sample,
):
    """
    Perform structural variant detection and VCF file generation.

    This function orchestrates a series of computations and data manipulations,
    leveraging GPU acceleration for performance improvements in genomic data analysis.

    Parameters
    ----------
        bamfile_path : str
            Path to the BAM file containing aligned reads.
        chr : str
            Chromosome identifier for which analysis is performed.
        seq_length : int
            Length of the chromosome sequence.
        window_size : int
            Size of the sliding window used for analysis.
        step_size : int
            Size of the step when moving the window along the chromosome.
        zscore_threshold : float
            Threshold value for detecting significant events based on Z-scores.
        lengthFilter : int
            Minimum length threshold for including variants in the VCF file.
        output_file : str
            Path to the output VCF file.
        sample : str
            Name of the sample being analyzed.

    Returns
    -------
        None
    """

    sys.stderr.write("\t entering main_calcul\n")
    global seq
    events = {}
    segment = {}
    signal = {}

    # Appeler les différentes fonctions
    map_data = calcul_mappability(seq_length, mappability, chr)
    gc_data = calcul_gc_content(seq_length, chr, seq)
    depth_data = calcul_depth_seq(seq_length, bamfile_path, chr)

    # Transférer le tableau NumPy vers CUDA
    d_depth_data = cuda.mem_alloc(depth_data.nbytes)
    cuda.memcpy_htod(d_depth_data, depth_data)
    sys.stderr.write(
        "\t d_depth_data : %s, %s\n"
        % (d_depth_data, d_depth_data.as_buffer(sys.getsizeof(depth_data)))
    )

    d_gc_data = cuda.mem_alloc(gc_data.nbytes)
    cuda.memcpy_htod(d_gc_data, gc_data)
    sys.stderr.write(
        "\t d_gc_data : %s, %s\n"
        % (d_gc_data, d_gc_data.as_buffer(sys.getsizeof(gc_data)))
    )

    d_map_data = cuda.mem_alloc(map_data.nbytes)
    cuda.memcpy_htod(d_map_data, map_data)
    sys.stderr.write(
        "\t d_map_data : %s, %s\n"
        % (d_map_data, d_map_data.as_buffer(sys.getsizeof(map_data)))
    )

    # Calculer la taille de la grille et de bloc pour CUDA
    block_size = num_cores
    sys.stderr.write("\t blocksize (nb de threads) = %s\n" % (num_cores))
    grid_size = int((int((seq_length - window_size) / step_size) + 1) / block_size) + 1
    sys.stderr.write("\t grid_size = \n")

    # Initialiser les tableaux pour stocker les résultats
    depth_results = np.zeros(
        int((seq_length - window_size) / step_size) + 1, dtype=np.float32
    )
    sys.stderr.write("\t Definition de depth_results\n")

    gc_results = np.zeros(
        int((seq_length - window_size) / step_size) + 1, dtype=np.float32
    )
    sys.stderr.write("\t Definition de gc_results\n")

    map_results = np.zeros(
        int((seq_length - window_size) / step_size) + 1, dtype=np.float32
    )
    sys.stderr.write("\t Definition de map_results\n")

    depth_correction_results = np.zeros(
        int((seq_length - window_size) / step_size) + 1, dtype=np.float32
    )
    sys.stderr.write("\t Definition de depth_correction_results\n")

    normalize_depth_results = np.zeros(
        int((seq_length - window_size) / step_size) + 1, dtype=np.float32
    )
    sys.stderr.write("\t Definition de normalize_depth_results\n")

    ratio_par_window_results = np.zeros(
        int((seq_length - window_size) / step_size) + 1, dtype=np.float32
    )
    sys.stderr.write("\t Definition de ratio_par_window\n")

    z_score_results = np.zeros(
        int((seq_length - window_size) / step_size) + 1, dtype=np.float32
    )
    sys.stderr.write("\t Definition de z_score_results\n")

    ratio_par_mean_ratio_results = np.zeros(
        int((seq_length - window_size) / step_size) + 1, dtype=np.float32
    )
    sys.stderr.write("\t Definition de ratio_par_mean_ratio_results\n")

    # Allouer de la mémoire pour les résultats sur le périphérique CUDA
    d_depth_results = cuda.mem_alloc(depth_results.nbytes)
    sys.stderr.write(
        "\t d_depth_results = %s\n"
        % d_depth_results.as_buffer(sys.getsizeof(d_depth_results))
    )
    sys.stderr.write("\t depth_results.nbytes = %s\n" % depth_results.nbytes)

    d_gc_results = cuda.mem_alloc(gc_results.nbytes)
    sys.stderr.write(
        "\t d_gc_results = %s\n" % d_gc_results.as_buffer(sys.getsizeof(d_gc_results))
    )
    sys.stderr.write("\t gc_results.nbytes = %s\n" % gc_results.nbytes)

    d_map_results = cuda.mem_alloc(map_results.nbytes)
    sys.stderr.write(
        "\t d_map_results = %s\n"
        % d_map_results.as_buffer(sys.getsizeof(d_map_results))
    )
    sys.stderr.write("\t map_results.nbytes = %s\n" % map_results.nbytes)

    d_depth_correction_results = cuda.mem_alloc(depth_correction_results.nbytes)
    sys.stderr.write(
        "\t d_depth_correction_results = %s\n"
        % d_depth_correction_results.as_buffer(
            sys.getsizeof(d_depth_correction_results)
        )
    )
    sys.stderr.write(
        "\t depth_correction_results.nbytes = %s\n" % depth_correction_results.nbytes
    )

    d_normalize_depth_results = cuda.mem_alloc(normalize_depth_results.nbytes)
    sys.stderr.write(
        "\t d_normalize_depth_results = %s\n"
        % d_normalize_depth_results.as_buffer(sys.getsizeof(d_normalize_depth_results))
    )
    sys.stderr.write(
        "\t normalize_depth_results.nbytes = %s\n" % normalize_depth_results.nbytes
    )

    d_ratio_par_window_results = cuda.mem_alloc(ratio_par_window_results.nbytes)
    sys.stderr.write(
        "\t d_ratio_par_window_results = %s\n"
        % d_ratio_par_window_results.as_buffer(
            sys.getsizeof(d_ratio_par_window_results)
        )
    )
    sys.stderr.write(
        "\t ratio_par_window_results.nbytes = %s\n" % ratio_par_window_results.nbytes
    )

    d_z_score_results = cuda.mem_alloc(z_score_results.nbytes)
    sys.stderr.write(
        "\t d_z_score_results = %s\n"
        % d_z_score_results.as_buffer(sys.getsizeof(d_z_score_results))
    )
    sys.stderr.write("\t z_score_results.nbytes = %s\n" % z_score_results.nbytes)

    d_ratio_par_mean_ratio_results = cuda.mem_alloc(ratio_par_mean_ratio_results.nbytes)
    sys.stderr.write(
        "\t d_ratio_par_mean_ratio_results = %s\n"
        % d_ratio_par_mean_ratio_results.as_buffer(
            sys.getsizeof(d_ratio_par_mean_ratio_results)
        )
    )
    sys.stderr.write(
        "\t ratio_par_mean_ratio_results.nbytes = %s\n"
        % ratio_par_mean_ratio_results.nbytes
    )

    # Appeler la fonction de calcul de profondeur avec CUDA
    calcul_depth_kernel_cuda(
        d_depth_data,
        np.int32(seq_length),
        np.int32(window_size),
        np.int32(step_size),
        d_depth_results,
        block=(block_size, 1, 1),
        grid=(grid_size, 1),
    )
    sys.stderr.write("\t appel fonction calcul_depth_kernel_cuda\n")

    calcul_gc_kernel_cuda(
        d_gc_data,
        np.int32(seq_length),
        np.int32(window_size),
        np.int32(step_size),
        d_gc_results,
        block=(block_size, 1, 1),
        grid=(grid_size, 1),
    )
    sys.stderr.write("\t appel fonction calcul_gc_kernel_cuda\n")

    calcul_map_kernel_cuda(
        d_map_data,
        np.int32(seq_length),
        np.int32(window_size),
        np.int32(step_size),
        d_map_results,
        block=(block_size, 1, 1),
        grid=(grid_size, 1),
    )
    sys.stderr.write("\t appel fonction calcul_map_kernel_cuda\n")

    calcul_depth_correction_kernel_cuda(
        d_depth_results,
        d_map_results,
        np.int32(seq_length),
        np.int32(window_size),
        np.int32(step_size),
        d_depth_correction_results,
        block=(block_size, 1, 1),
        grid=(grid_size, 1),
    )
    sys.stderr.write("\t appel fonction calcul_depth_correction_kernel_cuda\n")

    context.synchronize()

    # Copier les résultats depuis le périphérique CUDA vers l'hôte
    # cuda.memcpy_dtoh(dest, src)
    cuda.memcpy_dtoh(depth_results, d_depth_results)
    sys.stderr.write(
        "\t Copie les resultats du GPU (d_depth_results) vers le CPU (depth_results)\n"
    )

    cuda.memcpy_dtoh(gc_results, d_gc_results)  # cuda.memcpy_dtoh(dest, src)
    sys.stderr.write(
        "\t Copie les resultats du GPU (d_gc_results) vers le CPU (gc_results)\n"
    )

    cuda.memcpy_dtoh(map_results, d_map_results)  # cuda.memcpy_dtoh(dest, src)
    sys.stderr.write(
        "\t Copie les resultats du GPU (d_map_results) vers le CPU (map_results)\n"
    )

    cuda.memcpy_dtoh(
        depth_correction_results, d_depth_correction_results
    )  # cuda.memcpy_dtoh(dest, src)
    sys.stderr.write(
        "\t Copie les resultats du GPU (d_depth_correction_results) vers le CPU (depth_correction_results)\n"
    )

    ### NORMALISATION###

    # Appel fonctions medianes
    sys.stderr.write("\t appel fonctions calcul medianes\n")
    m = calcul_med_total(depth_correction_results)
    gc_to_median = calcul_med_same_gc(gc_results, depth_correction_results)

    # Convertir gc_to_median en un tableau NumPy pour le transfert vers CUDA
    sys.stderr.write("\t Conversion medianes en tableau numpy\n")
    gc_to_median_array = np.zeros(int(max(gc_results)) + 1, dtype=np.float32)
    for gc, median in gc_to_median.items():
        gc_to_median_array[int(gc)] = median

    # Allouer de la memoire pour gc_to_median sur le peripherique CUDA
    sys.stderr.write("\t Allocation mémoire médianes GPU\n")
    d_gc_to_median = cuda.mem_alloc(gc_to_median_array.nbytes)
    cuda.memcpy_htod(d_gc_to_median, gc_to_median_array)

    # Appeler le kernel de normalisation
    normalize_depth_kernel_cuda(
        d_depth_correction_results,
        d_gc_results,
        np.float32(m),
        d_gc_to_median,
        np.int32(seq_length),
        np.int32(window_size),
        np.int32(step_size),
        d_normalize_depth_results,
        block=(block_size, 1, 1),
        grid=(grid_size, 1),
    )
    sys.stderr.write("\t appel fonction normalize_depth_kernel_cuda\n")

    context.synchronize()

    # Copier les resultats normalises depuis le peripherique CUDA vers l'hote
    cuda.memcpy_dtoh(normalize_depth_results, d_normalize_depth_results)

    ### Ratio par window###

    # Appel fonction moyenne
    sys.stderr.write("\t appel fonction calcul moyenne\n")
    mean_chr = calcul_moy_totale(normalize_depth_results)

    # Appeler le kernel de ratio
    ratio_par_window_kernel_cuda(
        d_normalize_depth_results,
        np.float32(mean_chr),
        np.int32(seq_length),
        np.int32(window_size),
        np.int32(step_size),
        d_ratio_par_window_results,
        block=(block_size, 1, 1),
        grid=(grid_size, 1),
    )
    sys.stderr.write("\t appel fonction ratio_par_window_kernel_cuda\n")

    context.synchronize()

    # Copier les resultats ratio depuis le peripherique CUDA vers l'hote
    cuda.memcpy_dtoh(ratio_par_window_results, d_ratio_par_window_results)

    # Création table à partir du ratio
    mean_ratio, std_ratio, med_ratio = compute_mean_and_std(ratio_par_window_results)

    # Appeler le kernel de calcule du ratio divisé par le ratio moyen
    ratio_par_mean_ratio_kernel_cuda(
        d_ratio_par_window_results,
        np.float32(mean_ratio),
        np.int32(seq_length),
        np.int32(window_size),
        np.int32(step_size),
        d_ratio_par_mean_ratio_results,
        block=(block_size, 1, 1),
        grid=(grid_size, 1),
    )
    sys.stderr.write("\t appel fonction ratio_par_mean_ratio_kernel_cuda\n")

    # Appeler le kernel de Z-score
    z_score_kernel_cuda(
        d_ratio_par_window_results,
        np.float32(mean_ratio),
        np.float32(std_ratio),
        np.int32(seq_length),
        np.int32(window_size),
        np.int32(step_size),
        d_z_score_results,
        block=(block_size, 1, 1),
        grid=(grid_size, 1),
    )
    sys.stderr.write("\t appel fonction z_score_kernel_cuda\n")

    context.synchronize()

    # Copier les resultats ratio depuis le peripherique CUDA vers l'hote
    cuda.memcpy_dtoh(ratio_par_mean_ratio_results, d_ratio_par_mean_ratio_results)
    cuda.memcpy_dtoh(z_score_results, d_z_score_results)

    # Appel fonction create signal
    create_signal(signal, chr, z_score_results, step_size)

    # Appel fonction detect events
    detect_events(
        z_score_results,
        zscore_threshold,
        events,
        med_ratio,
        ratio_par_mean_ratio_results,
        chr,
    )

    # Appel fonction segmentation
    segmentation(events, segment)

    # Appel fonction display_results_vcf
    display_results_vcf(sample, segment, signal, lengthFilter, output_file)


# Programme principal
# Calcul nombre de coeurs max pour le GPU
header_written = False

sample = get_sample_name(bamfile_path)
device = cuda.Device(0)
attributes = device.get_attributes()
num_cores = attributes[1]
print("Nombre de CPU: ", multiprocessing.cpu_count())
print(f"Nombre de coeurs max GPU: {num_cores}")

gc_file = "/work/gad/shared/pipeline/grch38/index/grch38_essential.fa"
mappability_file = "/work/gad/shared/analyse/test/cnvGPU/test_scalability/wgEncodeCrgMapabilityAlign100mer_no_uniq.grch38.bedgraph"
seq = parse_fasta(gc_file)
mappability = dico_mappabilite(mappability_file)

with pysam.AlignmentFile(bamfile_path, "rb") as bamfile_handle:
    for i, seq_length in enumerate(bamfile_handle.lengths):
        chr = bamfile_handle.references[i]
        sys.stderr.write("Chromosome : %s, seq length : %s\n" % (chr, seq_length))

        # Appeler la fonction de calcul de la profondeur moyenne pour ce chromosome
        main_calcul(
            bamfile_handle,
            chr,
            seq_length,
            window_size,
            step_size,
            zscore_threshold,
            lengthFilter,
            output_file,
            sample,
        )

    logging.basicConfig(
        filename="%s" % (logfile),
        filemode="a",
        level=logging.INFO,
        format="%(asctime)s %(levelname)s - %(message)s",
    )
    logging.info("end")