#include <math.h>
#include <float.h>
#include <stdlib.h>
#include "common.h"

void rec_calc_scores(struct Node *, char *, double *);
double new_score(long, long);

extern struct Node *root;
extern long na, nc, nchromo;
extern double current_location;

static double tmp_scoretable[MAX_SUBTREES];

void calculate_scores(char *status, double *table)
{
	rec_calc_scores(root, status, table);
}

void rec_calc_scores(struct Node *n, char *status, double *table)
{
	struct ListNode *ln = n->first_child;
	double max_scoretable[MAX_SUBTREES], scoretable[MAX_SUBTREES];
	int i;

	if (ln) /* internal */
	{
		long a = 0, c = 0;
		double new_component;

		rec_calc_scores(ln->ptr, status, NULL);
		a += ln->ptr->na; c += ln->ptr->nc;
		for (i = 0; i < MAX_SUBTREES; i++)
			scoretable[i] = max_scoretable[i] = tmp_scoretable[i];

		ln = ln->next;
		if (ln == NULL) { printf("calc_scores: malformed tree\n"); exit(-1); }

		while (ln)
		{
			int j;
			double x;

			rec_calc_scores(ln->ptr, status, NULL);
			a += ln->ptr->na; c += ln->ptr->nc;
			
			for (i = 0; i < MAX_SUBTREES - 1; i++)
				for (j = 0; j < MAX_SUBTREES - i - 1; j++)
					if (tmp_scoretable[i] > -FLT_MAX && max_scoretable[j] > -FLT_MAX)
						if ((x = tmp_scoretable[i] + max_scoretable[j]) > scoretable[i + j + 1])
							scoretable[i + j + 1] = x;
			
			for (i = 0; i < MAX_SUBTREES; i++)
			{
				if (tmp_scoretable[i] > max_scoretable[i]) max_scoretable[i] = tmp_scoretable[i];
				if (tmp_scoretable[i] > scoretable[i]) scoretable[i] = tmp_scoretable[i];
			}
			
			ln = ln->next;
		}

		new_component = new_score(a, c);
		if (new_component > scoretable[0]) scoretable[0] = new_component;

		for (i = 0; i < MAX_SUBTREES; i++) tmp_scoretable[i] = scoretable[i];
		n->na = a; n->nc = c;
	}
	else /* leaf */
	{
		for (i = 0; i < MAX_SUBTREES; i++) tmp_scoretable[i] = -FLT_MAX;
		if (status[n->chromo] == 'a') { n->na = 1; n->nc = 0; }
		else { n->na = 0; n->nc = 1; }
	}
	if (table) for (i = 0; i < MAX_SUBTREES; i++) table[i] = scoretable[i];
}


double new_score(long a, long c)
{
	double p = (a + c) / (double)nchromo, pa = na / (double)nchromo;
	double ea = na * p, ec = nc * p; 
	double x = (a - ea) / sqrt((a + c) * pa * (1 - pa));
	return x;

/*	double p = (a + c) / (double)nchromo;
	double ea = na * p, ec = nc * p; 
	double x = (a - ea) / sqrt(ea * (1-p)) + (ec - c) / sqrt(ec * (1-p));
	double x = (a - ea) / sqrt(ea) + (ec - c) / sqrt(ec);
	return x;*/
}
