#include <stdlib.h>
#include <float.h>
#include "common.h"

extern long na, nc, nchromo, nmarkers, npermu, seed;
extern char mode;
extern long fixed_k;

int cmp_sort(const void *, const void *);
double calc_permu_score2(double *);
double calc_permu_score3(double *, long);
double pvalue(long, long);

extern double obs_score1[];

double obs_pval1[NSUBSCORES], *obs_pval2;
double *obs_score2;

static long *tmp_ptr;
static double *tmp_score;
static double *permu_score1, *permu_score2;
static double *pval1;
static char *permu_status;
static double *min_permu_pval2;
static int phase;

void alloc_permu(void)
{
	long i;

	tmp_ptr = (long *)malloc((npermu + 1) * sizeof(long));
	tmp_score = (double *)malloc((npermu + 1) * sizeof(double));

	permu_score1 = (double *)malloc(npermu * NSUBSCORES * sizeof(double));
	permu_score2 = (double *)malloc(2 * nmarkers * npermu * sizeof(double));

	pval1 = (double *)malloc(npermu * NSUBSCORES * sizeof(double));

	obs_score2 = (double *)malloc(2 * nmarkers * sizeof(double));
	obs_pval2 = (double *)malloc((nmarkers + 1) * sizeof(double));
	
	min_permu_pval2 = (double *)malloc(npermu * sizeof(double));
	for (i = 0; i < npermu; i++) min_permu_pval2[i] = 1.0;

	permu_status = (char *)malloc(nchromo * sizeof(char));
}

void free_permu(void)
{
	free(tmp_ptr); free(tmp_score);
	free(permu_score1); free(permu_score2);
	free(pval1);
	free(obs_pval2); free(obs_score2);
	free(permu_status);
	free(min_permu_pval2);
}

void permu_phase1(int side, long m)
{
	long i, j;
	
	phase = 1;
	srand(seed);
	for (i = 0; i < npermu; i++)
	{
		long a = na, c = nc, j;
	
		for (j = 0; j < nchromo; j++)
			if (rand() / (RAND_MAX + 1.0) < a / (double)(a+c)) { permu_status[j] = 'a'; a--; }
			else { permu_status[j] = 'c'; c--; }
		
		calculate_scores(permu_status, permu_score1 + i * NSUBSCORES);
	}
	

	for (j = 0; j < NSUBSCORES; j++)
	{
		for (i = 0; i < npermu; i++) { tmp_ptr[i] = i; tmp_score[i] = permu_score1[i*NSUBSCORES+j]; }
		tmp_ptr[npermu] = npermu; tmp_score[npermu] = obs_score1[j];

		qsort(tmp_ptr, npermu + 1, sizeof(long), cmp_sort);

		/*if (j == 0){
			printf("mk %ld\n", m);
			for (i = 0; i < npermu + 1; i++)
			{ printf("%f\n", tmp_score[i]); }
		}*/

		for (i = 0; i < npermu + 1; i++)
		{
			long k = i;
			double p, sc = tmp_score[tmp_ptr[i]];

			while (k < npermu && tmp_score[tmp_ptr[k+1]] == sc) k++;

			p = 1.0 - pvalue(i-1, k+1);
			while (i <= k)
			{
				if (tmp_ptr[i] == npermu) obs_pval1[j] = p;
				else pval1[tmp_ptr[i]*NSUBSCORES+j] = p;
				i++;
			}
			i--;
		}
	}

	obs_score2[side*nmarkers+m] = calc_permu_score2(obs_pval1);
	for (i = 0; i < npermu; i++)
		permu_score2[(2*i+side)*nmarkers+m] = calc_permu_score2(pval1+i*NSUBSCORES);
}

void permu_phase2(long m)
{
	double score = calc_permu_score3(obs_score2, m);
	double prev = 0;
	long i;

	phase = 2;
	for (i = 0; i < npermu; i++)
	{
		tmp_ptr[i] = i;
		tmp_score[i] = calc_permu_score3(permu_score2 + 2 * i * nmarkers, m);
	}
	tmp_ptr[npermu] = npermu; tmp_score[npermu] = score;
	
	qsort(tmp_ptr, npermu + 1, sizeof(long), cmp_sort);

	/*printf("mk %ld\n", m);
	for (i = 0; i < npermu + 1; i++)
	{ printf("%ld %f\n", tmp_ptr[i], tmp_score[tmp_ptr[i]]); }*/

	for (i = 0; i < npermu + 1; i++)
	{
		long k = i;
		double p, sc = tmp_score[tmp_ptr[i]];

		while (k < npermu && tmp_score[tmp_ptr[k+1]] == sc) k++;
		p = pvalue(i-1, k+1);

		while (i <= k)
		{
			if (tmp_ptr[i] == npermu) obs_pval2[m] = p;
			else if (p < min_permu_pval2[tmp_ptr[i]]) min_permu_pval2[tmp_ptr[i]] = p;
			i++;
		}
		i--;
	}
}

double total_p(double p)
{
	long i, n = 0;
	for (i = 0; i < npermu; i++) if (p >= min_permu_pval2[i]) n++;
	return n / (double)npermu;
}

int cmp_sort(const void *a, const void *b)
{
	if (tmp_score[*(long *)a] < tmp_score[*(long *)b]) return -1;
	if (tmp_score[*(long *)a] > tmp_score[*(long *)b]) return 1;
	return 0;
}

double calc_permu_score2(double *p)
{
	if (fixed_k == 0)
	{
		long i;
		double min = p[0];
		for (i = 1; i < NSUBSCORES; i++) if (p[i] < min) min = p[i];
		return min;
	}
	else
		return p[fixed_k-1];
}

double calc_permu_score3(double *p, long m)
{
	double p0, p1;
	if (m > 0) p0 = p[m-1]; else p0 = 0.5;
	if (m < nmarkers) p1 = p[nmarkers+m]; else p1 = 0.5;
	return p0 * p1;
}

/* y = obs. score; max_lt = ix of highest pscore < y; min_gt = ix of lowest pscore > y */

double pvalue(long max_lt, long min_gt)
{
	double minp = (max_lt + 1.0) / (npermu + 1.0);
	double maxp = min_gt / (npermu + 1.0);
	double low, high, obs;
	
	if (min_gt - max_lt > 2) return (minp + maxp) * 0.5;
	if (max_lt == -1)
	{
		if (phase == 1) return (minp + maxp) * 0.5;
		low = 0;
	}
	else if (min_gt == npermu + 1)
	{
		if (phase == 2) return (minp + maxp) * 0.5;
		return 1.0 - 1.0/(npermu + 1.0) *
				(tmp_score[tmp_ptr[npermu-1]] / tmp_score[tmp_ptr[npermu]]);
	}
	else
	{
		low = tmp_score[tmp_ptr[max_lt]];
	}
	high = tmp_score[tmp_ptr[min_gt]];
	obs = tmp_score[tmp_ptr[max_lt+1]];
	
	return minp + (obs - low) / (high - low) * (maxp - minp);
}

