#include "McRate.h"


#include "errorMsg.h"
#include "amino.h"
#include "nucleotide.h"
#include "maseFormat.h"
#include "molphyFormat.h"
#include "clustalFormat.h"
#include "fastaFormat.h"
#include "phylipFormat.h"
#include "uniDistribution.h"
#include "datMatrixHolder.h"
#include "readDatMatrix.h"
#include "chebyshevAccelerator.h"
#include "nucJC.h"
#include "aaJC.h"
#include "trivialAccelerator.h"
#include "someUtil.h"


McRate::McRate(McRateOptions &options)
:m_options(options)
{
	run();
}

McRate::McRate(int argc, char* argv[])
 :m_options(argc, argv)
{
	run();
}



McRate::~McRate()
{
	if (m_pMultiChain != NULL)
	{
		delete m_pMultiChain;
		m_pMultiChain = NULL;
	}
	if (m_pAlph != NULL)
	{
		delete m_pAlph;
		m_pAlph = NULL;
	}
}


void McRate::run()
{
	printRunInfo(cout);
	printOptionParameters();
	getStartingSequenceData();
	getStartingStochasticProcess();// must be b4 the tree!
	getStartingEvolTree();

	if (strcmp(m_options.m_treefileStr.c_str(), "") == 0)
	{
		m_pMultiChain = new MultiChain(m_options.m_chainsNum, m_seqContainer, 
									   m_pSp, m_options.m_outDirStr, m_options);
	}
	else
	{
		m_pMultiChain = new MultiChain(m_tree, m_options.m_chainsNum, m_seqContainer, 
									   m_pSp, m_options.m_outDirStr, m_options);
	}

	computeRate4Site(); 
	computeData4Pos(); //compute the number of sequences in each position

	//removePositionInRateVectorWithGaps();
	printRates();
}

//printRunInfo: print names into output file
void McRate::printRunInfo(ostream& out)
{
	out<<endl;
	out<<" ======================================================="<<endl;
	out<<" The McRate project:                                    "<<endl;
	out<<" Itay Mayrose:  itaymay@post.tau.ac.il                  "<<endl;
	out<<" Amir Mitchel:  mitchel@post.tau.ac.il                  "<<endl;
	out<<" Dan Graur:     graur@post.tau.ac.il                    "<<endl;
	out<<" Tal Pupko:     pupko@csit.fsu.edu                      "<<endl;
	out<<" ======================================================="<<endl;
	out<<endl;
}


//printOptionParameters: print the options received from user to screen
void McRate::printOptionParameters()
{
	cout<<"\n ---------------------- THE PARAMETERS ----------------------------"<<endl;
	if (m_options.m_treefileStr.size() > 0)
		cout << "The tree file is : " << m_options.m_treefileStr << endl;
	if (m_options.m_seqFileStr.size() > 0) 
		cout << "The MSA file is : " << m_options.m_seqFileStr << endl;
	if (m_options.m_outDirStr.size() > 0)
		cout << "The output directory is : " << m_options.m_outDirStr << endl;
	if 	(strcmp(m_options.m_referenceSeq.c_str(),"none") != 0)
		cout << "The reference sequence is : " << m_options.m_referenceSeq << endl;
	switch (m_options.m_modelName)
	{
	case (McRateOptions::DAY):
		cout << "probablistic_model = DAY" << endl;
		break;
	case (McRateOptions::JTT):
		cout << "probablistic_model = JTT" << endl; 
		break;
	case (McRateOptions::REV):
		cout << "probablistic_model = REV" << endl;
		break;
	case (McRateOptions::AAJC):
		cout << "probablistic_model = AAJC" << endl;
		break;
	case (McRateOptions::NUCJC):
		cout << "probablistic_model = NUCJC" <<endl;
		break;
	}

	cout << "number of discrete gamma categories = " << m_options.m_numCategories<< endl;

	if (m_options.m_bRemoveGaps)
	{
		cout << " positions with gaps were removed from the analysis." << endl;
		cout << " the number of each position in the results refers to the gapless alignment." << endl;
	}
	else 
	{
		cout << "(gaps characters were treated as missing data.)" << endl;
	}
	cout<<"\n ---------------------- Chain Parameters ----------------------------"<<endl;
	cout << " Number of burning steps is:" << m_options.m_burningTime <<endl;
	cout << " Number of inference steps is:" << m_options.m_inferenceTime<<endl;
	cout << " calibration cycle every :" << m_options.m_calibrationCycle<<endl;
	cout << " thinning:" << m_options.m_thinning<<endl;


	cout << "\n -----------------------------------------------------------------" << endl;
}


//getStartingSequenceData:
//initialize the sequence container from MSA file
void McRate::getStartingSequenceData()
{
	if (m_options.m_seqFileStr == "")
	{
		errorMsg::reportError("Please give a sequence file name in the command line");
	}
	ifstream in(m_options.m_seqFileStr.c_str());
	int alphabetSize = m_options.m_alphabet_size;
	if (alphabetSize == 4)
		m_pAlph = new nucleotide();
	else if (alphabetSize == 20)
		m_pAlph = new amino();
	else 
		errorMsg::reportError("no such alphabet in function rate4site::getStartingSequenceData");

	sequenceContainer1G original;
	switch (m_options.m_seqInputFormat)
	{
	case McRateOptions::MASE: 
		original = maseFormat::read(in, m_pAlph);
		break;
	case McRateOptions::MOLPHY: 
		original = molphyFormat::read(in, m_pAlph);
		break;
	case McRateOptions::CLUSTAL: 
		original = clustalFormat::read(in, m_pAlph);
		break;
	case McRateOptions::FASTA:
		original = fastaFormat::read(in, m_pAlph);
		break;
	case McRateOptions::PHYLIP:
		original = phylipFormat::read(in, m_pAlph);
		break;
	default:
		errorMsg::reportError(" format not implemented yet in this version... ",1);
	}

	if (m_options.m_bRemoveGaps)
	{
		errorMsg::reportError("removeGaps currently not implemented");
	}
	else
	{
		original.changeGapsToMissingData();
		m_seqContainer = original;
	}
}


//getStartingStochasticProcess:
//initialize the stochastic process to be used
void McRate::getStartingStochasticProcess()
{
	distribution *pDist = new gammaDistribution(1.0, GAMMA_CATEGORIES); 
	replacementModel* pProbMod = NULL;
	pijAccelerator* pPijAcc = NULL;
	switch (m_options.m_modelName)
	{
	case (McRateOptions::DAY):
		pProbMod = new pupAll(datMatrixHolder::dayhoff);
		pPijAcc = new chebyshevAccelerator(pProbMod);
		break;
	case (McRateOptions::JTT):
		pProbMod = new pupAll(datMatrixHolder::jones);
		pPijAcc = new chebyshevAccelerator(pProbMod);
		break;
	case (McRateOptions::REV):
		pProbMod = new pupAll(datMatrixHolder::mtREV24);
		pPijAcc = new chebyshevAccelerator(pProbMod);
		break;
	case (McRateOptions::WAG):
		pProbMod = new pupAll(datMatrixHolder::wag);
		pPijAcc = new chebyshevAccelerator(pProbMod);
		break;
	case (McRateOptions::CPREV):
		pProbMod = new pupAll(datMatrixHolder::cpREV45);
		pPijAcc = new chebyshevAccelerator(pProbMod);
		break;
	case (McRateOptions::NUCJC):
		pProbMod = new nucJC; 
		pPijAcc = new trivialAccelerator(pProbMod);
		break;
	case (McRateOptions::AAJC):
		pProbMod = new aaJC;
		pPijAcc = new trivialAccelerator(pProbMod);
		break;
	default:
		errorMsg::reportError("the specified probablistic model is not yet available");
	}

	m_pSp = new stochasticProcess(pDist, pPijAcc);
	if (pProbMod) 
		delete pProbMod;
	if (pPijAcc)
		delete pPijAcc;
	if (pDist) 
		delete pDist;
}


//getStartingEvolTree: get the input tree from input tree file.
//if tree file is not received from user - each chain creates a random tree
void McRate::getStartingEvolTree()
{
	if (strcmp(m_options.m_treefileStr.c_str(), "") != 0) 
		getStartingTreeFromTreeFile();
}

void McRate::getStartingTreeFromTreeFile()
{
	m_tree = tree(m_options.m_treefileStr);
	rootToUnrootedTree(m_tree);
}

void McRate::computeRate4Site()
{
	m_pMultiChain->runChains(m_options.m_burningTime, m_options.m_inferenceTime);
	//m_rate4site = m_pMultiChain->getRatesAllChains()
}

/*
multiChainParameters McRate::getChainsParam()
{
	multiChainParameters res;
	res.m_kThinning = m_options.m_thinning;
	res.m_calibrationCycle = m_options.
	res.m_alternateSteps = m_options.m_alternateSteps;
	res.
	return res;
}

*/


void McRate::computeData4Pos()
{
	m_seqNum4pos.resize(m_seqContainer.seqLen());
	for (int pos = 0; pos < m_seqContainer.seqLen(); ++pos) 
	{
		int numOfSeqPos = m_seqContainer.numberOfSequences();
		for (int i = 0; i < m_seqContainer.numberOfSequences(); ++i) 
		{
			if (m_seqContainer[i][pos] < 0)
				--numOfSeqPos;
		}
		m_seqNum4pos[pos] = numOfSeqPos;
	}
}



void McRate::printRates()
{
	m_options.out() <<"# rates were created with McRate"<<endl;
	m_options.out() <<"++++++++++++++++++++++++++++++++++++++++++++" <<endl;
	m_options.out() <<" average alpha is: " <<getAverageAlpha() <<endl;
	
	m_options.out() << "#POS"<<" "<<"SEQ"<<"   " << "SCORE"; 

	const sequence* pSeq = NULL;
	if (strcmp(m_options.m_referenceSeq.c_str(), "none") == 0)
	{
		pSeq = &(m_seqContainer[0]);
	}
	else 
	{
		int id1 = m_seqContainer.getId(m_options.m_referenceSeq, true);
		pSeq = &(m_seqContainer[id1]);
	}

	Vdouble rates = getRates();
	int posInBigSeq=-1;
	for (int i = 0; i < rates.size(); ++i)
	{
		posInBigSeq++;

		while ((*pSeq)[posInBigSeq] == pSeq->getAlphabet()->unknown()) 
		{
			posInBigSeq++;
		}

		m_options.out() << i+1<< "    ";
		m_options.out() << pSeq->getAlphabet()->fromInt((*pSeq)[posInBigSeq]) <<"   ";
		m_options.out() << rates[i] << "\t";
		m_options.out() << endl;
	}

	MDOUBLE ave = computeAverage(rates);
	MDOUBLE std = computeStd(rates);
	if (((ave<1e-9)) && (ave>(-(1e-9)))) ave=0;
	if ((std>(1-(1e-9))) && (std< (1.0+(1e-9)))) std=1.0;
	m_options.out() << "Average = " << ave << endl;
	m_options.out() << "Standard Deviation = " << std << endl;
}
