#include <cassert>
#include <iostream>
#include <cmath>
#include <algorithm>
using namespace std;

#include "MRateDiscrete.h"
#include "DrateDiscrete.h"
#include "McRateUtils.h"

#include "errorMsg.h"
#include "someUtil.h"

//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////


MRateDiscrete::MRateDiscrete()
: m_distribution(0), m_expectation(0)
{
}

MRateDiscrete::MRateDiscrete(const MRateDiscrete& other)
{
	int catNum = other.size();
	m_distribution.resize(catNum);
	
	int i;
	for (i = 0; i < m_distribution.size(); ++i)
	{
		m_distribution[i] = other.m_distribution[i];
	}

	int totalWeight = m_weight + other.m_weight;
	m_expectation = (m_expectation * m_weight) + (other.m_expectation * other.m_weight); 
	m_expectation /= totalWeight;
	m_weight = totalWeight;
}



MRateDiscrete::~MRateDiscrete()
{
	m_distribution.clear();
}

//return the number of rates added
int MRateDiscrete::size() const
{
	return m_distribution.size();
}


//addDRate: adds a DrateDiscrete (k discrete rate categories which represent the distribution) to the cummulative distribution 
void MRateDiscrete::addDRate(const Drates * pInRates)
{
	const DrateDiscrete* pRates = static_cast<const DrateDiscrete *>(pInRates);
	int rateNum = pRates->size();
	int ri;
	for (ri = 0; ri < rateNum; ++ri)
	{
		m_distribution.push_back(pRates->getRateP(ri));
		//@@@@ why not m_distribution.push_back(pRates[ri]);
	}
	

	MDOUBLE inExp = pRates->getExpectation();
	m_expectation = ((m_expectation * m_weight) + inExp) / (m_weight + 1);
	m_weight++; 
}




//addMetaRate: adds a whole MRateDiscrete to this
//if bSameWeight==FALSE then add according to relative weight of each distributon (how many Drates it "contains")
//if bSameWeight==TRUE then make a simple adition. the new m_weight is doubled
void MRateDiscrete::addMetaRate(const MetaRates * pMetaOther, bool bSameWeight)
{
	const MRateDiscrete* pOther = static_cast<const MRateDiscrete *>(pMetaOther);

	int otherRateNum = pOther->size(); 

	if (bSameWeight == true)
	{
		assert(false);
	}
	else
	{
		int i;
		for (i = 0; i < otherRateNum; ++i)
		{
			m_distribution.push_back(pOther->m_distribution[i]);
		}
		m_weight += pOther->m_weight; 
	}

	//check that total probability is 1.0
	MDOUBLE sum = 0.0;
	int i;
	for (i = 0; i < m_distribution.size(); ++i)
	{
		sum += m_distribution[i].getProb();
	}

	sum /= m_weight;
	if (!DEQUAL(sum, 1.0))
		errorMsg::reportError("total probability is not 1.0 in function MRateDiscrete::addMetaRate()"); 
}



//getExpectation: gets the expectation of the distribution
MDOUBLE MRateDiscrete::getExpectation() const
{
	return m_expectation;
}

//calcExpectation: calculates the expectation of the whole distribution
MDOUBLE MRateDiscrete::calcExpectation() const
{
	MDOUBLE res = 0;
	int rateNum = size();
	MDOUBLE p, r;
	int ri;
	for (ri = 0; ri < rateNum; ++ri)
	{
		p = m_distribution[ri].getProb();
		r = m_distribution[ri].getRate();
		res += p * r;
	}
	res /= m_weight;

	if (!DEQUAL(res, m_expectation))
		errorMsg::reportError("calculated expectation is not same as m_expectation in MRateDiscrete::calcExpectation()"); 

	return res;
}



//getStd: returns the standard deviation of MRateDiscrete
//std = sqrt(E[x^2] - E[x]^2)
MDOUBLE MRateDiscrete::getStd() const
{
	MDOUBLE r, Pr, Ex = 0.0, Ex2 = 0.0;
	int rateNum = size();
	int i;
	for (i = 0; i < rateNum; ++i)
	{
		r = m_distribution[i].getRate();
		Pr = m_distribution[i].getProb();
		Ex += r * Pr;
		Ex2 += r * r * Pr;
	}

	Ex /= m_weight;
	Ex2 /= m_weight;
	MDOUBLE var = Ex2 - (Ex * Ex);
	if (var < 0.0)
		errorMsg::reportError("variance is negative in in function MRateDiscrete::getStd()");

	return sqrt(var);
}


void MRateDiscrete::printDistribution(ofstream &outFile)
{
	printTime(outFile);
	int rateNum = size();
	
	sort(m_distribution.begin(), m_distribution.end());
	outFile <<"i"<<"\t" <<"rate"<<"\t"<<"prob" << endl;

	int i;
	for (i = 0; i < rateNum; ++i)
	{
		outFile << i <<"\t" <<m_distribution[i].getRate()<<"\t" << m_distribution[i].getProb() << endl;
	}
	
}
