#  ### GAD development file  ###
#  ## hiv_genotyping.py
#  ## Version : 1.0
#  ## Description : 
#  ## Usage : hiv_genotyping.py
#  ## Output : 
#  ## Requirements : python 2.7+ - pySAM - scipy 
#  
#  ## Author: yannis.duffourd@u-bourgogne.fr
#  ## Creation Date: 2016-10-18
#  ## last revision date: 2018-03-19
#  ## Known bugs: None
#  ## TODO: 
#  





import os
import sys

sys.path.insert( 0 , "/user1/gad/ya0902du/miniconda2/lib/python2.7/site-packages")


print sys.path


import getopt
import logging
import pysam 
import logging
import math
import threading
from scipy import stats
# ~ import matplotlib.pyplot as plt

currentThread = 0
integraseFile , samFile , rtFile , proteaseFile , referenceFile = "" , "" , "" , "" , ""
depthThreshold = 200
ratioThreshold = 0
pvalueThreshold = 1.0
annotationIsDefined = False
inRunControlBam = ""
allControlBamList = ""
nbThread = 1
countStop = False


opts, args = getopt.getopt(sys.argv[1:], 's:i:r:p:R:D:l:P:o:c:b:a:n:S')  
for opt, arg in opts:
	if opt in ("-i"):
		integraseFile = arg
	elif opt in ("-s"):
		samFile = arg
	elif opt in ("-r"):
		rtFile = arg
	elif opt in ("-p"):
		proteaseFile = arg
	elif opt in ("-R"):
		referenceFile = arg
	elif opt in ("-D"):
		depthThreshold = int(arg)
	elif opt in ("-l"):
		ratioThreshold = int(arg)
	elif opt in ("-P"):
		pvalueThreshold = float(arg)
	elif opt in ("-o"):
		outputFile = arg
	elif opt in ("-c"):
		codonResultFile = arg
	elif opt in ("-b"):
		inRunControlBam = arg
	elif opt in ("-a"):
		allControlBamList = arg
	elif opt in ("-n"):
		nbThread = int( arg )
	elif opt in ("-S"):
		countStop = True
		sys.stderr.write('Stop codon detection activated.')

sample = samFile.split( "/" ) [-1]
sample = sample.replace( ".bam" , "" )
sample = sample.split( "." )[0]

if countStop == True :
	codonResultFile = "codons.stop.tsv" 
else:
	codonResultFile = "codons.tsv"



sys.stderr.write('start\n')

class myThread (threading.Thread):
    def __init__( self , ref , refCodon , refAA , position , codon , obsAA , count , totalOnPos , countCtrl , totalOnPosReference , nb ) :
		threading.Thread.__init__(self)
		self.ref = ref
		self.refCodon = refCodon
		self.refAA = refAA
		self.position = position
		self.nb = nb
		self.codon = codon
		self.obsAA = obsAA
		self.count = count
		self.totalOnPos = totalOnPos
		self.countCtrl = countCtrl
		self.totalOnPosReference = totalOnPosReference
		
		threadLock.acquire()
		sys.stderr.write( "New thread number %s , launch with params : %s , %s , %s , %s , %s , %s , %s , %s , %s , %s\n" % ( nb , ref , refCodon , refAA , position , codon , obsAA , count , totalOnPos , countCtrl , totalOnPosReference ) )
		threadLock.release()

    def run( self ):		
		global currentThread 
		oddsratio,pValue = stats.fisher_exact([[self.count , self.totalOnPos] , [self.countCtrl , self.totalOnPosReference]] )
		threadLock.acquire()
		sys.stderr.write( "p-value : %s \n" % pValue )
		if pValue < 0.01:
			mutationFileResult.write( "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" % ( self.ref , self.refCodon , self.refAA , self.position , self.codon , self.obsAA , self.count , self.totalOnPos , self.countCtrl , self.totalOnPosReference , pValue ))
		currentThread -= 1
		threadLock.release()	



    
def phredToInt( incString ):
    # convert to ascii 
    score = ord( incString )
    
    # substract 33 to the score 
    score -= 33
    
    return score


def igf( lddl , chiStat ):
    if chiStat < 0.0:
		return 0.0
    
    sc = 1.0 / lddl
    sc *= chiStat**lddl
    sc *= math.exp( -chiStat )
    sys.stderr.write( "sc value = %s\n" % ( sc ) )
    
    somme = 1.0
    nom = 1.0
    denom = 1.0
    
    for i in range( 1 , 200):
		nom *= chiStat
		lddl += 1
		denom *= lddl
		somme += (nom / denom)
		
    if ( somme * sc ) < 0.000000000000001:
		sys.stderr.write( "igf value = %s\n" % (0.000000000000001) )
		return 0.000000000000001
    
    sys.stderr.write( "igf value = %s\n" % (somme * sc ) )
    return somme * sc
    

def gamma( lddl ):
    RECIP_E = 0.36787944117144232159552377016147
    TWOPI = 6.283185307179586476925286766559
    D = 1.0 / (10.0 * lddl)
    D = 1.0 / ((12 * lddl) - D)
    D = (D + lddl) * RECIP_E
    D = D**lddl
    D *= math.sqrt(TWOPI / lddl)
    
    
    sys.stderr.write( "gamma = %s\n" % D )
    return D


def adequationChiTwo( effectTable , localRef , localAlt  ):
    
    
    # compute deviation between obs & theoritical
    thMean = sum( effectTable ) / len( effectTable )
    obsMean = localAlt * 100 / ( float( localRef ) + float( localAlt  ) )
    
    
    if obsMean < thMean:
		return 1
    
    if obsMean == 100:
		obsMean = 99.99999
    sys.stderr.write( "Th mean : %s ; obsMean : %s\n" % ( thMean , obsMean ))
    
    refDev = ( ( (100 - obsMean) - (100 - thMean) )**2 ) /  (100 - obsMean)
    altDev = ( ( obsMean - thMean )**2 ) / obsMean
    ddl = 1
    qSquare = refDev + altDev
    alpha = 5
    
    sys.stderr.write( "X2 statistical value = %s\n" % qSquare )
    
    # with alpha = 5% and a ddl = 1, the confidence interval is [0;5.991] 
    # H0 is rejected if qSquare superior to  5.991 with a 5% error risk
    # so there is asignificant difference between the position wie observed and the rest of the reference. 
    # need to compute a p-value
    
    return pochisq( qSquare , 1 )
    
    #~ pvalue = igf( ddl , qSquare )
    #~ pvalue /= gamma( ddl )
    
    #~ sys.stderr.write( "p-value = %s\n" % (1.0 - pvalue) )
    #~ return 1.0 - pvalue
    
def pochisq(x, df) :
       
    LOG_SQRT_PI = 0.5723649429247000870717135
    I_SQRT_PI = 0.5641895835477562869480795

    if x <= 0.0 or df < 1 :
	return 1.0

    s = 2.0 * poz( - math.sqrt( x ) )
    
    sys.stderr.write( " p-value : %s\n" % s )
    return s
 
def poz(z) :
    sys.stderr.write( "Incoming z : %s\n" % z )

    y = 0 
    x = 0 
    w = 0
    Z_MAX = 6.0          

    if z == 0.0 :
		x = 0.0
    else :
		y = 0.5 * abs(z)
		
		if (y >= (Z_MAX * 0.5)): 
			x = 1.0
		elif y < 1.0 :
			w = y * y
			x = ((((((((0.000124818987 * w - 0.001075204047) * w + 0.005198775019) * w - 0.019198292004) * w + 0.059054035642) * w
				 - 0.151968751364) * w + 0.319152932694) * w - 0.531923007300) * w + 0.797884560593) * y * 2.0
		else :
			y -= 2.0
			x = (((((((((((((-0.000045255659 * y + 0.000152529290) * y - 0.000019538132) * y - 0.000676904986) * y + 0.001390604284) * y -   0.000794620820) * y - 0.002034254874) * y + 0.006549791214) * y - 0.010557625006) * y + 0.011630447319) * y - 0.009279453341) * y + 0.005353579108) * y - 0.002141268741) * y + 0.000535310849) * y + 0.999936657524
		
	
    if z > 0.0 :
		sys.stderr.write( "Output : %s\n" % ((x + 1.0) * 0.5) )
		return ((x + 1.0) * 0.5)
    else:
		sys.stderr.write( "Output : %s\n" % ((1.0 - x) * 0.5) )
		return ((1.0 - x) * 0.5)

# need the bam already parsed to be used.
def getEffectiveTable( ref , excludeList ):
    localEffectifTable = []
    for position in bamCodon[ref].keys():
	
		# pass the position if new as a mutation
		if position in excludeList:
			continue
			
		totalOnPos = 0
		totalAltOnPos = 0

		refAA = codonTable[ referenceCodon[ref][position] ]
		for codon in bamCodon[ref][position].keys():
			totalOnPos += bamCodon[ref][position][codon]
			obsAA = codonTable[ codon ]
			
			if obsAA != refAA:
				totalAltOnPos += bamCodon[ref][position][codon]

		if totalAltOnPos == 0:
			localEffectifTable.append( 0 )
			#~ sys.stderr.write( "Adding 0 to %s effectives for pos : %s \n" % ( ref , position ) )
		else:
			localEffectifTable.append( totalAltOnPos * 100 / float( totalOnPos ) )
			#~ sys.stderr.write( "Adding %s to %s effectives for pos : %s \n" % ( totalAltOnPos * 100 / float( totalOnPos ) , ref , position ) )
			
    return localEffectifTable

# compute factoriel
def fact(n):
    """fact(n): calcule la factorielle de n (entier >= 0)"""
    if n<2:
        return 1
    else:
        return n*fact(n-1)

# compute log factoriel
def logFactoriel( inc ):
    ret = 0
    while inc > 0:
		ret += math.log( inc )
		inc -= 1
    return ret;

# compute log of hypergeometrique for FET
def logHypergeometricProb( a ,  b , c  , d ):
    return logFactoriel( a + b ) + logFactoriel( c + d ) + logFactoriel( a + c ) + logFactoriel( b + d )- logFactoriel( a ) - logFactoriel( b ) - logFactoriel( c ) - logFactoriel( d ) - logFactoriel( a + b + c + d )

# compute pvalue from FET
def FETPvalue( a , b , c, d ):
	#sys.stderr.write( "Computing a p-value from FET ..." )
	n = a + b + c + d 
	logpCutOff = logHypergeometricProb( a , b , c , d )
	pFraction = 0
	logpValue = 0
	
	for x in range( 0 , n ):
		if( ( a + b - x >= 0 ) and ( a + c - x >= 0 ) and ( d - a + x >= 0 ) ) :
			l = logHypergeometricProb( x , a + b - x , a + c - x , d - a + x )
			if  l <= logpCutOff :
				pFraction += math.exp( l - logpCutOff )

	logpValue = logpCutOff + math.log( pFraction )
	#sys.stderr.write( " done \n" )
	return math.exp(logpValue);

# determine if a codon is a stop or not
def isStop( incCodon ):
	stopList = [ "TAG" , "TAA" , "TGA" ]
	if incCodon in stopList :
		return True 
	else:
		return False

# deal with the ref genome
sys.stderr.write( "Parsing reference genome ..." )
referenceSequence = {}
referenceStream = open( referenceFile , "r" )
currentChr = ""

for line in referenceStream:
    # name of the sequence
    if line.startswith( ">" ):
		line = line.strip()
		line = line.replace( ">" , "" )
		currentChr = line
		referenceSequence[line] = {}
    else:
		i = 1
		for base in line:
			referenceSequence[currentChr][i] = base
			i += 1
referenceStream.close()

#~ # debug 
#~ for i in referenceSequence.keys(): 
    #~ sys.stderr.write( "%s : %s\n" % (i , referenceSequence[i] )) 
# end of debug

# codon table. 
referenceCodon = {}
for chrom in referenceSequence.keys():
    referenceCodon[chrom] = {}
    codon = ""
    for pos in  referenceSequence[chrom].keys():
		codon += referenceSequence[chrom][pos]
		if pos % 3 == 0 :
			codon = codon.upper()
			referenceCodon[chrom][((pos-3)/3)+1] = codon
			codon = ""

# debug 
#~ for i in referenceCodon.keys(): 
    #~ sys.stderr.write( "%s : %s\n" % (i , referenceCodon[i] )) 
# end of debug	    

sys.stderr.write( " done\n" )

knownMutationTable = {}


sys.stderr.write( "Parsing genotyping tables ...\n" )
if integraseFile != "":
    annotationIsDefined = True
    refFound = False
    # parse & store genotyping files
    
    integraseStream = open( integraseFile,  "r" )
    for line in integraseStream:
		if line.startswith( " " ) or line.startswith( "Position" ) or line.startswith( "AA" ) or  len(line) < 10:
			continue
		
		if line.startswith( "#" ) :
			temp = line.strip().split( "#" )
			ref = temp[1]
			knownMutationTable[ref] = {}
			refFound = True
			continue
	
		if not refFound :
			sys.stderr.write( "Error : %s file is malformed. Please redo the analysis without it or correct the file (see doc at http://blabla.html)\n" % integraseFile )
			sys.exit( 1 )
			
		lineTable = line.strip().split( "\t" )
		index = lineTable[2] + lineTable[0] + lineTable[4]
		#~ sys.stderr.write( "index : %s\n" % index )
		if lineTable[2] == lineTable[4]:
			continue
	
		knownMutationTable[ref][index] = lineTable[5]
	
    integraseStream.close()
    sys.stderr.write( "\tINTEGRASE : OK\n" )
else:
    sys.stderr.write( "\tINTEGRASE : no file provided\n" )

if rtFile != "":
    annotationIsDefined = True
    refFound = False
    # parse & store genotyping files
    
    reverseTranscriptaseStream = open( rtFile,  "r" )
    for line in reverseTranscriptaseStream:
		if line.startswith( " " ) or line.startswith( "Position" ) or line.startswith( "AA" ) or  len(line) < 10:
			continue
		
		if line.startswith( "#" ) :
			temp = line.strip().split( "#" )
			ref = temp[1]
			knownMutationTable[ref] = {}
			refFound = True
			continue
		
		if not refFound :
			sys.stderr.write( "Error : %s file is malformed. Please redo the analysis without it or correct the file (see doc at http://blabla.html)\n" % rtFile )
			sys.exit( 1 )
		lineTable = line.strip().split( "\t" )
		index = lineTable[2] + lineTable[0] + lineTable[4]
		if lineTable[2] == lineTable[4]:
			continue
		knownMutationTable[ref][index] = lineTable[5]
	
    reverseTranscriptaseStream.close()
    sys.stderr.write( "\tREVERSETRANSCRIPTASE : OK\n" )
else:
    sys.stderr.write( "\tREVERSETRANSCRIPTASE : no file provided\n" )
	
if proteaseFile != "":
    annotationIsDefined = True
    refFound = False
    # parse & store genotyping files
    
    proteaseStream = open( proteaseFile,  "r" )
    for line in proteaseStream:
		if line.startswith( " " ) or line.startswith( "Position" ) or line.startswith( "AA" ) or  len(line) < 10:
			continue
		
		if line.startswith( "#" ) :
			temp = line.strip().split( "#" )
			ref = temp[1]
			knownMutationTable[ref] = {}
			refFound = True
			continue
		
		if not refFound :
			sys.stderr.write( "Error : %s file is malformed. Please redo the analysis without it or correct the file (see doc at http://blabla.html)\n" % proteaseFile )
			sys.exit( 1 )
		lineTable = line.strip().split( "\t" )
		index = lineTable[2] + lineTable[0] + lineTable[4]
		if lineTable[2] == lineTable[4]:
			continue
		knownMutationTable[ref][index] = lineTable[5]
	
    proteaseStream.close()
    sys.stderr.write( "\tPROTEASE : OK\n" )
else:
    sys.stderr.write( "\tPROTEASE : no file provided\n" )
   


sys.stderr.write( " ... done\n" )
sys.stderr.write( "Parsing bam to genotype ... " )

resultTable = {}
qualityTable = {}

resultTable["ReverseTranscriptase"] = {}
resultTable["Protease"] = {}
resultTable["Integrase"] = {}

qualityTable["ReverseTranscriptase"] = {}
qualityTable["Protease"] = {}
qualityTable["Integrase"] = {}

effectifTable = {}

###### deal with inRunControlBam 
sys.stderr.write( "Parsing bam control file : %s ..." % ( inRunControlBam ) )
bamIterRef = pysam.AlignmentFile( inRunControlBam , "r" )
bamCodonReference = {}
# samStream = open( samFile , "r" )

for line in bamIterRef:
	#logging.info('########\nNew read to parse : ' + str(line) )
	# pass bad alignements
	if (line.is_unmapped == True ) or (line.is_secondary == True ) or (line.is_supplementary == True) or (int(line.mapping_quality) < 30) :
		#~ logging.info('Passing sequence : bad quality ' )
		continue

	# get position and ref 
	ref = line.reference_name
	position = int( line.reference_start) + 1
	sequence = line.query_sequence
	quality = line.query_qualities
	
	# non aligned reads, theoritically unreachable because tested before
	if ref == "*" :
		#~ logging.info('Passing sequence : not aligned : ' + ref )
		continue
		
	# manage new chr encounters 
	if not bamCodonReference.has_key( ref ):
		logging.info('New reference found : ' + ref )
		bamCodonReference[ref] = {}
		
	#~ sys.stderr.write( "%s\n" % refTable )
	
	# test CIGAR string : if insertion or deletion, we skip the read
	cigar = line.cigarstring
	
	if "I" in cigar or "H" in cigar or "S" in cigar or "D" in cigar:
		#logging.info('Passing sequence : CIGAR is bad: ' + cigar )
		continue
		
	# manage deletion 
	positionOfDel = 0 
	lengthOfDel = 0
	if "D" in cigar:
	#logging.info('Passing sequence : Deletion detected : ' + cigar + " ; sequence : " + sequence )
	
		# get position of the del in the read and its size
		for elt in line.cigartuples:
			if elt[0] == 2 :
				lengthOfDel = elt[1]
				break
			if elt[0] == 0:
				positionOfDel += elt[1]
			


	# deal with 0-based bam position 
	#position += 1
	currentPosition = position
	# extract codon sequences
	i = 1 
	
	# browse every codons, but need to determine if the first codon is complete or not
	while i < len(sequence) :
		# deal with end of sequence
		if i + 3 > len( sequence ):
			break
		
		# codon is complete
		if currentPosition % 3 == 0 or currentPosition == 1:
			
			#~ # manage del 
			#~ if i >= positionOfDel and (i + 3) <= positionOfDel :
			 #~ test = 1
			
			
			
			codon = sequence[i:i + 3]
			codon = codon.upper()
			#~ logging.info('Adding a codon on ' + ref + ' : pos on read : ' + str(i) + ' ; pos on ref : ' + str( currentPosition ) + ' ; codon : ' + codon + '(' + str((currentPosition/3)+1 ) + ')' )
			qualityCodon = quality[i:i+3]
			qualityScore = 0
			takeIt = 1
			for base in qualityCodon:
				qualityScore += int( base )
				if int(base) < 30 :
					takeIt = 0
				
				if takeIt == 1:
					if not bamCodonReference[ref].has_key((currentPosition/3)+1 ):
						bamCodonReference[ref][(currentPosition/3)+1] = {}
					if not bamCodonReference[ref][(currentPosition/3)+1].has_key( codon ):
						bamCodonReference[ref][(currentPosition/3)+1][codon] = 0
					bamCodonReference[ref][(currentPosition/3)+1][codon] += 1
			i += 3 
			currentPosition += 3 
		else:
			# codon is not complete
			#logging.info('Not on a 1st base, passing to next base : pos on read : ' + str(i) + ' ; pos on ref : ' + str( currentPosition ) )
			i += 1
			currentPosition += 1 
		
sys.stderr.write( " done\n" )



############ deal with bam file
sys.stderr.write( "Parsing bam file : %s ..." % ( samFile ) )



bannedReadID = []

bamIter = pysam.AlignmentFile( samFile , "r" )
bamCodon = {}
#samStream = open( samFile , "r" )

for line in bamIter:
	#logging.info('########\nNew read to parse : ' + str(line) )
	# pass bad alignements
	if (line.is_unmapped == True ) or (line.is_secondary == True ) or (line.is_supplementary == True) or (int(line.mapping_quality) < 30) :
		sys.stderr.write('Passing sequence : bad quality\n' )
		continue
	
	# pass stop pairs if activated
	if line.query_name in bannedReadID :
		# ~ sys.stderr.write( 'Passing sequence : stop in pair \n' ) 
		continue

	# get position and ref 
	ref = line.reference_name
	position = int( line.reference_start) + 1
	sequence = line.query_sequence
	quality = line.query_qualities
	
	# non aligned reads, theoritically unreachable because tested before
	if ref == "*" :
		sys.stderr.write('Passing sequence : not aligned : %s \n' % ref )
		continue
	
	# manage new chr encounters 
	if not bamCodon.has_key( ref ):
		sys.stderr.write('New reference found : %s\n' % ref )
		bamCodon[ref] = {}
	
	#~ sys.stderr.write( "%s\n" % refTable )
	
	# test CIGAR string : if insertion or deletion, we skip the read
	cigar = line.cigarstring
	
	if "I" in cigar or "H" in cigar or "S" in cigar or "D" in cigar:
		#logging.info('Passing sequence : CIGAR is bad: ' + cigar )
		continue
	
	# manage deletion 
	positionOfDel = 0 
	lengthOfDel = 0
	if "D" in cigar:
		#logging.info('Passing sequence : Deletion detected : ' + cigar + " ; sequence : " + sequence )
		
		# get position of the del in the read and its size
		for elt in line.cigartuples:
			if elt[0] == 2 :
				lengthOfDel = elt[1]
				break
			if elt[0] == 0:
				positionOfDel += elt[1]
		

	# deal with 0-based bam position 
	#position += 1
	currentPosition = position
	# extract codon sequences
	i = 1 
	
	# browse every codons, but need to determine if the first codon is complete or not
	while i < len(sequence) :
		# deal with end of sequence
		if i + 3 > len( sequence ):
			break
		
		# codon is complete
		if currentPosition % 3 == 0 or currentPosition == 1:
			
			#~ # manage del 
			#~ if i >= positionOfDel and (i + 3) <= positionOfDel :
			 #~ test = 1
			
			
			
			codon = sequence[i:i + 3]
			codon = codon.upper()
			#~ logging.info('Adding a codon on ' + ref + ' : pos on read : ' + str(i) + ' ; pos on ref : ' + str( currentPosition ) + ' ; codon : ' + codon + '(' + str((currentPosition/3)+1 ) + ')' )
			qualityCodon = quality[i:i+3]
			qualityScore = 0
			takeIt = 1
			for base in qualityCodon:
				qualityScore += int( base )
				if int(base) < 30 :
					takeIt = 0
			
			if takeIt == 1:
				if not bamCodon[ref].has_key((currentPosition/3)+1 ):
					bamCodon[ref][(currentPosition/3)+1] = {}
				if not bamCodon[ref][(currentPosition/3)+1].has_key( codon ):
					bamCodon[ref][(currentPosition/3)+1][codon] = 0
				bamCodon[ref][(currentPosition/3)+1][codon] += 1
			i += 3 
			currentPosition += 3 
			
			# break if stop and stop analysis required
			if countStop:
				if isStop( codon ):
					bannedReadID.append( line.query_name )
					break
					
			
		else:
			# codon is not complete
			#logging.info('Not on a 1st base, passing to next base : pos on read : ' + str(i) + ' ; pos on ref : ' + str( currentPosition ) )
			i += 1
			currentPosition += 1 
	    
	    

sys.stderr.write( " done\n" )


sys.stderr.write( "%s sequence passed truncated due to stop codon.\n" % len( bannedReadID ) )


# debug 
#~ for a in bamCodon.keys():
    #~ for b in bamCodon[a].keys():
	#~ for c in bamCodon[a][b].keys():
	    #~ sys.stderr.write( "%s %s %s %s\n" % ( a , b , c , bamCodon[a][b][c] ) )

# debug end 

# codon table 
codonTable = { "TTT":"F" , "TTC" : "F" , "TTA" : "L" , "TTG" : "L" , "CTT":"L" , "CTC":"L" , "CTA":"L" , "CTG":"L" , "ATT" : "I" , "ATC":"I" , "ATA":"I" , "ATG":"M" , "GTT":"V" , "GTC":"V" , "GTA":"V" , "GTG":"V" , "TCT":"S" , "TCC":"S" , "TCA":"S" , "TCG":"S" , "CCT":"P" , "CCC":"P" , "CCA":"P" , "CCG":"P", "ACT":"T" , "ACC":"T" , "ACA":"T" , "ACG":"T" , "GCT":"A" , "GCC":"A" , "GCA":"A" , "GCG":"A" , "TAT":"Y", "TAC":"Y" , "TAA":"*" , "TAG":"*" , "CAT":"H" , "CAC":"H" , "CAA":"Q" , "CAG" : "Q" , "AAT":"N" , "AAC":"N" , "AAA":"K" , "AAG":"K" , "GAT":"D" , "GAC":"D" , "GAA":"E" , "GAG":"E", "TGT":"C" , "TGC":"C" , "TGA":"*" , "TGG":"W" , "CGT":"R" , "CGC":"R" , "CGA":"R" , "CGG":"R" , "AGT":"S" , "AGC":"S" , "AGA":"R" , "AGG":"R" , "GGT":"G", "GGC":"G" , "GGA":"G", "GGG":"G" }


# we need to store some stuff for further statistical tests some values 
totalRefSequencedBase = {}
totalAltSequencedBase = {}


# let's modelize the background noise : 
# count codons wich are not the reference
sys.stderr.write( "Counting codons ... " )
for ref in bamCodon.keys():
    outStream = open( sample + "/" + sample + "." + ref + ".percent.data" , "w" )
    noiseTable = {}
    for position in bamCodon[ref].keys():
		totalOnPos = 0
		if not noiseTable.has_key( position ):
			noiseTable[position] = {}
		
		refAA = codonTable[ referenceCodon[ref][position] ]
		for codon in bamCodon[ref][position].keys():
			totalOnPos += bamCodon[ref][position][codon]
			obsAA = codonTable[ codon ]
			
			if obsAA != refAA:
				if not noiseTable[position].has_key( obsAA ):
					noiseTable[position][obsAA] = 0
				noiseTable[position][obsAA] += bamCodon[ref][position][codon]
			
    # we have all data for counting
		for aa in noiseTable[position].keys():
			noiseTable[position][aa] = noiseTable[position][aa] * 100 / float(totalOnPos)

    listofAA = list( set( codonTable.values() ) )
    outStream.write( "#position\t%s\n" % "\t".join( listofAA ) )
    for position in noiseTable.keys():
		outStream.write( "%s" % position )
		for a in listofAA : 
			if noiseTable[position].has_key( a ):
				outStream.write( "\t%s" % noiseTable[position][a] )
			else:
				outStream.write( "\t0" )
		outStream.write( "\n" )
    outStream.close()
sys.stderr.write( " done\n " )
    
    

# let's modelize the background noise : 
# count codons wich are not the reference
sys.stderr.write( "Modelizing background noise  ... " )

for ref in bamCodon.keys():
    totalRefSequencedBase[ref] = 0
    totalAltSequencedBase[ref] = 0
    
    bigTotal = 0
    noiseTable = {}
    if not noiseTable.has_key( ref ):
		noiseTable[ref] = {}
    
    
    outStream = open( sample + "/" + sample + "." + ref + ".value.data" , "w" )
    
    for position in bamCodon[ref].keys():
		totalOnPos = 0
		totalAltOnPos = 0
		
		if not noiseTable[ref].has_key( position ):
			noiseTable[ref][position] = {}
		
		refAA = codonTable[ referenceCodon[ref][position] ]
		for codon in bamCodon[ref][position].keys():
			totalOnPos += bamCodon[ref][position][codon]
			bigTotal += bamCodon[ref][position][codon]
			obsAA = codonTable[ codon ]
			
			if obsAA != refAA:
				if not noiseTable[ref][position].has_key( obsAA ):
					noiseTable[ref][position][obsAA] = 0
				noiseTable[ref][position][obsAA] += bamCodon[ref][position][codon]
				totalAltSequencedBase[ref] += bamCodon[ref][position][codon]
				totalAltOnPos += bamCodon[ref][position][codon]
		
		
		#~ sys.stderr.write( "Position : %s , Ref : %s ,  Alt : %s , Total : %s \n" % ( position , totalOnPos - totalAltOnPos , totalAltOnPos , totalOnPos ) )
			
		
			
		totalRefSequencedBase[ref] += bigTotal - totalAltSequencedBase[ref]
	
    listofAA = list( set( codonTable.values() ) )
    outStream.write( "#position\t%s\n" % "\t".join( listofAA ) )
    for position in noiseTable[ref].keys():
		outStream.write( "%s" % position )
		for a in listofAA : 
			if noiseTable[ref][position].has_key( a ):
				outStream.write( "\t%s" % noiseTable[ref][position][a] )
			else:
				outStream.write( "\t0" )
		outStream.write( "\n" )
    outStream.close()
sys.stderr.write( " done\n " )



# calculate distributions 
sys.stderr.write( "Computing distribution  ... " )
for ref in bamCodon.keys():
    sys.stderr.write( "Calculating distribution for %s \n" % ref )
    incStream = open( sample + "/" + sample + "." + ref+".value.data" , "r" )
    contingencyTable = {}
    
    for line in incStream : 
		#sys.stderr.write( "%s" % line )
		if line.startswith( "#position" ):
			continue
		myTable = line.split( "\t" )
		
		for elt in myTable[1:]:
			i_elt = int( elt )
			if not contingencyTable.has_key( i_elt ):
				contingencyTable[i_elt] = 0
			contingencyTable[i_elt] += 1
    
    totalSum = 0
    for elt in contingencyTable.keys():
		totalSum += contingencyTable[elt]
	
    # calculate probabilities
    for elt in contingencyTable.keys():
		tmp = contingencyTable[elt]
		contingencyTable[elt] = tmp / float( totalSum )
    
    
    outStream = open( sample + "/" + sample + "." + ref+".value.density" , "w" )
    outStream.write( "#value\tcount\n" )
    for elt in contingencyTable.keys():
		outStream.write( "%s\t%s\n" % ( elt , contingencyTable[elt] ) )
    outStream.close()
    
sys.stderr.write( " done\n " )


# detect the potential background noise 
a = 0
b = 1
turn = 1
positionToDiscard = []
bg = {}
while b != 0:
    bg = {}
    bgEff = {}
    b = 0
    
    
    
    sys.stderr.write( "BG estimation : turn %s\n" % turn )
    
    # first : get data for chisquare test
    for ref in bamCodon.keys():
	altCount = []
	refCount = []
	if not bg.has_key( "ref" ):
	    bg[ref] = []
	    bgEff[ref] = []

	    
	for position in bamCodon[ref].keys():
	    positionAltCount = 0
	    positionTotalCount = 0
	    if position in positionToDiscard:
			sys.stderr.write( "discarded pos : %s Passing \n" % position ) 
			continue
	    
		
	    for codon in bamCodon[ref][position].keys():
			positionTotalCount += bamCodon[ref][position][codon]
	    sys.stderr.write( "On %s : pos %s , totalD : %s\n" % ( ref , position , positionTotalCount ) )

	    for codon in bamCodon[ref][position].keys():
			refCodon = referenceCodon[ref][position]
			if codon != refCodon :
				positionAltCount = bamCodon[ref][position][codon]
				sys.stderr.write( "On %s : pos %s , alt : %s\n" % ( ref , position , positionAltCount ) )
				bg[ref].append( positionAltCount / float( positionTotalCount ) )
				
	    if positionAltCount > positionTotalCount :
			sys.stderr.write( "Error : more alt allele than the total on position %s\n" % position )
			sys.exit( 1 )
			
	    altCount.append( positionAltCount )
	    refCount.append( positionTotalCount )
	
	
	
	
	
	if len( altCount ) == 0 :
	    sys.stderr.write( "Error : all position excluded on ref %s\n" % ref )
	    sys.exit( 0 )
	    
	if len( refCount ) == 0 :
	    sys.stderr.write( "Error : all position excluded on ref %s\n" % ref )
	    sys.exit( 0 )  
	
	bgEff[ref].append( sum( altCount ) / len( altCount ) )
	bgEff[ref].append( sum( refCount ) / len( refCount ) )
	sys.stderr.write( "%s : %s , %s\n" % ( ref , bgEff[ref][0] , bgEff[ref][1] ) )
	    
    # now redo a loop to find mutations
    for ref in bamCodon.keys():
		for position in bamCodon[ref].keys():
			positionAltCount = 0
			positionTotalCount = 0
			if position in positionToDiscard:
				continue

			for codon in bamCodon[ref][position].keys():
				positionTotalCount += bamCodon[ref][position][codon]
			
			
			for codon in bamCodon[ref][position].keys():
				refCodon = referenceCodon[ref][position]
				if codon != refCodon :
					positionAltCount = bamCodon[ref][position][codon]
					# chisquare test to determine if position is significantly different from ref
					t,p = stats.chisquare( positionAltCount / float( positionTotalCount ) , bg[ref] )
					#~ sys.stderr.write( "Testing bg : %s , %s , %s, %s\n" % ( positionAltCount , positionTotalCount , ",".join( bg[ref] ) , p )  )
					if p < 0.01: 
						positionToDiscard.append( position )
						b += 1
						a += 1

    sys.stderr.write( "End of turn %s with %s new mutation found\n" % ( turn , b ) ) 
    turn += 1
    


# background noise is estimated with a control sample.
sys.stderr.write( "Calling variants ... \n" )

nbMutTotal = 0
resultTable = {}

codonFileResult = open( sample + "/" + sample + "." + codonResultFile , "w" )
codonFileResult.write( "#Reference\treferenceCodon\treferenceAA\tposition\tsequencedCodon\tsequencedAA\tcounts\tallelicratio\t\tctrlCounts\tctrlRatio\n" )

mutationFileResult = open( sample + "/" + sample + ".mutation.temp.tsv" , "w" )



testToPerform = []

for ref in bamCodon.keys():
    
    resultTable[ref] = {}    
    totalOnPos = 0
    sys.stderr.write( "ref : %s\n" % ref )
    for position in bamCodon[ref].keys():
		#~ logging.info( "ref : " + ref + " ; position : " + str( position ) )
		refCodon = referenceCodon[ref][position]
		refAA = codonTable[refCodon]
		
		totalOnPos = 0
		totalOnPosReference = 0
		for a in bamCodon[ref][position].keys():
			totalOnPos += bamCodon[ref][position][a]
		
		if bamCodonReference[ref].has_key( position ):
			for a in bamCodonReference[ref][position].keys():
				totalOnPosReference += bamCodonReference[ref][position][a]
		
			
		
		for codon in bamCodon[ref][position].keys():
			count = bamCodon[ref][position][codon]
			obsAA = codonTable[codon]
			
			if bamCodonReference[ref].has_key( position ):
				if bamCodonReference[ref][position].has_key( codon ):
					countCtrl = bamCodonReference[ref][position][codon]
				else:
					countCtrl = 0
			else:
				countCtrl = 0
			
			ratioCtrl = 0
			ratio = 0
			if totalOnPos != 0:
				ratio = ( count * 100 ) / float(totalOnPos)
			
			if totalOnPosReference != 0:
				ratioCtrl = ( countCtrl * 100 ) / float(totalOnPosReference)
			
			codonFileResult.write( "%s\t%s\t%s\t%s\t%s\t%s\t%s/%s\t%s\t%s\t%s\n" % ( ref , refCodon , refAA , position , codon , obsAA , count , totalOnPos , ratio , countCtrl , ratioCtrl))
			 
			# output mutation		
			if obsAA != refAA and totalOnPos > 200 and count > 30  and  countCtrl < count :
				sys.stderr.write( "Testing position : %s on %s , AA : %s (ref %s ), depth : %s , ratio : %s , " % ( position , ref , obsAA , refAA , totalOnPos , ratio  ) ) 
				# perform statistical test  )
				#testToPerform.append( [ ref , refCodon , refAA , position , codon , obsAA , count , totalOnPos , countCtrl , totalOnPosReference] )
				oddsratio,pValue = stats.fisher_exact([[count , totalOnPos] , [countCtrl , totalOnPosReference]] )
				sys.stderr.write( "p-value : %s \n" % pValue )
				if pValue < 0.01:  
					chi2, p, dof, ex = stats.chi2_contingency( [[count , totalOnPos ] , [ bgEff[ref][0] , bgEff[ref][1] ]] , correction=False)
					# test for bg noise
					if p < 0.01:
						mutationFileResult.write( "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" % ( ref , refCodon , refAA , position , codon , obsAA , count , totalOnPos , countCtrl , totalOnPosReference , pValue , p ))
					else:
						sys.stderr.write( "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s was different from control but in background noise\n" % ( ref , refCodon , refAA , position , codon , obsAA , count , totalOnPos , countCtrl , totalOnPosReference , pValue , p ) )
sys.stderr.write( "Done, performing statistical tests ...\n" )


codonFileResult.close()
mutationFileResult.close()

# stats for graphics
statsForGraph = {}
for ref in bamCodon.keys():
	
	statsForGraph[ref] = {}	
	totalOnPos = 0
	logging.info( "ref : " + ref )
	for position in bamCodon[ref].keys():
		#~ logging.info( "ref : " + ref + " ; position : " + str( position ) )
		refCodon = referenceCodon[ref][position]
		refAA = codonTable[refCodon]
		statsForGraph[ref][position] = {}
		
		
		totalOnPos = 0
		totalOnPosReference = 0
		for a in bamCodon[ref][position].keys():
			totalOnPos += bamCodon[ref][position][a]
			
		if bamCodonReference[ref].has_key( position ) :
			for a in bamCodonReference[ref][position].keys():
				totalOnPosReference += bamCodonReference[ref][position][a]
		else:
			totalOnPosReference = 0
		
		
		for codon in bamCodon[ref][position].keys():
			statsForGraph[ref][position][codon] = {}
			
			count = bamCodon[ref][position][codon]
			obsAA = codonTable[codon]
			
			if bamCodonReference[ref].has_key( position ):
				if bamCodonReference[ref][position].has_key( codon ):
					countCtrl = bamCodonReference[ref][position][codon]
				else:
					countCtrl = 0
			else:
				countCtrl = 0
			
			# FET
			oddsratio,pValue = stats.fisher_exact([[count , totalOnPos] , [countCtrl , totalOnPosReference]] )
			statsForGraph[ref][position][codon]['FET'] = pValue
			
			# chisquare
			chi2, p, dof, ex = stats.chi2_contingency( [[count , totalOnPos ] , [ bgEff[ref][0] , bgEff[ref][1] ]] , correction=False)
			statsForGraph[ref][position][codon]['CHI'] = p
			statsForGraph[ref][position][codon]['codonDepth'] = count




# ~ # graphics
# ~ for ref in statsForGraph.keys():
    # ~ plt.figure( figsize=( 12,10 ) ) 
    # ~ title = "Analysis for " + str( ref )
    # ~ gName = str( ref ) + "_analysis.png" 
    # ~ plt.title( title )
    # ~ plt.ylabel( "-log( pValue )" )
    # ~ plt.xlabel( "Position" )

    # ~ x = []
    # ~ y = []
    # ~ z = []
    # ~ ac = []
    # ~ b = []
    
    # ~ for position in statsForGraph[ref].keys():
	# ~ x.append( position )
	# ~ totalOnPos = 0
	# ~ for a in bamCodon[ref][position].keys():
	    # ~ totalOnPos += bamCodon[ref][position][a]
	# ~ ac.append( totalOnPos )
    
	# ~ maxPValFET = -1
	# ~ maxPValCHI = -1
	# ~ temp  = []
	# ~ for codon in statsForGraph[ref][position].keys():
	    # ~ sys.stderr.write( "fet : %s ; chi : %s \n" % ( statsForGraph[ref][position][codon]['FET'] , statsForGraph[ref][position][codon]['CHI'] ) )
	    
	    
	    # ~ if statsForGraph[ref][position][codon]['FET'] == 0:
		# ~ fetPValue = 500
	    # ~ else:
		# ~ fetPValue = - math.log10( statsForGraph[ref][position][codon]['FET'] )
	    # ~ if fetPValue > maxPValFET:
		# ~ maxPValFET = fetPValue
	    
	    # ~ if statsForGraph[ref][position][codon]['CHI'] == 0:
		# ~ chiPValue = 500
	    # ~ else:
		# ~ chiPValue = - math.log10( statsForGraph[ref][position][codon]['CHI'] )

	    # ~ if chiPValue > maxPValCHI:
		# ~ maxPValCHI = chiPValue
	    # ~ temp.append( statsForGraph[ref][position][codon]['codonDepth'] )
	
	# ~ y.append( maxPValFET )
	# ~ z.append( maxPValCHI )
	# ~ b.append( temp )
	
    # ~ plt.plot( x, y, linewidth=2.0 , color = 'r' )
    # ~ plt.plot( x , z , linewidth=2.0 , color = 'b' )
    # ~ plt.savefig( gName )

# writing final results after annotation

finalFileResult = open( outputFile , "w" )
finalFileResult.write( "#Reference\taa Position\tReference Codon\tReference aa\tAlternative Codon\tAlternative aa\tDepth\tAlternative Allele Ratio\tFET ctrl p-value\tFET bg pValue\tAnnotation\n")
mutationFileResult = open( sample + "/" + sample + ".mutation.temp.tsv" , "r" )

for line in mutationFileResult:
    lineTable = line.strip().split( "\t" )
    ref = lineTable[0]
    annotation = "No annotation found"
    
    if annotationIsDefined :
		index = lineTable[2] + lineTable[3] + lineTable[5]
		if index in knownMutationTable[ref].keys():
			annotation = knownMutationTable[ref][index]
    
    depth = int( lineTable[7] )
    ratio = ( float( lineTable[6] ) *100 ) / depth 
    
    finalFileResult.write( "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s pct\t%s\t%s\t%s\n" % ( ref , lineTable[3] , lineTable[1] , lineTable[2] , lineTable[4] , lineTable[5] , depth , ratio , lineTable[10], lineTable[11] ,  annotation ) ) 
mutationFileResult.close()
finalFileResult.close()


# writing consensus sequence
consFile = open( sample + "/" + sample + ".cs" , "w" )
for ref in bamCodon.keys():
    consFile.write( ">" + ref + "\n" )
    for position in sorted( bamCodon[ref].keys() ) :
		count = 0
		retainedCodon = ""
		for codon in bamCodon[ref][position].keys():
			if bamCodon[ref][position][codon] > count:
				count = bamCodon[ref][position][codon]
				retainedCodon = codon
		consFile.write( retainedCodon )
    consFile.write( "\n" )
consFile.close()