#include "Chain.h"
#include "ProposeAlpha.h"
#include "MRateCont.h" 
#include "MRateDiscrete.h" 
#include "DrateDiscrete.h"
#include "McRateUtils.h"

#include "talRandom.h"
#include "someUtil.h"
#include "nj.h"
#include "likelihoodComputation.h"
#include "gammaDistribution.h"
using namespace likelihoodComputation;

#define ALL 0


Chain::Chain(const tree& t1, 
			 const sequenceContainer1G & sc, 
			 const stochasticProcess* pSp,
			 const string baseStr, 
			 const McRateOptions& options)
: m_seqContainer(sc) , m_kThining(options.getThinning()) 
{
	initTree(t1);
	init(pSp, baseStr, options);
}


Chain::Chain(const sequenceContainer1G & sc, 
			 const stochasticProcess* pSp, 
			 const string baseStr, 
			 const McRateOptions& options)
: m_seqContainer(sc) , m_kThining(options.getThinning()) 
{
	setRandomTree();
	init(pSp, baseStr, options);
}



Chain::~Chain()
{
	m_resFile.close();
	m_logFile.close();

	if (m_pPriorDist != NULL)
	{
		delete m_pPriorDist;
		m_pPriorDist = NULL;
	}
	if (m_pSp != NULL)
	{
		delete m_pSp;
		m_pSp = NULL;
	}

	for (int i = 0; i < seqLen(); ++i)
	{
		delete m_currentRates[i];
		delete m_MetaRates[i];
	}
}


void Chain::init(const stochasticProcess* pSp,
				 const string baseStr, 
				 const McRateOptions& options)
{
	m_sampledSteps = 0;
	m_avgAlpha = 0;
	m_bChanged = true; //so that rates will be inferred after first step
	m_Counters.resize(5);  
	m_expRates.resize(seqLen());
	m_currentExpRates.resize(seqLen());
	m_expBranches.resize(m_tree.iNodes());
	m_currentRates.resize(seqLen());
	m_MetaRates.resize(seqLen());
	for (int i = 0; i <seqLen(); ++i)
	{
		m_currentRates[i] = new DrateDiscrete();
		m_MetaRates[i] = new MRateDiscrete();
	}

	m_pSp = new stochasticProcess(*pSp); //make a local copy for each chain  
	
	//init a random alpha
	MDOUBLE randomAlpha = talRandom::giveRandomNumberBetweenZeroAndEntry(2.0);
	//@@@@@@@@@@ to start with alpaha = 1 for debug : remark next line
	(static_cast<gammaDistribution*>(m_pSp->distr()))->setAlpha(randomAlpha);
	m_pPriorDist =  new gammaDistribution(randomAlpha, options.getCategoriesNum());
	m_curLikelihood = likelihoodComputation::getTreeLikelihoodAllPosAlphTheSame(m_tree, m_seqContainer, *m_pSp);
	
	//init output files
	string resStr = baseStr + string("_res.txt");
	string logStr = baseStr + string("_Log.txt");
	string treesStr = baseStr + string("_trees.txt");
	m_resFile.open(resStr.c_str());
	m_logFile.open(logStr.c_str());
	m_treesFile.open(treesStr.c_str());
	m_resFile <<"step"<<"\t"<< "likelihood" <<"\t" << "alpha" << endl;
	initProposalProbs();	
}


void Chain::initProposalProbs()
{
	typedef pair<ProposeName, int> proposalP;
	vector<proposalP> probVec;
	probVec.push_back(proposalP(LOCAL_BRANCH, 0));
	probVec.push_back(proposalP(GLOBAL_BRANCH, 5));
	probVec.push_back(proposalP(ALPHA, 2));
	probVec.push_back(proposalP(NNI, 3));

	m_proposalProb.resize(10);
	MDOUBLE sum = 0.0;
	int place = 0;
	for (int i = 0; i < probVec.size(); ++i)
	{
		//sum += probVec[i].second;
		for (int j = 0; j < probVec[i].second; ++j)
		{
			m_proposalProb[place++] = probVec[i].first;
		}
	}
	if (place != 10)
		errorMsg::reportError("total proposals probability is not 1.0 in initProposalProbs()");
}


//makeBurnInSteps(): 
//stepNum= number of burnin steps to make. 
//calibrateCycle= after how many steps to make parameter calibrations. 
//if bTopology = false then make no topology changes
void Chain::makeBurnInSteps(const int stepsNum, const int calibrateCycle, const bool bTopology/*=false*/)
{
	for (int i = 0; i < stepsNum ; ++i) 
	{
		makeSingleStep(bTopology);
		if (getTotalSteps() % calibrateCycle == 0) // Calibration is needed
		{
			calibrate();
		}
		if (getTotalSteps() % m_kThining == 0)
		{
			printCurrentState();
		}
	}
	m_logFile <<"------------------------------------------" <<endl <<endl;
	m_logFile <<"alpha acc= " << m_Counters[ALPHA].getTotalAccRate() <<endl;
	m_logFile <<"local acc= " << m_Counters[LOCAL_BRANCH].getTotalAccRate() <<endl;
	m_logFile <<"global acc= " << m_Counters[GLOBAL_BRANCH].getTotalAccRate() <<endl;
	m_logFile <<"NNI acc= " << m_Counters[NNI].getTotalAccRate() <<endl;


}

//makeInferenceSteps(): 
//stepNum= number of inference steps to make. 
//bScaleRates = if true then scale the rates so their the average is 1.0 
//if bTopology = false then make no topology changes
//if bInferRates = true then infer rate4site
void Chain::makeInferenceSteps(const int stepsNum, bool bInferRates, bool bScaleRates, bool bTopology/*=false*/)
{
  	for (int i = 0; i < stepsNum ; ++i) 
	{
		makeSingleStep(bTopology);
		if (getTotalSteps() % m_kThining == 0)
		{
			++m_sampledSteps;
			if (bInferRates == true)
			{
				if (m_bChanged == true)
				{//if tree has changed since last time - then find new rates 
					findRates(bScaleRates);
				}

				//add cur rates to global rate vector
				cout << "alpha = " <<(static_cast<gammaDistribution*>(m_pSp->distr()))->getAlpha()<<endl;
				for (int j = 0; j < seqLen(); ++j)
				{
					m_MetaRates[j]->addDRate(m_currentRates[j]);
				}

				//calc whole chain average rate4site 
				for (int ri = 0; ri <seqLen(); ++ri)
				{
					MDOUBLE curR = m_currentExpRates[ri];
					MDOUBLE averageR = (m_expRates[ri] * (getWeight() - 1) + curR) / getWeight();
					m_expRates[ri] = averageR;
				}
			}

			//calc whole chain average alpha
			MDOUBLE curAlpha  = (static_cast<gammaDistribution*>(m_pSp->distr()))->getAlpha();
			m_avgAlpha = (m_avgAlpha * (m_sampledSteps - 1) + curAlpha) / m_sampledSteps;

			//add cur branch lengths to global vector 
			if (bTopology == false)
			{
				vector<tree::nodeP> nodesVec;
				m_tree.getAllNodes(nodesVec, m_tree.iRoot());
				for (int nodei = 0; nodei <nodesVec.size(); ++nodei)
				{
					if (!nodesVec[nodei]->isRoot())
					{
						int stepNum = m_sampledSteps;
						MDOUBLE curD = nodesVec[nodei]->dis2father();
						m_expBranches[nodei] = (m_expBranches[nodei] * (stepNum - 1) + curD) / stepNum;
					}
				}
			}

			printCurrentState();
			m_bChanged = false;

		}
	}
}

//initTree(): initialize the tree. 
//if inTree has no branchLengths then initialize random lengths  
void Chain::initTree(const tree& inTree)
{
	m_tree = inTree;
	if (m_tree.WithBranchLength() == false)
	{
		vector<tree::nodeP> nodesVec;
		m_tree.getAllNodes(nodesVec, m_tree.iRoot());
		for (int i=0; i<nodesVec.size(); ++i)
		{
			if (!nodesVec[i]->isRoot())
			{
				MDOUBLE length = talRandom::giveRandomNumberBetweenZeroAndEntry(2.0);
				if (length == 0.0)
					length = ZERO_DIST;

				nodesVec[i]->setDisToFather(length);
			}
		}
	}
	m_tree.output("initTree.txt");
}

//setRandomTree: initalize a random tree.
//create a random distance matrix and then send to Neighbour-Joining algorithm
void Chain::setRandomTree()
{
	VVdouble disTab;
	vector<string> vNames;

	disTab.resize(m_seqContainer.numberOfSequences());
	for (int s = 0; s < m_seqContainer.numberOfSequences(); ++s)
		disTab[s].resize(m_seqContainer.numberOfSequences(), 0.0);

	for (int i = 0; i < m_seqContainer.numberOfSequences(); ++i)
	{
		for (int j = i+1; j < m_seqContainer.numberOfSequences(); ++j)
		{
			MDOUBLE tmpDist = talRandom::giveRandomNumberBetweenZeroAndEntry(2.0);
			disTab[i][j] = tmpDist;
			disTab[j][i] = tmpDist;
		}
		vNames.push_back(m_seqContainer[i].name());
	}

	NJalg nj1;
	m_tree = nj1.computeNJtree(disTab, vNames);
}

//makeSingleStep():
//1. propose a new state  
//2. accept or reject
//if bTopology=false then make no topology changes
bool Chain::makeSingleStep(const bool bTopology)
{ 
	//@@@debug
	ofstream outTree("chainTree.txt");
	m_tree.output(outTree);
	outTree.close();

	Proposal::ParamProposal param_proposal = Proposal::MULTIPLY;
	bool bIsAccepted; 
	ProposeName prop_type = getStepType(bTopology);
	MDOUBLE hastingsRatio;
	MDOUBLE priorRatio = 1.0;
	MDOUBLE curAlpha = (static_cast<gammaDistribution*>(m_pSp->distr()))->getAlpha(); //get the alpha of the chain
	
	switch (prop_type)
	{	
	case ALPHA:
	{//alpha change
		MDOUBLE newAlpha = curAlpha;
		hastingsRatio = m_propAlpha.proposeNewState(m_tree, newAlpha, param_proposal);
		(static_cast<gammaDistribution*>(m_pSp->distr()))->setAlpha(newAlpha);
	    bIsAccepted = IsNewStateAccepted(m_tree, hastingsRatio, priorRatio, prop_type);
		if (bIsAccepted) 
		{
			static_cast<gammaDistribution*>(m_pPriorDist)->setAlpha(newAlpha);
		}
		else // returning the original alpha 
		{
			(static_cast<gammaDistribution*>(m_pSp->distr()))->setAlpha(curAlpha);
		}
		break;
	}
	case LOCAL_BRANCH:
	{
		tree propTree = m_tree;	
		hastingsRatio = m_propLocal.proposeNewState(propTree, curAlpha, param_proposal);

		bIsAccepted = IsNewStateAccepted(propTree, hastingsRatio, priorRatio, prop_type);
		if (bIsAccepted) 
		{
			m_tree = propTree;		
		}
		break;
	}
	case GLOBAL_BRANCH:
	{
		tree propTree = m_tree;
		hastingsRatio = m_propGlobal.proposeNewState(propTree, curAlpha, param_proposal);
		bIsAccepted = IsNewStateAccepted(propTree, hastingsRatio, priorRatio, prop_type);
		if (bIsAccepted) 
		{
			m_tree = propTree;		
		}
		break;
	}
	case NNI:
	{
		tree propTree = m_tree;
		hastingsRatio = m_propNni.proposeNewState(propTree, curAlpha, param_proposal);
		bIsAccepted = IsNewStateAccepted(propTree, hastingsRatio, priorRatio, prop_type);
		if (bIsAccepted) 
		{
			m_tree = propTree;		
		}
		break;
	}
	default:
		errorMsg::reportError("specified proposed move is not implemented. in Chain::makeSingleStep()");
	}
	
	if (bIsAccepted == true)
	{
		m_bChanged = true;
	}

	cout << ".";
	return bIsAccepted;
}



//decide whether to accept or reject the new state. 
//in both cases - update the counter of the proposed move type 
bool Chain::IsNewStateAccepted(const tree& propTree, const MDOUBLE hastingsRatio, const MDOUBLE priorRatio, const ProposeName proName)
{
	bool bIsAccepted;
	//@@@@@@@CAN MAKE LIKELIHOOD COMPUTATON FASTER IF ONLY ONE BRANCH CHANGED
	MDOUBLE propL = likelihoodComputation::getTreeLikelihoodAllPosAlphTheSame(propTree, m_seqContainer, *m_pSp);
	MDOUBLE LRatio = exp(propL - m_curLikelihood);
	MDOUBLE moveRatio = LRatio * priorRatio * hastingsRatio;
	MDOUBLE u = talRandom::giveRandomNumberBetweenZeroAndEntry(1.0);
	if (u < moveRatio) 
	{//proposed state is accepted
		m_curLikelihood = propL;
		bIsAccepted = true;
	}
	else 
	{
		bIsAccepted = false;
	}
	m_Counters[proName].update(bIsAccepted);
	m_Counters[ALL].update(bIsAccepted);
	return bIsAccepted;
}

void Chain::calibrate()
{
	int min_steps_for_calibration = 50;
	if(m_Counters[ALPHA].getLastStepNum() > min_steps_for_calibration) 
	{
		m_logFile << "alpha acc= " << m_Counters[ALPHA].getLastAccRate()<<" lambda before = "<<m_propAlpha.m_lambda<<" ";
		m_propAlpha.updateParameters(m_Counters[ALPHA].getLastAccRate());
		m_logFile << " lambda after = "<<m_propAlpha.m_lambda<<endl;
		
		m_Counters[ALPHA].zeroCounters();
	}
	if(m_Counters[LOCAL_BRANCH].getLastStepNum() > min_steps_for_calibration) 
	{
		m_logFile << "local acc= " << m_Counters[LOCAL_BRANCH].getLastAccRate() <<" lambda before = "<<m_propLocal.m_lambda<<" ";
		m_propLocal.updateParameters(m_Counters[LOCAL_BRANCH].getLastAccRate());
		m_logFile << " lambda after = "<<m_propLocal.m_lambda<<endl;

		m_Counters[LOCAL_BRANCH].zeroCounters();
	}
	if(m_Counters[GLOBAL_BRANCH].getLastStepNum() > min_steps_for_calibration) 
	{
		m_logFile << "global acc= " << m_Counters[GLOBAL_BRANCH].getLastAccRate() << " lambda before= "<<m_propGlobal.m_lambda<<" ";
		m_propGlobal.updateParameters(m_Counters[GLOBAL_BRANCH].getLastAccRate());
		m_logFile << " lambda after = "<<m_propGlobal.m_lambda<<endl;
		m_Counters[GLOBAL_BRANCH].zeroCounters();
	}
	if(m_Counters[NNI].getLastStepNum() > min_steps_for_calibration) 
	{
		m_logFile << "NNI acc= " << m_Counters[NNI].getLastAccRate();
		m_propNni.updateParameters(m_Counters[NNI].getLastAccRate());
		m_Counters[NNI].zeroCounters();
	}

	m_Counters[ALL].zeroCounters();
	m_logFile << "step= "<<getTotalSteps()<<endl;
}


//getStepType():
//decide which step to make
//if bTopology=false then make no topology changes
Chain::ProposeName Chain::getStepType(const bool bTopology) const
{
	int r = talRandom::giveIntRandomNumberBetweenZeroAndEntry(m_proposalProb.size());
	return m_proposalProb[r];
}

void Chain::printCurrentState()
{
	MDOUBLE curAlpha = (static_cast<gammaDistribution*>(m_pSp->distr()))->getAlpha();
	m_resFile << getTotalSteps() << "\t" << m_curLikelihood << "\t" << curAlpha << endl;
	//m_tree.output(m_treesFile);
}



//find the rate for each position.
//the returned value is a vector of the rate expectation of each position  
void Chain::findRates(const bool bScaleRates)
{
	computeRate4site();
	for (int i = 0; i < seqLen(); ++i)
	{
		m_currentExpRates[i] = m_currentRates[i]->getExpectation();
	}
		
	if (bScaleRates == true)
	{
		MDOUBLE scaleFactor = McRateUtils::scaleVec(m_currentExpRates, 1.0);
		for (int i = 0; i < seqLen(); ++i)
		{
			m_currentRates[i]->scale(scaleFactor);
		}
	}
}





/////////////
//computeRate4site: calculates the posterior rates:
//the rate for each position is the expectation of that rate 
//for each position: 
//E(R|Data) = sigma{ (P(Data|Ri)*P(Ri)*Ri) / sigma{(P(Data|Ri)*P(Ri))}} = 
//1/sigma{(P(Data|Ri)*P(Ri))} * sigma{ (P(Data|Ri)*P(Ri)*Ri) }
//the Ri's are the mean of the priorDistribution categories
void Chain::computeRate4site()
{
	//calc posterior distribution using the categories of the prior
	for (int pos = 0 ; pos < m_seqContainer.seqLen(); ++pos)
	{
		MDOUBLE totalProb = 0.0;
		
		//calc probability of each Ri
		MDOUBLE rate, prob_Ri, prior_Ri, LRi;
		Vdouble ratesProb(m_pPriorDist->categories());
		totalProb; 
		for (int r = 0; r < m_pPriorDist->categories(); ++r)
		{
			rate = m_pPriorDist->rates(r);
			if (rate == 0.0)
				rate = ZERO_DIST;
			//////////////////
			//@@@@@debug
			if (rate < 0.0)
			{
				cout << "negative rate!!!!  r = " <<rate <<endl;
				cout << "alpha = " << (static_cast<gammaDistribution*>(m_pSp->distr()))->getAlpha() <<endl;
			}
			/////////////////
			prior_Ri = m_pPriorDist->ratesProb(r);
			LRi = likelihoodComputation::getLofPos(pos, m_tree, m_seqContainer, *m_pSp, rate);
			prob_Ri = LRi * prior_Ri;
			ratesProb[r] = prob_Ri;
			totalProb += prob_Ri;
		}
			
		MDOUBLE sumRF = 0.0;
		for (int ri = 0; ri < m_pPriorDist->categories(); ++ri)
		{
			ratesProb[ri] /= totalProb;
		}
		m_currentRates[pos]->setProb(ratesProb, static_cast<gammaDistribution*>(m_pPriorDist)); 
	}
}



void Chain::printRates(ofstream& outFile, const bool bDrate)
{
	printTime(outFile);
	
	outFile <<"total number of steps: " << getTotalSteps()<< endl;
	outFile <<"number of inference steps: " << m_Counters[ALL].getLastStepNum()<< endl;
	outFile <<"infer rates every:  " << m_kThining<< " steps" <<endl;
	outFile <<"number of inference rates: " << m_MetaRates[0]->getWeight()<< endl;
	outFile <<"++++++++++++++++++++++++++++++++++++++++++++" <<endl;
	outFile <<"Acceptance Rates:   Alpha = " << m_Counters[ALPHA].getTotalAccRate();
	outFile <<"  LOCAL =  "<< m_Counters[LOCAL_BRANCH].getTotalAccRate(); 
	outFile	<<"  GLOBAL = " << m_Counters[GLOBAL_BRANCH].getTotalAccRate();
	outFile <<"  NNI = "<<m_Counters[NNI].getTotalAccRate() << endl;
	outFile <<"average alpha is: " << m_avgAlpha << endl;


	outFile<<"#POS"<<"\t"<<"SEQ"<<"\t"<<"META_EXP"<<"\t"<<"AVG_RATE"<<endl;
	const sequence* pSeq = &(m_seqContainer[0]);
	Vdouble expRates(m_seqContainer.seqLen());
	for (int i=0; i < m_MetaRates.size(); ++i) 
	{
		expRates[i] = m_MetaRates[i]->getExpectation();
		outFile<<i+1;
		outFile<< "\t"<< pSeq->getAlphabet()->fromInt((*pSeq)[i]);
		outFile<< "\t"<< expRates[i];
		outFile<< "\t"<< m_expRates[i];
		outFile<<endl;
	}

	MDOUBLE ave = computeAverage(expRates);
	MDOUBLE std = computeStd(expRates);
	if (((ave < 1e-9)) && (ave > (-(1e-9)))) 
		ave=0;
	if ((std > (1-(1e-9))) && (std < (1.0 + (1e-9)))) 
		std=1.0;
	outFile<<"# Average = "<<ave<<endl;
	outFile<<"# Standard Deviation = "<<std<<endl;


	if (bDrate == true)
	{
		outFile << endl <<"__________________________________________________________________"<<endl;
		outFile << "#POS" << "\t" << "SEQ" << "\t" << "EXP" << "\t" <<endl;
		const sequence* pSeq = &(m_seqContainer[0]);
		for (int i=0; i < m_MetaRates.size(); ++i) 
		{
			expRates[i] = m_currentRates[i]->getExpectation();
			outFile<<i+1;
			outFile<< "\t"<< pSeq->getAlphabet()->fromInt((*pSeq)[i]);
			outFile<< "\t"<< expRates[i];
			outFile<<endl;
		}

		MDOUBLE ave = computeAverage(expRates);
		MDOUBLE std = computeStd(expRates);
		if (((ave < 1e-9)) && (ave > (-(1e-9)))) 
			ave=0;
		if ((std > (1-(1e-9))) && (std < (1.0 + (1e-9)))) 
			std=1.0;
		outFile<<"# Average = "<<ave<<endl;
		outFile<<"# Standard Deviation = "<<std<<endl;
	}
}


void Chain::resetCounters()
{
	for (int i = 0; i < m_Counters.size(); ++i)
	{
		m_Counters[i].resetCounters();
	}
}

//getTree: get the mcmc tree. 
//if bTopology=false then the branch-lengths are the chain average 
tree Chain::getBayesTree(const bool bTopology) const
{
	tree bayesTree = m_tree; 
	if (bTopology == false)
	{
		vector<tree::nodeP> nodesVec;
		bayesTree.getAllNodes(nodesVec, bayesTree.iRoot());
		for (int i=0; i<nodesVec.size(); ++i)
		{
			if (nodesVec[i]->father() != NULL)
				nodesVec[i]->setDisToFather(m_expBranches[i]);
		}
	}
	else 
	{
		errorMsg::reportError("not implemented yet: Chain::getTree(bTopology=true)");
	}
	return bayesTree;
}



