#include "LGT3State.h"
#include "threeStateAlphabet.h"
#include "treeIt.h"
#include "matrixUtils.h"
#include "trivialAccelerator.h"
#include "uniDistribution.h"
#include "sequence.h"
#include "someUtil.h"
#include "optimizeThreeStateModel.h"
#include "evaluateCharacterFreq.h"
#include "recognizeFormat.h"
#include "seqContainerTreeMap.h"
#include "computeDownAlg.h"
#include "talRandom.h"
#include "datMatrixHolder.h"
#include "replacementModel.h"
#include "readDatMatrix.h"
#include "treeUtil.h"

int main(int argc, char **argv) {
	//test();
	assess3stateLGT sk(argc,argv);
	//testDown();
	//loopLRT();
	//removeOrganisms();
	//sumBranchLengths();
	return 0;
}

assess3stateLGT::assess3stateLGT(int argc, char* argv[]): _sp(NULL),_isNullModel(false),
_useMarkovLimiting(true),_multipleStartingPoints(10),
_numSimulations(100),_upperBoundMuVals(5),_probChanges(NULL) {
	initialize(argc, argv);
	myLog::setLog(_logFile, 10);
	myLog::printArgv(5,argc,argv);
	//initializeStatesVector();
	if (_multipleStartingPoints>1) {
		runMultipleStartingPoints(_multipleStartingPoints);
	}
	else {
		startStochasticProcess();
		checkIfNullModel(); // if input params specifiy a null model
		runOptimization();
	}

	checkIfNullModel();// if params after optimization lead to a null model
	runComputationsAfterOptimization();
	printResults();
	LOG(5,<<endl<<"END OF LOG FILE "<<endl);
	myLog::endLog();

}

assess3stateLGT::assess3stateLGT(tree &tr, sequenceContainer &sc,bool isNullModel ,
										   string *rootAt, 
										   MDOUBLE *fixedMu1,MDOUBLE *fixedMu2,
										   MDOUBLE *fixedMu3,MDOUBLE *fixedMu4,
										   string *resFile, 
										   string *logFile):
_tr(tr),_sc(sc),_isNullModel(isNullModel),_useMarkovLimiting(true),_multipleStartingPoints(10),_numSimulations(100),
_upperBoundMuVals(5),_probChanges(NULL)

{
	if (logFile)
		_logFile = *logFile;
	else
		_logFile = "log.txt";
	if (resFile)
		_outPutFile = *resFile;
	else
		_outPutFile = "results.txt";
	if (rootAt)
		_rootAt = *rootAt;
	else
		_rootAt = "";


	rootTree();
	Vdouble initMuParameters(4,0.0);
	if (fixedMu1)
		_initParameters[0] = new MDOUBLE(*fixedMu1);
	if (fixedMu2)
		_initParameters[1] = new MDOUBLE(*fixedMu2);
	if (fixedMu3)
		_initParameters[2] = new MDOUBLE(*fixedMu3);
	if (fixedMu4)
		_initParameters[3] = new MDOUBLE(*fixedMu4);
	myLog::setLog(_logFile, 5);
	LOG(5,<<"Running internal call of assess3stateLGT (input tree and seq.container)"<<endl);
	initializeStatesVector();
	startStochasticProcess();
	runOptimization();
	runComputationsAfterOptimization();
	printResults();

}


assess3stateLGT::~assess3stateLGT() {
	if (_sp) 
		delete _sp;
	for (int i = 0; i < _initParameters.size(); ++i){
		if (_initParameters[i])
			delete _initParameters[i];
	}
	if (_probChanges)
		delete _probChanges;

}


void assess3stateLGT::initialize(int argc, char* argv[]) {
	_initParameters.resize(6,NULL);
	_doOptimizeParams.resize(6,true);
	string inputTree;// = "D:\\My Documents\\projects\\Thy\\species trees\\ancestralReconstruct\\Ciccarelli.parsed.wBL.noBP.removed.ph";
	string seqsFile;
	//_outPutFile = "D:\\My Documents\\projects\\Thy\\species trees\\ancestralReconstruct\\ML.results";
	_rootAt = "";
	_outPutFile = "results.txt";
	for (int ix = 0; ix < argc; ix++) {
		char *pchar=argv[ix];
		switch (pchar[0]) {
		case '-':
			switch (pchar[1]) {
			case 'a'://gain thyX
				if (_initParameters[0]) delete _initParameters[0];
				_initParameters[0] = new MDOUBLE(atof(argv[++ix]));
				break;
			case 'b'://gain thyA
				if (_initParameters[1]) delete _initParameters[1];
				_initParameters[1] = new MDOUBLE(atof(argv[++ix]));
				break;
			case 'c'://loss thyX
				if (_initParameters[2]) delete _initParameters[2];
				_initParameters[2] = new MDOUBLE(atof(argv[++ix]));
				break;
			case 'd'://loss thyA
				if (_initParameters[3]) delete _initParameters[3];
				_initParameters[3] = new MDOUBLE(atof(argv[++ix]));
				break;
			case 'e':case 'E': 
			switch (pchar[2]) {
				case 'm': case 'M':  _useMarkovLimiting=true; break; 
			case 'n': case 'N':  _useMarkovLimiting=false;break; 
				default: _useMarkovLimiting = true;; break;
			}
			break;
			case 'f':case 'F': 
			switch (pchar[2]) {
				case 'a': case 'A':  _doOptimizeParams[0]=false; break; //no optimization of gain X	
				case 'b': case 'B':  _doOptimizeParams[1]=false; break; //no optimization of gain A	
				case 'c': case 'C':  _doOptimizeParams[2]=false; break; //no optimization of loss X	
				case 'd': case 'D':  _doOptimizeParams[3]=false; break; //no optimization of loss A
				case 'p': case 'P':  _doOptimizeParams[4]=false; break; //no optimization of freq A	
				case 'q': case 'Q': _doOptimizeParams[5]=false; break; // no optimization of freq X	
				default: errorMsg::reportError("Must specify which parameter is fixed"); break;
			}
			break;

			case 'h':
				cout <<"USAGE:	"<<argv[0]<<" [-options] "<<endl <<endl;
				cout <<"----------------------------------------------"<<endl;
				cout <<"-t    input treeFile                          "<<endl;
				cout <<"-s    input seqsFile "<<endl;
				cout <<"-l    logFile				                   "<<endl;
				cout <<"-o    output file					           "<<endl;
				cout <<"-r    root at (input the name of the node)    "<<endl;
				cout <<"----------------------------------------------"<<endl;
				cout <<"-a    initial value for mu1 (gain thyX)	   "<<endl;
				cout <<"-b    initial value for mu2 (gain thyA)	   "<<endl;
				cout <<"-c    initial value for mu3 (loss thyX)		"<<endl;
				cout <<"-d    initial value for mu4 (loss thyA)	   "<<endl;
				cout <<"-p    initial value for pi0 (freq thyA)	   "<<endl;
				cout <<"-q    initial value for pi1 (freq thyX)	   "<<endl;
				cout <<"----------------------------------------------"<<endl;
				cout <<"-Fa    fixed value for mu1 (gain thyX)		   "<<endl;
				cout <<"-Fb    fixed value for mu2 (gain thyA)		   "<<endl;
				cout <<"-Fc    fixed value for mu3 (loss thyX)		   "<<endl;
				cout <<"-Fd    fixed value for mu4 (loss thyA)		   "<<endl;
				cout <<"-Fp    fixed value for pi0 (freq thyA)	   "<<endl;
				cout <<"-Fq    fixed value for pi1 (freq thyX)	   "<<endl;
				cout <<"----------------------------------------------"<<endl;
				cout <<"-E    Evaluation of root frequencies:	   "<<endl;
				cout <<"-Em   Use the markov limiting distribution based on the Q matrix (default)"<<endl;
				cout <<"-En   Do not use markov limiting distribution:"<<endl;
				cout <<"      Perform optimization for pi0 and pi1 OR"<<endl;	   
				cout <<"      Use fixed values, if given in the input "<<endl;
				cout <<"----------------------------------------------"<<endl;
				cout <<"-n    run null model (mu1=mu2=0, pi2=1)	   "<<endl;
				cout <<"----------------------------------------------"<<endl;
				cout <<"-m    number of multiple starting points (default=1)"<<endl;
				cout <<"----------------------------------------------"<<endl;
				cout <<"-i    number of simulations (default=100)"<<endl;
				cout <<"----------------------------------------------"<<endl;
				cout <<"-h or -? or -H     help                       "<<endl;
				cout <<"capital and lowercase letters are ok          "<<endl;
				cout <<"----------------------------------------------"<<endl;
			cout<<endl;	cerr<<" please press 0 to exit "; int d; cin>>d;exit (0);
			case 'i':
				_numSimulations=atoi(argv[++ix]);
				break;
			case 'l':
				_logFile=argv[++ix];
				break;
			case 'm':
				_multipleStartingPoints=atoi(argv[++ix]);
				break;
			case 'n':
				_isNullModel=true;
				break;
			case 'o':
				_outPutFile=argv[++ix];
				break;
			case 'p'://freq thyA
				if (_initParameters[4]) delete _initParameters[4];
				_initParameters[4] = new MDOUBLE(atof(argv[++ix]));
				break;
			case 'q'://freq thyX
				if (_initParameters[5]) delete _initParameters[5];
				_initParameters[5] = new MDOUBLE(atof(argv[++ix]));
				break;
			case 'r':
				_rootAt=argv[++ix];
				break;
			case 's':
				seqsFile=argv[++ix];
				break;
			case 't':
				inputTree=argv[++ix];
				break;
			}
		}
	}
	checkParameters();
	
	//tree 
	tree t(inputTree);
	_tr = t;
	rootTree();
	//sequence container
	threeStateAlphabet alph;
	ifstream in(seqsFile.c_str());
	_sc = recognizeFormat::read(in,&alph);
}

void assess3stateLGT::checkParameters() 
{
	if ((_initParameters[4]) && (!_initParameters[5])){//pi0 and not pi1
		errorMsg::reportError("cannot specify only pi0  without pi1");
	}
	if ((_initParameters[5]) && (!_initParameters[4])){//pi1 and not pi0
		errorMsg::reportError("cannot specify only pi1  without pi0");
	}
	if ((_doOptimizeParams[0]==false) && (_doOptimizeParams[1]==false) && (_doOptimizeParams[2]==false) && (_doOptimizeParams[3]==false)) {
	    if ((_useMarkovLimiting==true) || ((_doOptimizeParams[4]==false)&&(_doOptimizeParams[5]==false))) {
		if (_multipleStartingPoints>1) {
		    errorMsg::reportError("All parameters are fixed. cannot run with multiples starting points");
		}
		
	    }
	    
	}
	if ((_useMarkovLimiting==true) && ((_doOptimizeParams[4]==false)&&(_doOptimizeParams[5]==false))) {
	    errorMsg::reportError("Cannot fix pi0 and pi1 when use of markov limiting distribution is specified");

	}
    
    
}


void assess3stateLGT::checkIfNullModel(){
	if (checkIfZeroFreqs()){
		MDOUBLE gainX = static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu3();
		MDOUBLE gainA = static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu4();
		if ((DEQUAL(gainX, 0.0)) && ((DEQUAL(gainA, 0.0))))
			_isNullModel = true;
	}
}

//return true if 2 of 3 freqs are zero: if the product of 2 of them is zero and also the third is zero or 1
bool assess3stateLGT::checkIfZeroFreqs(){
	if ((DEQUAL(_sp->freq(0)*_sp->freq(1),0.0))){
		if (((DEQUAL(_sp->freq(2),0.0))) || (((DEQUAL(_sp->freq(2),1.0)))))
			return true;
	}
	return false;
}

void assess3stateLGT::rootTree(){
	if (!(_rootAt =="")){
		tree::nodeP myroot = _tr.findNodeByName(_rootAt); //returns NULL if not found
		if (myroot){
			_tr.rootAt(myroot);
			LOGnOUT(5,<<"tree rooted at "<<myroot->name()<<" id, "<<myroot->id()<<endl);
			LOGnOUT(5,<<"sons of root are "<<_tr.getRoot()->getSon(0)->name()<<" , "<<_tr.getRoot()->getSon(1)->name()<<" , "<<_tr.getRoot()->getSon(2)->name()<<endl);
			return;
		}
	}
	LOGnOUT(5,<<"default rooting used, root name is "<<_tr.getRoot()->name()<<endl);
	LOGnOUT(5,<<"sons of root are "<<_tr.getRoot()->getSon(0)->name()<<" , "<<_tr.getRoot()->getSon(1)->name()<<endl);

}

void assess3stateLGT::startStochasticProcess(){
	Vdouble init_mu_vals(4,0.5);
	Vdouble freq(3,-1.0);
	if (_initParameters[0])
		init_mu_vals[0] = *(_initParameters[0]);
	if (_initParameters[1])
		init_mu_vals[1] = *(_initParameters[1]);
	if (_initParameters[2])
		init_mu_vals[2] = *(_initParameters[2]);
	if (_initParameters[3])
		init_mu_vals[3] = *(_initParameters[3]);
	if (_isNullModel){
		init_mu_vals[0] = 0; //gain x
		init_mu_vals[1] = 0; // gain a
		_doOptimizeParams[0]=false;
		_doOptimizeParams[1]=false;
		freq[0]=0.0;
		freq[1]=0.0;
		freq[2]=1.0; // a+x freq must be one, since 0 and 1 are absorbing states
		_doOptimizeParams[4]=false; // don't optimize pi0
		_doOptimizeParams[5]=false; // don't optimize pi1
	}
	else {
		if (_initParameters[4]){//if pi0 is specified so is pi1
			freq[0] = *(_initParameters[4]);
			freq[1] = *(_initParameters[5]);
			MDOUBLE sum = freq[0]+freq[1];
			if (sum>1.0){
				string strErr = "error in startStochasticProcess, the sum of initial values for pi0 and pi1=";
				strErr+=double2string(sum);
				strErr+=" exceed 1";
				errorMsg::reportError(strErr);
			}
			freq[2] = 1.0 - freq[0] - freq[1];			
		}
		else {
			freq=evaluateCharacterFreq(_sc);
			LOGnOUT(5,<<"freqs from data"<<endl);
		}
	}
	LOGnOUT(5,<<"----init mu vals are: "<<(init_mu_vals[0])<<","<<(init_mu_vals[1])<<","<<(init_mu_vals[2])<<","<<(init_mu_vals[3])<<endl);
	LOGnOUT(5,<<"----init frequencies are: "<<freq[0]<<" "<<freq[1]<<" "<<freq[2]<<endl);
	threeStateModel glm(init_mu_vals[0],init_mu_vals[1],
		init_mu_vals[2],init_mu_vals[3],freq,_useMarkovLimiting);
	trivialAccelerator pijAcc(&glm);

	uniDistribution uniDistr;
	_sp = new stochasticProcess(&uniDistr,&pijAcc,false);



}

void assess3stateLGT::initializeStatesVector(){
	_states.resize(_tr.getNodesNum(),-1000);
	checkThatNamesInTreeAreSameAsNamesInSequenceContainer(_tr,_sc);
	seqContainerTreeMap scTreeMap(_sc,_tr);	
	vector <tree::nodeP> leaves;
	_tr.getAllLeaves(leaves,_tr.getRoot());
	for (int i=0; i< leaves.size();i++){
		int myleafId = (leaves[i])->id();
		int mySeqId = scTreeMap.seqIdOfNodeI(myleafId);
		_states[myleafId] = _sc[mySeqId][0];
	}
}



void assess3stateLGT::runOptimization(){
	if (_isNullModel){
		_doOptimizeParams[0] = false; // gain x
		_doOptimizeParams[1] = false; // gain a
		_doOptimizeParams[4] = false; // pi0
		_doOptimizeParams[5] = false; // pi1
	}
	if (_useMarkovLimiting) {
		_doOptimizeParams[4] = false; // pi0
		_doOptimizeParams[5] = false; // pi1
	}
	optimizeThreeStateModel opt(_tr,*_sp,_sc,_doOptimizeParams[0],_doOptimizeParams[1],
		_doOptimizeParams[2],_doOptimizeParams[3],_doOptimizeParams[4],_doOptimizeParams[5],_upperBoundMuVals,0.001);
	_likelihoodGivenOptimizedParams = opt.getBestL();
}

void assess3stateLGT::runComputationsAfterOptimization(){
	if ( (!checkIfZeroFreqs()) && (!_isNullModel)) {
		threeStateAlphabet alph;
		LOGnOUT(5,<<endl<<"running "<<_numSimulations<<" simulations"<<endl);
		simulateJumps sim(_tr,*_sp,&alph);
		sim.runSimulation(_numSimulations);
		cout<<"finished simulations"<<endl;

		VVVdouble posteriors;
		computePosteriorOfChangeGivenTerminals(posteriors);
		VVdouble ancestralPost;
		computeAncestralPosterior(posteriors);
		_exp01=computeExpectationOfChange(sim,posteriors,0,1);
		_exp10=computeExpectationOfChange(sim,posteriors,1,0);
		_exp02=computeExpectationOfChange(sim,posteriors,0,2);
		_exp12=computeExpectationOfChange(sim,posteriors,1,2);
		_exp20=computeExpectationOfChange(sim,posteriors,2,0);
		_exp21=computeExpectationOfChange(sim,posteriors,2,1);

		computePosterior(sim,posteriors);
	}
}

void assess3stateLGT::printResults(){
	ofstream oStream(_outPutFile.c_str());
	Vstring data;
	if ((!checkIfZeroFreqs()) && (!_isNullModel)) {
	    preparePrintData(data);
	    printDataOnTreeAsBPValues(oStream,data,_tr);
		oStream<<endl<<"================="<<endl;
		oStream<<endl<<"Bootstrap values in the tree represent posterior probabilities of: gainX//gainA//lossX//lossA[nodeName]"<<endl;
		oStream<<endl<<"================="<<endl;
	}
 
	
	oStream<<"Parameters are: "<<endl;
	oStream<<"Mu1 (gain 1)= "<<static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu1()<<endl;
	oStream<<"Mu2 (gain 0)= "<<static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu2()<<endl;
	oStream<<"Mu3 (loss 1)= "<<static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu3()<<endl;
	oStream<<"Mu4 (loss 0)= "<<static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu4()<<endl;
	for (int i=0; i<3; ++i){
		oStream<<"Freq ("<<i<<")= "<<static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->freq(i)<<endl;
	}
	MDOUBLE res = getLikelihood();
	oStream<<endl<<endl<<"Likelihood of data given parameters is "<<res<<endl<<endl;

	cout<<endl<<endl<<"Likelihood of data given parameters is "<<res<<endl<<endl;
	if ((!checkIfZeroFreqs()) && (!_isNullModel)) {
		oStream<<"Expectation 0 to 1 ="<<_exp01<<endl;
		oStream<<"Expectation 1 to 0 ="<<_exp10<<endl;
		oStream<<"Expectation 0 to 2 ="<<_exp02<<endl;
		oStream<<"Expectation 1 to 2 ="<<_exp12<<endl;
		oStream<<"Expectation 2 to 0 ="<<_exp20<<endl;
		oStream<<"Expectation 2 to 1 ="<<_exp21<<endl;
				treeIterTopDownConst tIt(_tr);
		for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
			if ((*_probChanges)[mynode->id()][0][2] > 0.40)
				oStream<<"Gain X in branch "<<mynode->name()<<"  probability "<<(*_probChanges)[mynode->id()][0][2]<<endl;
			if ((*_probChanges)[mynode->id()][1][2] > 0.40){
				oStream<<"Gain A in branch "<<mynode->name()<<"  probability "<<(*_probChanges)[mynode->id()][1][2]<<endl;
			}
		}
		for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
			if ((*_probChanges)[mynode->id()][2][0] > 0.40)
				oStream<<"Loss X in branch "<<mynode->name()<<"  probability "<<(*_probChanges)[mynode->id()][2][0]<<endl;
			if ((*_probChanges)[mynode->id()][2][1] > 0.40)
				oStream<<"Loss A in branch "<<mynode->name()<<"  probability "<<(*_probChanges)[mynode->id()][2][1]<<endl;
		}
		vector<tree::nodeP> leaves;
		_tr.getAllLeaves(leaves,_tr.getRoot());
		oStream<<endl<<"====================="<<endl<<"All leaves probs"<<endl;
		for (int n=0;n<leaves.size();++n) {
			tree::nodeP mynode = leaves[n];
			oStream<<mynode->name()<<":"<<endl;
			oStream<<"Gain X, probability "<<(*_probChanges)[mynode->id()][0][2]<<endl;
			oStream<<"Gain A, probability "<<(*_probChanges)[mynode->id()][1][2]<<endl;
			oStream<<"Loss X, probability "<<(*_probChanges)[mynode->id()][2][0]<<endl;
			oStream<<"Loss A, probability "<<(*_probChanges)[mynode->id()][2][1]<<endl;
		}
		oStream<<endl<<"====================="<<endl<<"ancestral state probabilities"<<endl;
		oStream<<"NODE"<<"\t";
		int letter = 0;
		int alphabetSize = _sp->alphabetSize();
		for (; letter < alphabetSize; ++letter)
			oStream<<letter<<"\t";
		oStream<<endl;
		for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) 
		{
			oStream<<mynode->name()<<"\t";
			for (letter = 0; letter < alphabetSize; ++letter)
				oStream<<_ancestralProbs[mynode->id()][letter]<<"\t";
			oStream<<endl;
		}
	}
}

//prepares the data to be printed as BP data on the tree
void assess3stateLGT::preparePrintData(Vstring &data){
	data.resize(_tr.getNodesNum());
	//(*probs)[myNode->id()][0][2]<<"//"<<(*probs)[myNode->id()][1][2];
	for (int i=0; i< data.size(); ++i) {
		string gainX = double2string((*_probChanges)[i][0][2]);
		string gainA = double2string((*_probChanges)[i][1][2]);
		string lossX = double2string((*_probChanges)[i][2][0]);
		string lossA = double2string((*_probChanges)[i][2][1]);
		tree::nodeP myNode = _tr.findNodeById(i);
		if (!myNode)
			errorMsg::reportError("error in assess3stateLGT::preparePrintData, cannot find node");
		data[i] = gainX + "//" + gainA + "//" + lossX + "//" + lossA + "[" + myNode->name() + "]";
	}
}

void assess3stateLGT::runMultipleStartingPoints(int numStartingPoints){
	MDOUBLE currBestLikelihood = -VERYBIG;
	Vdouble bestParams(_initParameters.size(),-1);
	int bestPoint(-1);
	LOGnOUT(5,<<endl<<endl<<"RUNNING MULTIPLE STARTING POINTS"<<endl<<endl);
	for (int it=0; it < numStartingPoints; ++it) {
		LOGnOUT(5,<<"****STARTING POINT NO. "<<it<<" *****"<<endl);
		if (it != 0) {
			if (_doOptimizeParams[2]) {
				if (_initParameters[2]) delete _initParameters[2];
				_initParameters[2] = new MDOUBLE(talRandom::giveRandomNumberBetweenTwoPoints(0.0,_upperBoundMuVals));
			}
			if (_doOptimizeParams[3]) {
				if (_initParameters[3]) delete _initParameters[3];
				_initParameters[3] = new MDOUBLE(talRandom::giveRandomNumberBetweenTwoPoints(0.0,_upperBoundMuVals));
			}
//			LOGnOUT(5,<<"----STARTING POINT IS: "<<(*_initParameters[0])<<","<<(*_initParameters[1])<<","<<(*_initParameters[2])<<","<<(*_initParameters[3]));
			if (!_isNullModel) {
				if (_doOptimizeParams[0]) {
					if (_initParameters[0]) delete _initParameters[0];
					_initParameters[0] = new MDOUBLE(talRandom::giveRandomNumberBetweenTwoPoints(0.0,_upperBoundMuVals));
				}
				if (_doOptimizeParams[1]) {
					if (_initParameters[1]) delete _initParameters[1];
					_initParameters[1] = new MDOUBLE(talRandom::giveRandomNumberBetweenTwoPoints(0.0,_upperBoundMuVals));
				}
				if (it%2==0) { // in 50% of the starts we use the freqs evaluated from the data (this will happen if _initParameters[5] is NULL
				//	if (_doOptimizeParams[4]) {
						if (_initParameters[4]) delete _initParameters[4];
						_initParameters[4] = new MDOUBLE(talRandom::giveRandomNumberBetweenTwoPoints(0.0,1.0));
						//bestParams[4]=(*_initParameters[4]);
				//	}
//					if (_doOptimizeParams[5]) {
						if (_initParameters[5]) delete _initParameters[5];
						_initParameters[5] = new MDOUBLE(talRandom::giveRandomNumberBetweenTwoPoints(0.0,1.0-(*_initParameters[4])));
		//				LOGnOUT(5,<<","<<(*_initParameters[4])<<","<<(*_initParameters[5])<<endl);
//					}
				}
				else {// frequencies from the data
					_initParameters[4] = NULL;
					_initParameters[5] = NULL;
	//				LOGnOUT(5,<<",freqs from data"<<endl);
				}
			}
		}
		if (_sp){
			delete _sp;
		}		
		startStochasticProcess();
		runOptimization();
		if (_likelihoodGivenOptimizedParams > currBestLikelihood){
			bestPoint=it;
			currBestLikelihood = _likelihoodGivenOptimizedParams;
			bestParams[0]=static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu1();
			bestParams[1]=static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu2();
			bestParams[2]=static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu3();
			bestParams[3]=static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->getMu4();
			bestParams[4]=static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->freq(0);
			bestParams[5]=static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->freq(1);
			// debug
			cout <<"bestParams[] is: " << bestParams[0] << '\t' << bestParams[1] << '\t'<< bestParams[2] << '\t'<< bestParams[3] << '\t'<< bestParams[4] << '\t'<< bestParams[5] << endl;
			// endof debug
			
		}
	}
	//set the best likelihood:
	_likelihoodGivenOptimizedParams=currBestLikelihood;
	static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->setMu1(bestParams[0]);
	static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->setMu2(bestParams[1]);
	static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->setMu3(bestParams[2]);
	static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->setMu4(bestParams[3]);
	Vdouble freq(3,-1.0);
	freq[0]=bestParams[4];
	freq[1]=bestParams[5];
	freq[2]=1.0-freq[0]-freq[1];
	static_cast<threeStateModel*>((*_sp).getPijAccelerator()->getReplacementModel())->setFreq(freq);

	LOGnOUT(5,<<"FINISHED MULTIPLE STARTING POINTS, best params are:"<<endl);
	LOGnOUT(5,<<"----BEST POINT IS POINT NUMBER " << bestPoint << " : "<<bestParams[0]<<","<<bestParams[1]<<","<<bestParams[2]<<","<<bestParams[3]<<","<<bestParams[4]<<","<<bestParams[5]<<endl);
	LOGnOUT(5,<<"likelihood is "<<_likelihoodGivenOptimizedParams<<endl<<endl);

}




MDOUBLE assess3stateLGT::getLikelihood(){
	computePijHom pij;
	pij.fillPij(_tr,*_sp,0,false); 
	MDOUBLE res = convert(log(likelihoodComputation::getLofPos(0,_tr,_sc,pij,*_sp)));
	return res;
}


// upL[node][letter] = max(letter_here){P(letter->letter_here)*upL[son1][letter_here]*upL[son2][letter_here]} for letter at father node.
// backtrack[node][letter] = argmax of above 
void assess3stateLGT::traverseUpML(VVdouble &upL, VVint &backtrack){ // input as empty vector to be filled
	int i;
	upL.resize(_tr.getNodesNum());
	for (i = 0; i < upL.size(); i++) 
		upL[i].resize(_sp->alphabetSize());
	backtrack.resize(_tr.getNodesNum());
	for (i = 0; i < backtrack.size(); i++)
		backtrack[i].resize(_sp->alphabetSize());
	treeIterDownTopConst tIt(_tr);
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		int father_state = 0;
		if (mynode->isLeaf()) {
			for (father_state=0; father_state<_sp->alphabetSize();father_state++){ // looping over states at father
				int myState = _states[mynode->id()];
				upL[mynode->id()][father_state]=_sp->Pij_t(father_state,myState,mynode->dis2father());
				backtrack[mynode->id()][father_state]=myState;
			}
		}
		else if (!(mynode->isRoot())) {
			for (father_state=0; father_state<_sp->alphabetSize();father_state++){ // looping over states at father
				MDOUBLE myMax = -1;
				int myArgMax=-1;
				for (int my_state=0;my_state<_sp->alphabetSize();my_state++){ // loop to find max over current node
					MDOUBLE val=_sp->Pij_t(father_state,my_state,mynode->dis2father());
					for (int son=0;son<mynode->getNumberOfSons();son++)
						val*=upL[mynode->getSon(son)->id()][my_state];
					if (val>myMax){
						myMax=val;
						myArgMax=my_state;
					}
				}
				if ((myMax<0) || (myArgMax<0))
					errorMsg::reportError("Error in traverseUpML: cannot find maximum");
				upL[mynode->id()][father_state]=myMax;
				backtrack[mynode->id()][father_state]=myArgMax;
			}
		}
		else {// root
			for (int root_state=0; root_state<_sp->alphabetSize();root_state++){ 
				MDOUBLE val=_sp->freq(root_state);
				for (int son=0;son<mynode->getNumberOfSons();son++)
					val*=upL[mynode->getSon(son)->id()][root_state];
				upL[mynode->id()][root_state]=val;
			}
		}
	}
}

// return likelihood of max joint reconstruction
MDOUBLE assess3stateLGT::traverseDownML(VVdouble &upL, VVint &backtrack, 
								VVint &transitionTypeCount) { // input as already filled vector
	if (backtrack.size() == 0) 
		errorMsg::reportError("error in assess3stateLGT::traverseDownML, input vector backtrack must be filled (call traverseUpML() first)");
	MDOUBLE LofJoint;
	int stateOfRoot;
	findMaxInVector(upL[(_tr.getRoot())->id()], LofJoint, stateOfRoot);
	_states[(_tr.getRoot())->id()] = stateOfRoot;
	transitionTypeCount.resize(_sp->alphabetSize());
	for (int i = 0; i < transitionTypeCount.size(); i++) 
		transitionTypeCount[i].resize(_sp->alphabetSize(),0);
	treeIterTopDownConst tIt(_tr);
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		if (mynode->isRoot()) continue;
		int myId = mynode->id();
		int stateAtFather = _states[mynode->father()->id()];
		if (mynode->isLeaf()) {
			transitionTypeCount[stateAtFather][_states[mynode->id()]]++;
			if ((_states[mynode->id()]!=stateAtFather))
				cout<<"switch from "<<mynode->father()->name()<<"("<<stateAtFather<<") to "<<mynode->name()<<"("<<_states[mynode->id()]<<")"<<endl;
			continue;
		}
		_states[mynode->id()]=backtrack[myId][stateAtFather];
		transitionTypeCount[stateAtFather][_states[mynode->id()]]++;
	}
	return log(LofJoint);
}

MDOUBLE assess3stateLGT::logLikelihoodOfRecontructedTree(){
	MDOUBLE likelihood = 1;
	treeIterTopDownConst tIt(_tr);
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		if (mynode->isLeaf()) 
			continue;
		int mystate = _states[mynode->id()];
		for (int son = 0; son <mynode->getNumberOfSons(); ++son){
			int stateSon = _states[mynode->getSon(son)->id()];
			MDOUBLE dis2Son = mynode->getSon(son)->dis2father();
			likelihood *= _sp->Pij_t(mystate,stateSon,dis2Son);
		}
	}
	likelihood*=_sp->freq(_states[_tr.getRoot()->id()]);
	return log(likelihood);

}

/*
Posterior of observing a certain state change along a branch:
P(Node=x,Father=y|D) = P(D,Node=x,Father=y)/P(D)
usage: posteriorPerNodePer2States[nodeId][x][y]
In case of the root there is no meaning to Father=y
--> put the posterior probability of x at the root in posteriorPerNodePer2States[root_id][0]
*/
void assess3stateLGT::computePosteriorOfChangeGivenTerminals(
	VVVdouble &posteriorPerNodePer2States){
	int numNodes = _tr.getNodesNum();
	int alphabetSize = _sp->alphabetSize();
	posteriorPerNodePer2States.resize(numNodes);
	for (int n=0;n<posteriorPerNodePer2States.size();++n)
		resizeMatrix(posteriorPerNodePer2States[n],alphabetSize,alphabetSize);
	suffStatGlobalHomPos sscUp;
	suffStatGlobalGamPos sscDown;
	sscUp.allocatePlace(numNodes,alphabetSize);
	computePijHom pi;
	pi.fillPij(_tr,*_sp); 

	computeUpAlg comp_Up;
	computeDownAlg comp_Down;
	comp_Up.fillComputeUp(_tr,_sc,0,pi,sscUp);
	comp_Down.fillComputeDownNonReversible(_tr,_sc,0,pi,sscDown,sscUp);

	treeIterTopDownConst tIt(_tr);
	MDOUBLE ll = convert(likelihoodComputation::getLofPos(0,_tr,_sc,pi,*_sp));
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		if (mynode->isRoot()) { //there is no meaning if node is root. put the posteior probability of letter in jointPost[root_id][letter][0]
			for(int letter=0; letter < alphabetSize; ++letter) {
				MDOUBLE prob = convert(sscUp.get(mynode->id(), letter)) * _sp->freq(letter) / ll;
				posteriorPerNodePer2States[mynode->id()][0][letter] = prob;
			}
			continue;
		}
		for (int x = 0; x<alphabetSize; ++x){ // state at son (mynode)
		    for (int y = 0; y<alphabetSize; ++y){ // state at father of mynode

			    posteriorPerNodePer2States[mynode->id()][y][x]=
				computePosterioGivenTerminalsPerBranch(mynode->id(),x,y,sscUp,sscDown,pi,ll);
			}
		}
	}
}

/*
Posterior of observing a certain state change along a branch:
P(Node=x,Father=y|D) = P(D,Node=x,Father=y)/P(D)
usage: posteriorPerNodePer2States[nodeId][x][y]
*/
MDOUBLE assess3stateLGT::computePosterioGivenTerminalsPerBranch
(int nodeId,int sonState, int fatherState,suffStatGlobalHomPos &sscUp,
 suffStatGlobalGamPos &sscDown,computePijHom &pi, MDOUBLE &LLData)
{
	MDOUBLE res=0.0;
	for (int stateAtRoot = 0; stateAtRoot<_sp->alphabetSize(); ++stateAtRoot){
		res+=(_sp->freq(stateAtRoot)*
			convert(sscDown.get(stateAtRoot,nodeId,fatherState))*
			convert(sscUp.get(nodeId,sonState))*
			pi.getPij(nodeId,fatherState,sonState));
	}
	
	return res/LLData;
}


/*
compute Prob(letter at Node N is x|Data): the posterior probabilities at ancestral states 
Use the pre-calculated joint posterior probability P(N=x, father(N)=y|D) and just sum over these probs:
Prob(N=x|Data) = sum{fatherState}[P(N=x, father(N)=y|D)]}
stores results in member VVdouble[node][state] _ancestralProbs
*/
void assess3stateLGT::computeAncestralPosterior(const VVVdouble& jointPost)
{
	int numNodes = _tr.getNodesNum();
	int alphabetSize = _sp->alphabetSize();
	resizeMatrix(_ancestralProbs, numNodes, alphabetSize);
    
	treeIterTopDownConst tIt(_tr);
	int letter;
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
	    if (mynode->isRoot()) {
		for(letter = 0; letter<alphabetSize; ++letter) 
		    _ancestralProbs[mynode->id()][letter] = jointPost[mynode->id()][0][letter];
		continue;
	    }
	    for(letter = 0; letter < alphabetSize; ++letter) {
		MDOUBLE sum = 0.0;
		for(int fatherLetter = 0; fatherLetter < alphabetSize; ++fatherLetter) {
		    sum += jointPost[mynode->id()][fatherLetter][letter];
		}
		_ancestralProbs[mynode->id()][letter] = sum;
	    }
	}
}


/*
Expectation of number of changes from character u to v along branch NF=
sum over all changes x,y:
Posterior(Node=x,Father=y|D)*Exp(changes u to v|Node=x,Father=y)
The second term is given to the function as input (can be obtained via simulations
*/
MDOUBLE assess3stateLGT::computeExpectationOfChange(
	simulateJumps &sim,  //input given from simulation studies
	VVVdouble &posteriorProbs,
	int fromState, int toState)
{
	MDOUBLE res = 0;

	treeIterTopDownConst tIt(_tr);
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		MDOUBLE aaa = computeExpectationOfChangePerBranch(sim,posteriorProbs,mynode,fromState,toState);
		res+=aaa;
	}
	return res;
}


void assess3stateLGT::computePosterior(
	simulateJumps &sim, //input given from simulation studies
	VVVdouble &posteriorProbs)
{
	int numNodes = _tr.getNodesNum();
	int alphabetSize = _sp->alphabetSize();
	_probChanges = new VVVdouble;
	(*_probChanges).resize(numNodes);
	for (int n=0;n<numNodes;++n)
		resizeMatrix((*_probChanges)[n],alphabetSize,alphabetSize);
	
	treeIterTopDownConst tIt(_tr);
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		for (int fromState=0;fromState<alphabetSize;++fromState)
		{
			for (int toState=0;toState<alphabetSize;++toState)
			{
				if (fromState==toState) 
					continue;
				(*_probChanges)[mynode->id()][fromState][toState]=
					computePosteriorPerBranch(sim,posteriorProbs,mynode,fromState,toState);
			}
		}
	}
}


MDOUBLE assess3stateLGT::computePosteriorPerBranch(
	simulateJumps &sim, //input given from simulation studies
	VVVdouble &posteriorProbs,
	tree::nodeP node,
	int fromState, int toState)
{
	int alphabetSize = _sp->alphabetSize();
	MDOUBLE res = 0;

	for (int x=0;x<alphabetSize;++x)
	{
		for (int y=0;y<alphabetSize;++y)
		{
			res+=sim.getProb(node->name(),x,y,fromState,toState)*posteriorProbs[node->id()][x][y];
		}
	}
	return res;
}

MDOUBLE assess3stateLGT::computeExpectationOfChangePerBranch(
	simulateJumps &sim, //input given from simulation studies
	VVVdouble &posteriorProbs,
	tree::nodeP node,int fromState, int toState)
{
	int alphabetSize = _sp->alphabetSize();

	MDOUBLE nodeExpectation = 0;
	for (int x = 0; x<alphabetSize; ++x){
		for (int y = 0; y<alphabetSize; ++y){
			nodeExpectation+=(posteriorProbs[node->id()][x][y]*
				sim.getExpectation(node->name(),x,y,fromState,toState));
		}
	}
	return nodeExpectation;
}


