#include "likeDist.h"
#include "numRec.h"
#include "sequence.h"

class C_evalLikeDistDirect{
private:
	const stochasticProcess& _sp;
	const sequence& _s1;
	const sequence& _s2;
	const vector<MDOUBLE>  * _weights;
public:
	C_evalLikeDistDirect(const stochasticProcess& inS1,
		const sequence& s1,
		const sequence& s2,
		const vector<MDOUBLE>  * weights): _sp(inS1),_s1(s1),_s2(s2),_weights(weights) {};

	MDOUBLE operator() (MDOUBLE dist) {
		MDOUBLE sumL=0.0;
		MDOUBLE sumR = 0.0;
		for (int pos=0; pos < _s1.seqLen(); ++pos){
			if ((_s1[pos] == -2) && (_s2[pos] == -2)) {
				continue; // the case of two unknowns
			}
			sumR = 0;
			for (int rateCategor = 0; rateCategor<_sp.categories(); ++rateCategor) {
					MDOUBLE rate = _sp.rates(rateCategor);
					MDOUBLE pij= 0;
					if ((_s1[pos] != -2) && (_s2[pos] != -2)) {
						pij= _sp.Pij_t(_s1[pos],_s2[pos],dist*rate);
						if (pij==0) {
							pij = 0.000000001;
						}
					}
					
					else {
						pij = 1;
					}
					sumR += pij * ((_s1[pos] != -2)?_sp.freq(_s1[pos]):_sp.freq(_s2[pos]))*_sp.ratesProb(rateCategor);
			}
			assert(sumR!=0);
			sumL += log(sumR)*(_weights ? (*_weights)[pos]:1);
		}
		return -sumL;
	};
};

class C_evalLikeDistDirect_d{ // derivative.
private:
	const stochasticProcess& _sp;
	const sequence& _s1;
	const sequence& _s2;
	const vector<MDOUBLE>  * _weights;
public:
	C_evalLikeDistDirect_d(const stochasticProcess& inS1,
		const sequence& s1,
		const sequence& s2,
		const vector<MDOUBLE>  * weights): _sp(inS1),_s1(s1),_s2(s2),_weights(weights) {};

	MDOUBLE operator() (MDOUBLE dist) {
		MDOUBLE sumL=0.0;
		MDOUBLE sumR = 0.0;
		MDOUBLE sumR_d = 0.0;
		for (int pos=0; pos < _s1.seqLen(); ++pos){
			if ((_s1[pos] == -2) && (_s2[pos] == -2)) {	continue;} // two unknowns
			sumR = 0;
			sumR_d = 0;
			for (int rateCategor = 0; rateCategor<_sp.categories(); ++rateCategor) {
					MDOUBLE rate = _sp.rates(rateCategor);
					MDOUBLE pij= 0;
					MDOUBLE dpij=0;
					if ((_s1[pos] != -2) && (_s2[pos] != -2)) {
						pij= _sp.Pij_t(_s1[pos],_s2[pos],dist*rate);
						dpij= _sp.dPij_dt(_s1[pos],_s2[pos],dist*rate)*rate;
						if (pij==0) {
							pij = 0.000000001;
						}
					} else {
						pij = 1; // unknown pair with one.
						dpij =0;
					}

					MDOUBLE exp =  ((_s1[pos] != -2)?_sp.freq(_s1[pos]):_sp.freq(_s2[pos]))*_sp.ratesProb(rateCategor);
					sumR += pij *exp;
					sumR_d += dpij*exp;
			}
			assert(sumR!=0);
			sumL += (sumR_d/sumR)*(_weights ? (*_weights)[pos]:1);;
		}
		return -sumL;
	};
};

const MDOUBLE likeDist::giveDistance(const sequence& s1, const sequence& s2,
									 const MDOUBLE dis2evaluate) {
	C_evalLikeDistDirect Cev(_s1,s1,s2,NULL);
	return -Cev.operator ()(dis2evaluate);
}



const MDOUBLE likeDist::giveDistance2(const sequence& s1,
								const sequence& s2,
								const vector<MDOUBLE>  * weights,
								MDOUBLE* score) const {

	const MDOUBLE MAXDISTANCE=_maxPairwiseDistance;
//	const MDOUBLE PRECISION_TOLL=0.001;
	const MDOUBLE ax=0,bx=1.0,cx=MAXDISTANCE,tol=_toll;
	MDOUBLE dist=-1.0;
	MDOUBLE resL = -dbrent(ax,bx,cx,
		  C_evalLikeDistDirect(_s1,s1,s2,weights),
		  C_evalLikeDistDirect_d(_s1,s1,s2,weights),
		  tol,
		  &dist);
	if (score) *score = resL;
	return dist;
}
