#include "MultiChain.h"
#include "McRateUtils.h"


#include "someUtil.h"

#include <cmath>


/*
//bInferRates - if true then infer rates during inference
//alternateSteps = the number of steps to alternate between chains during inference.
//calibrationCycle = the number of steps to update chain param during burn-in. 
//bScale = should we scale the rates at each step
multiChainParameters::multiChainParameters() 
{
	m_kThinning = 10;
	m_bScale = true;
	m_minSteps = 1000;
	m_stepsInLimit = 300; // stopping the chains criterions
	m_rLimit = 0.9;
	m_epsilonLimit = 0.01;// stopping the chains criterions
	m_calibrationCycle = 100;
	m_alternateSteps = 100;
	m_bInferRates=true;
	m_bDoCorrelationTest= false;
}
*/

MultiChain::MultiChain(const int chainsNum,
					   const sequenceContainer1G & sc,
					   const stochasticProcess* pSp, 
					   const string baseFolderName,
					   const McRateOptions& options) 
: m_options(options)
{
	m_pChainsVec.resize(chainsNum);
	vector<string> chainsNames =  getChainsFileNames(chainsNum, baseFolderName);
	for (int i = 0 ; i < chainsNum ; ++i)
	{// initialize all chains
		m_pChainsVec[i] = new Chain(sc, pSp , chainsNames[i], m_options);
	}
}

MultiChain::MultiChain(const tree & t1,
					   const int chainsNum,
					   const sequenceContainer1G & sc,
					   const stochasticProcess* pSp, 
					   const string baseFolderName,
					   const McRateOptions& options)
: m_options(options)
{
	m_pChainsVec.resize(chainsNum);
	vector<string> chainsNames =  getChainsFileNames(chainsNum, baseFolderName);
	for (int i = 0 ; i < chainsNum ; ++i)
	{
		m_pChainsVec[i] = new Chain(t1, sc, pSp , chainsNames[i], m_options);
	}
}

MultiChain::~MultiChain()
{	
	for (int i = 0 ; i < numChains(); ++i) 
	{
		delete m_pChainsVec[i];
		m_pChainsVec[i] = NULL;
	}
}


//runChains: run the MCMC chains
//burnInTime = number of burn-in steps for every chain.
//maxInferrenceTime = number of inference steps for every chain.
void MultiChain::runChains(const int burnInTime,
						   const int maxInferrenceTime)
{
	runBurnIn(burnInTime);
	runInference(maxInferrenceTime);
}


void MultiChain::runBurnIn(const int burnInTime) 
{
	for (int i = 0; i < numChains(); ++i)
	{
		m_pChainsVec[i]->makeBurnInSteps(burnInTime, m_options.m_calibrationCycle);
		m_pChainsVec[i]->resetCounters();
	}
}

//computeHowManyInferenceStepsToDo: 
//decide how many inference steps each chain should next do
int MultiChain::computeHowManyInferenceStepsToDo(const int stepsDone,
												 const int maxInferrenceTime) 
{
	if (stepsDone + m_options.m_alternateSteps > maxInferrenceTime) 
		return (maxInferrenceTime - stepsDone);
	return m_options.m_alternateSteps;
}


void MultiChain::runInference(const int maxInferrenceTime)
{
	int stepsDone = 0;
	while (stepsDone < maxInferrenceTime) 
	{
		int infer_steps = computeHowManyInferenceStepsToDo(stepsDone,maxInferrenceTime);
		for (int i = 0; i < numChains(); ++i)
		{
			m_pChainsVec[i]->makeInferenceSteps(infer_steps, m_options.m_bInferRates, m_options.m_bScale);
		}
		stepsDone += infer_steps;
		if (IsChainsLimit(stepsDone)) //check if chains have converged to the limit
			break;
	}
}


void MultiChain::printResults(ofstream& outFile)
{
	//@@@@@
	m_pChainsVec[0]->printRates(outFile, false);

}

//isEnoughInferenceSteps: return true if the all chains have perfomed a minimum number of steps before convergence can be reached
bool MultiChain::isEnoughInferenceSteps(const int stepsDone)
{
	int inferenceSteps = m_pChainsVec[0]->getWeight();
	if (inferenceSteps < m_options.m_minSteps) 
		return false;
	return true;
}


//correlationTest: return true if the correlations between all of the rates chains is higher than m_mcp.m_rLimit
bool MultiChain::correlationTest() 
{
	for (int i = 0; i < numChains(); ++i)
	{
		for (int j = i +1; j < numChains(); ++j)
		{
			MDOUBLE correlation = McRateUtils::calcCorrelationCoefficient(m_pChainsVec[i]->getBayesRates(),m_pChainsVec[j]->getBayesRates());
			if (correlation < m_options.m_rLimit)
				return false;
		}
	}
	return true;
}

//IsChainsLimit: check if chains have converged to a limit value according to a number of tests:
//return true only if ALL rates has converged
//1. minSteps = minimum number of steps before a limit can be reached
//2. maxSteps = maximum number of steps allowed for each chain. if all chains done more then this number then return true
//3. correlation test between all (rates) chains 
//4. limting test - check if a limit value is reached for all rates 
bool MultiChain::IsChainsLimit(const int stepsDone)
{
	//NOTE - this function assumes that all chains have performed the same number of steps!!
	if (!isEnoughInferenceSteps(stepsDone)) 
		return false;
	
	//if the correlation between all chains is lower than rLimit then retrun false
	if (m_options.m_bDoCorrelationTest) 
	{	
		if (!correlationTest())
			return false;
	}
	
	//this test check if all rates (from all chains) have reached their limit value.
	if (!checkIfRatesOfChainsConverged(stepsDone)) 
	{
		return false;
	}
	return true;
}


bool MultiChain::checkIfRatesOfChainsConverged(const int stepsDone) 
{
	//if (|curRates[pos] - referenceRates[pos]| > epsilonLimit) for any pos then return false
	static Vdouble referenceRates(m_pChainsVec[0]->getBayesRates().size(), -1);
	static int collectedRefRatesStep = 0; //the step number in which referenceRates were "collected "
	Vdouble curRates = getAverageRatesAllChains();
	for (int pos = 0; pos < curRates.size(); ++pos)
	{
		MDOUBLE diff = fabs(curRates[pos] - referenceRates[pos]);
		if (diff > m_options.m_epsilonLimit)
		{	//if the 2 vectors are different by epsilonLimit in ONE position then return false
			collectedRefRatesStep = stepsDone;
			referenceRates = curRates;
			return false;
		}
	}
	if ((stepsDone - collectedRefRatesStep) < m_options.m_stepsInLimit)
	{ //the vectors (referenceRates and curRates) are ~the same but for not enough steps
		return false;
	}
	return true;
}


//getAlphaAllChains: returns the average alpha from all chains 
MDOUBLE MultiChain::getAverageAlphaOverAllChains() const
{
	MDOUBLE sumAlpha = 0;
	for (int i = 0; i < numChains(); ++i)
	{
		sumAlpha += m_pChainsVec[i]->getAlpha();
	}
	
	return (sumAlpha / numChains());
}

Vdouble MultiChain::getAverageRatesAllChains()
{
	if (numChains() < 1)
	{
		errorMsg::reportError("number of chains is smaller than 1 in MultiChain::getAverageRatesAllChains");
	}
	if (numChains() == 1)
	{
		return m_pChainsVec[0]->getBayesRates();
	}

	Vdouble ratesVec=m_pChainsVec[0]->getBayesRates();
	for (int i = 1; i < numChains(); ++i)
	{
		for (int j = 0; j < m_pChainsVec[0]->getBayesRates().size(); ++i)
		{
			MDOUBLE val = (*m_pChainsVec[i])[j];
			ratesVec[j] += val;
		}
	}
	for (int k = 0; k < ratesVec.size(); ++k)
	{
		ratesVec[k]/=numChains();
	}
	return ratesVec;
}

//getChainsFileNames: returns a vector of names for each chain: each name is the outputdirectory followes by the chain name.
//for example "Runs//resultsFolder//chain1"
vector<string> MultiChain::getChainsFileNames(const int chainsNum,
											  const string& baseFolderName) 
{
	//createDir("", baseFolderName);
	vector<string> res;
	for (int i = 0 ; i < chainsNum ; ++i){
		res.push_back(baseFolderName + string("chain") + int2string(i));
	}
	return res;
}

//getBayesTreeAllChains: returns the average tree from all chains 
//if bTopology=false then calculates the average branch length of all chains
void MultiChain::getBayesTreeAllChains(tree &resTree, bool bTopology)
{
	if (bTopology == true)
	{
		errorMsg::reportError("not implemented yet: MultiChain::getBayesTreeAllChains(bTopology=true)");
	}
	
	vector<tree> chainTree;
	for(int tr = 0; tr < numChains(); ++tr)
	{
		chainTree.push_back(m_pChainsVec[tr]->getBayesTree(bTopology));
		//@@@debug - print tree
		//char numToAppend[20];
		//itoa(tr, numToAppend, 10);  
		//string fileName = string("treeFile") + string(numToAppend) + string(".txt");
		//ofstream treeFile(fileName.c_str()); 
		//chainTree[tr].output(treeFile);
		///////////////////////
	}


	resTree = chainTree[0];
	vector<tree::nodeP> nodesVec;
	resTree.getAllNodes(nodesVec, resTree.iRoot());
	for (int i = 0; i < nodesVec.size(); ++i) 
	{
		MDOUBLE branchL = 0;
		if (!nodesVec[i]->isRoot())
		{
			for (int j = 0; j < numChains(); ++j)
			{
				tree::nodeP pNode = chainTree[j].fromName(nodesVec[i]->name());
				branchL += pNode->dis2father();
			}
			branchL /= numChains();
			nodesVec[i]->setDisToFather(branchL);
		}
	}
}
