#include <fstream>
using namespace std;

#include "SimulateRates.h"

#include "stochasticProcess.h"
#include "someUtil.h"
#include "talRandom.h"
#include "simulateTree.h"
#include "amino.h"
#include "maseFormat.h"
#include "uniDistribution.h"
#include "brLenOptEM.h"
#include "trivialAccelerator.h" 
#include "chebyshevAccelerator.h"

#include "McRateUtils.h"
using namespace McRateUtils;


//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////

SimulateRates::SimulateRates(const tree& in_tree, 
							 string ratesDistFile, 
							 int colNum, 
							 const stochasticProcess& sp, 
							 const alphabet* palph)
: m_sp(sp), m_tree(in_tree)
{
	vector<tree::nodeP> nodesVec;
	m_tree.getAllLeaves(nodesVec, m_tree.iRoot());
	m_pAlph = palph;

	loadRatesDistributionFromFile(ratesDistFile, colNum);
}




SimulateRates::~SimulateRates()
{

}

//runSimulation: makes the simulation
//1. choose the rates at each position
//2. generate the leaf sequences
//3. run the mcmc estimate
//4. calcculates the distance between the estimates and the real rates vector 
/*void SimulateRates::runSimulation(int seqLength)
{
	simulateTree simTree(m_tree, m_sp, m_pAlph);

	m_simulateRates.resize(seqLength);
	fillRateVector(seqLength);
	simTree.generate_seqWithRateVector(m_simulateRates, seqLength);
	sequenceContainer1G seqContainer = simTree.toSeqDataWithoutInternalNodes();

	////BUG: probelm with the seqContainer - extract to MSA file and read from it again
	string outStr = "Runs/checkMase.txt";
	ofstream seqFile(outStr.c_str());
	maseFormat::write(seqFile, seqContainer);
	seqFile.close();
	ifstream msaFile(outStr.c_str());
	seqContainer = maseFormat::read(msaFile,m_pAlph);
	/////////////////////////////////
	
	BayesTreeMng mng(m_tree, 1, m_pAlph, seqContainer, &m_sp, BayesTreeMng::UNIFORM);
	//Vdouble mcmcRates = mng.RunRatesChains(10, 30, true, false);
	//MDOUBLE dist = calcDistBetweenRatesVectors(mcmcRates, m_simulateRates, SUM_SQUARES);
}
*/

//1. make a new rates vector based on m_ratesDist and save to m_simulateRates
//2. simulate seuences based on the rate vector and write to MSA file 
void SimulateRates::generateSequencesToMsaFile(int seqLength, ofstream& outFile, SimType simulationType, distribution* pDist)
{
	m_simulateRates.clear();
	m_simulateRates.resize(seqLength);
	fillRateVector(seqLength, simulationType, pDist);

	simulateTree simTree(m_tree, m_sp, m_pAlph);
	simTree.generate_seqWithRateVector(m_simulateRates, seqLength);
	sequenceContainer1G seqContainer = simTree.toSeqDataWithoutInternalNodes();

	maseFormat::write(outFile, seqContainer);
}


//1. make a new rates vector based on m_ratesDist and save to m_simulateRates
//2. simulate seuences based on the rate vector and write to MSA file 
void SimulateRates::generateSequencesToMsaFile(int seqLength, ofstream& msaFile, ofstream& ratesFile, SimType simulationType, distribution* pDist)
{
	generateSequencesToMsaFile(seqLength, msaFile, simulationType, pDist);
	printSimulatedRates(ratesFile);
}


//1. make a new rates vector based on according to a fixed distribution
//2. simulate seuences based on the rate vector and write to MSA file 
void SimulateRates::generateSequencesToMsaFileFixedRates(int seqLength, ofstream& msaFile, ofstream& ratesFile, MDOUBLE max_simRate)
{
	m_simulateRates.clear();
	m_simulateRates.resize(seqLength);
	fillRateVectorFixed(seqLength, max_simRate);

	simulateTree simTree(m_tree, m_sp, m_pAlph);
	simTree.generate_seqWithRateVector(m_simulateRates, seqLength);
	sequenceContainer1G seqContainer = simTree.toSeqDataWithoutInternalNodes();

	maseFormat::write(msaFile, seqContainer);
	printSimulatedRates(ratesFile);
}



void SimulateRates::printSimulatedRates(ofstream& outFile)
{
	printTime(outFile);
	outFile<<"# rates were created with SimulateRates"<<endl;
	
	for (int i=0; i < m_simulateRates.size(); ++i) 
	{
		outFile<<i+1;
		outFile<< "\t"<< m_simulateRates[i];
		outFile<<endl;
	}

	MDOUBLE ave = computeAverage(m_simulateRates);
	MDOUBLE std = computeStd(m_simulateRates);
	if (((ave<1e-9)) && (ave>(-(1e-9)))) 
		ave=0;
	if ((std>(1-(1e-9))) && (std< (1.0+(1e-9)))) 
		std=1.0;
	outFile<<"# Average = "<<ave<<endl;
	outFile<<"# Standard Deviation = "<<std<<endl;
}


void SimulateRates::loadRatesDistributionFromFile(string ratesDistFile, int colNum)
{
	ifstream ratesFile(ratesDistFile.c_str());
	vector<string> distFileData;
	putFileIntoVectorStringArray(ratesFile,distFileData);
	if (distFileData.empty()){
		errorMsg::reportError("unable to open file, or file is empty in SimulateRates::loadRatesDistFromFile");
	}


	vector<string>::const_iterator it= distFileData.begin();
	for (; it!= distFileData.end(); ++it) 
	{
		if (it->empty())
		{// empty line continue
			continue; 
		}
		if ((*it)[0]=='#')
		{// remark line 
			continue;  
		}

		if (colNum == 2)
		{

			//in ratesDistFile: only pos# and rate 
			int startRate = 1+ it->find("\t", 0);
			int endRate = 1+ it->find("\t", startRate);
			if (startRate>0)
			{
				string rateStr = it->substr(startRate, endRate-startRate -1);
				MDOUBLE rate = string2double(rateStr);
				m_ratesDist.push_back(rate);
			}
		}
		else if (colNum == 3)
		{
			//in ratesDistFile: column1 is pos#, column2 = aa, column3 = rate 
			int startAA = 1+ it->find("\t", 0);
			int startRate = 1+ it->find("\t", startAA);
			int endRate = 1+ it->find("\t", startRate);
			if (startAA>0)
			{
				string AA = it->substr(startAA, startRate-startAA -1);
				string rateStr = it->substr(startRate, endRate-startRate -1);
				MDOUBLE rate = string2double(rateStr);
				m_ratesDist.push_back(rate);
			}
		}
		else
			errorMsg::reportError("unknown number of columns in SimulateRates::loadRatesDistributionFromFile");

	}
}

//fill the simulation rates vector with randomely chosen rates from the distribution
//if simulationType = RANDOM draw rates randomely
//if simulationType = FILE draw from the rates distribution file
//if simulationType = DISTRIBUTION draw from the m_sp given distribution
void SimulateRates::fillRateVector(int seqLength, SimType simulationType, distribution* pDist, bool bScale/*=true*/)
{
	int i;
	switch (simulationType)
	{
	case RANDOM:
		for (i =0; i<seqLength; ++i)
		{
			MDOUBLE rate = talRandom::giveRandomNumberBetweenZeroAndEntry(10.0);
			m_simulateRates[i] = rate;
		}
		break;
	case FILE_DISTRIBUTION:
		if ( m_ratesDist.empty() == true)
		{
			errorMsg::reportError("the rates distribution is empty in function SimulateRates::fillRateVector()");
		}
		for (i =0; i<seqLength; ++i)
		{
			int distSize = m_ratesDist.size();
			int rateIndex = talRandom::giveIntRandomNumberBetweenZeroAndEntry(distSize-1);
			m_simulateRates[i] = m_ratesDist[rateIndex];
		}
		break;
	case FILE_FIXED:
		if ( (m_ratesDist.empty() == true) || (m_ratesDist.size() !=  seqLength))
		{
			errorMsg::reportError("the rates distribution is empty OR fixed_distribution_file is not same size as seqLen in function SimulateRates::fillRateVector()");
		}
		for (i =0; i<seqLength; ++i)
		{
			MDOUBLE r = m_ratesDist[i];
			m_simulateRates[i] = m_ratesDist[i];
		}
		break;
	case DISTRIBUTION:
		assert (pDist);
		//we assume that the probability of each category is equal
		for (i =0; i<seqLength; ++i)
		{
			int catNum = pDist->categories();
			int category = talRandom::giveIntRandomNumberBetweenZeroAndEntry(catNum);
			MDOUBLE rate = pDist->rates(category);
			m_simulateRates[i] = rate;
		}
		break;
	default:
		errorMsg::reportError("unknown simulation type in SimulateRates::fillRateVector()");
	}

	if (bScale == true)
		McRateUtils::scaleVec(m_simulateRates, 1.0);
}





//fill the simulation rates vector with randomely chosen rates from the distribution
//if bRandom = true - draw rates randomely and scale the rates so that avg rate is one
//if bRandom = false draw from the rates distribution file
void SimulateRates::fillRateVectorFixed(int seqLength, MDOUBLE max_rate, int repetition/*=5*/,bool bScale/*=false*/)
{
	MDOUBLE interval = (max_rate * repetition) / seqLength;
	MDOUBLE rate = 0.0;
	for (int pos =0; pos<seqLength; )
	{

		for (int j=0; j<repetition; ++j)
		{
			if (rate == 0)
				m_simulateRates[pos++] = 0.000001;
			else
				m_simulateRates[pos++] = rate;
		}
		rate += interval;
	}
	
	if (bScale == true)
		McRateUtils::scaleVec(m_simulateRates, 1.0);
}



