#include "optimizeThreeStateModel.h"
#include "numRec.h"
#include "bblEM.h"
#include "logFile.h"


optimizeThreeStateModel::optimizeThreeStateModel(tree& et, //findBestAlphaFixedTree
					   stochasticProcess& sp, sequenceContainer &sc,
					   bool optimizeMu1 , bool optimizeMu2 ,
					   bool optimizeMu3 ,bool optimizeMu4 ,
					   bool optimizePi0, bool optimizePi1,
					   const MDOUBLE upperBoundMuVals,
					   const MDOUBLE epsilonOptimization){
	//bool isReversible=true;
	//if (replacementModel *tmpModel = dynamic_cast<threeStateModel*>(sp.getPijAccelerator()->getReplacementModel()))
	bool	isReversible=false;
	MDOUBLE currM1=0;
	MDOUBLE currM2=0;
	MDOUBLE currM3=0; // for non-reversible model only
	MDOUBLE currM4=0; // for non-reversible model only
	MDOUBLE currPi0=0;
	MDOUBLE currPi1=0;
	threeStateModel* modelCasted = 
		static_cast<threeStateModel*>(sp.getPijAccelerator()->getReplacementModel());
	// left, middle, right limits
	const MDOUBLE upperBoundPi=1;
	const MDOUBLE lowerBound=0.0;

// initialize
	_bestL = VERYSMALL;
	MDOUBLE currBestL=VERYSMALL;
	_bestMu1 = modelCasted->getMu1();
	_bestMu2 = modelCasted->getMu2();
	//if (!isReversible){
	_bestMu3 = modelCasted->getMu3();
	_bestMu4 = modelCasted->getMu4();
	_bestPi0 = modelCasted->freq(0);
	_bestPi1 = modelCasted->freq(1);

	//}
	int it=0;
	for (it=0;it<20;it++){
		bool changed=false;
		LOGnOUT(5,<<"OPTIMIZATION ITERATION="<<it<<endl);
/****mu1****/
		if (optimizeMu1){
			LOGnOUT(5,<<"Optimizing m1"<<endl);
			currBestL = -brent(lowerBound,_bestMu1,upperBoundMuVals,evalParam(et,sp,sc,evalParam::mu1,isReversible),epsilonOptimization,&currM1);
			if (currBestL>_bestL+epsilonOptimization) {
				modelCasted->setMu1(currM1);
				LOGnOUT(5,<<"currBestL="<<currBestL<<" for m1="<<currM1<<endl);
				_bestL=currBestL;			
				_bestMu1=currM1;
				changed=true;
			}
			else{
				LOGnOUT(5,<<"saving last value "<<_bestMu1<<", no improvement in likelihood "<<_bestL<<endl);
				modelCasted->setMu1(_bestMu1);
			}
		}
/****mu2****/
		if (optimizeMu2){
			LOGnOUT(5,<<"Optimizing m2"<<endl);
			currBestL = -brent(lowerBound,_bestMu2,upperBoundMuVals,evalParam(et,sp,sc,evalParam::mu2,isReversible),epsilonOptimization,&currM2);
			if (currBestL>_bestL+epsilonOptimization) {
				modelCasted->setMu2(currM2);
				LOGnOUT(5,<<"currBestL="<<currBestL<<" for m2="<<currM2<<endl);
				_bestL=currBestL;
				_bestMu2=currM2;
				changed=true;
			}
			else{
				LOGnOUT(5,<<"saving last value "<<_bestMu2<<", no improvement in likelihood "<<_bestL<<endl);
				modelCasted->setMu2(_bestMu2);
			}

		}
		if (!isReversible){
/****mu3****/
			if (optimizeMu3){
				LOGnOUT(5,<<"Optimizing m3"<<endl);
				currBestL = -brent(lowerBound,_bestMu3,upperBoundMuVals,evalParam(et,sp,sc,evalParam::mu3,isReversible),epsilonOptimization,&currM3);
				if (currBestL>_bestL+epsilonOptimization) {
					modelCasted->setMu3(currM3);
					LOGnOUT(5,<<"currBestL="<<currBestL<<" for m3="<<currM3<<endl);
					_bestL=currBestL;			
					_bestMu3=currM3;
					changed=true;
				}
				else{
					LOGnOUT(5,<<"saving last value "<<_bestMu3<<", no improvement in likelihood "<<_bestL<<endl);
					modelCasted->setMu3(_bestMu3);
				}
			}
/****mu4****/
			if (optimizeMu4){
				LOGnOUT(5,<<"Optimizing m4"<<endl);
				currBestL = -brent(lowerBound,_bestMu4,upperBoundMuVals,evalParam(et,sp,sc,evalParam::mu4,isReversible),epsilonOptimization,&currM4);
				if (currBestL>_bestL+epsilonOptimization) {
					modelCasted->setMu4(currM4);
					LOGnOUT(5,<<"currBestL="<<currBestL<<" for m4="<<currM4<<endl);
					_bestL=currBestL;
					_bestMu4=currM4;
					changed=true;
				}
				else{
					LOGnOUT(5,<<"saving last value "<<_bestMu4<<", no improvement in likelihood "<<_bestL<<endl);
					modelCasted->setMu4(_bestMu4);
				}
			}
/****pi0****/
			if (optimizePi0){
				LOGnOUT(5,<<"Optimizing pi0"<<endl);
				currBestL = -brent(lowerBound,_bestPi0,upperBoundPi,evalParam(et,sp,sc,evalParam::pi0,isReversible),
					epsilonOptimization,&currPi0);
				if (currBestL>_bestL+epsilonOptimization) {
					LOGnOUT(5,<<"currBestL="<<currBestL<<" for pi0="<<currPi0<<endl);
					Vdouble freqs = modelCasted->getFreqs();
					computeRelativeFreqsFollowingOneChanged(currPi0,0,freqs);
					modelCasted->setFreq(freqs);
					_bestL=currBestL;
					_bestPi0=currPi0;
					_bestPi1=modelCasted->freq(1);
					changed=true;
				}
				else {
					LOGnOUT(5,<<"saving last value "<<_bestPi0<<", no improvement in likelihood "<<_bestL<<endl);
				}

			}

/****pi1****/
			if (optimizePi1){
				LOGnOUT(5,<<"Optimizing pi1"<<endl);
				currBestL = -brent(lowerBound,_bestPi1,upperBoundPi,evalParam(et,sp,sc,evalParam::pi1,isReversible),
					epsilonOptimization,&currPi1);
				if (currBestL>_bestL+epsilonOptimization) {
					LOGnOUT(5,<<"currBestL="<<currBestL<<" for pi1="<<currPi1<<endl);
					Vdouble freqs = modelCasted->getFreqs();
					computeRelativeFreqsFollowingOneChanged(currPi1,1,freqs);
					_bestL=currBestL;
					_bestPi1=currPi1;
					_bestPi0=modelCasted->freq(0);
					changed=true;
				}
				else {
					LOGnOUT(5,<<"saving last value "<<_bestPi1<<", no improvement in likelihood "<<_bestL<<endl);
				}
			}
		}

		if (changed==false){
			LOGnOUT(5,<<"optimization over "<<endl);
			break;
		}
	}
	if (it==20)
		LOG(5,<<"Too many iterations in optimizeThreeStateModel. Last optimized parameters are used."<<endl);
}


MDOUBLE evalParam::operator() (MDOUBLE val) {
	threeStateModel* modelCasted = 
		static_cast<threeStateModel*>(_sp.getPijAccelerator()->getReplacementModel());

	Vdouble freqs;
	switch (_paramName) {
	case mu1 : modelCasted->setMu1(val) ; break;
	case mu2 : modelCasted->setMu2(val) ; break;
	case mu3 : modelCasted->setMu3(val) ; break;
	case mu4 : modelCasted->setMu4(val) ; break;
	case pi0: // NOTE! we could save the time of computing pij and up, and only multiply the end result by freq
		freqs = static_cast<threeStateModel*>(_sp.getPijAccelerator()->getReplacementModel())->getFreqs();
		computeRelativeFreqsFollowingOneChanged(val,0,freqs);
		modelCasted->setFreq(freqs);
		break;
	case pi1:// NOTE! we could save the time of computing pij and up, and only multiply the end result by freq
		freqs = static_cast<threeStateModel*>(_sp.getPijAccelerator()->getReplacementModel())->getFreqs();
		computeRelativeFreqsFollowingOneChanged(val,1,freqs);
		modelCasted->setFreq(freqs);
		break;
	default: 
		errorMsg::reportError("evalParam::operator, illegal parameter name");
		break;
	}


	computePijHom pij;
	pij.fillPij(_et,_sp,0,false); // the false indicates that the model is non-reversible
	MDOUBLE res = convert(log(likelihoodComputation::getLofPos(0,_et,_sc,pij,_sp)));
	return -res;
}
