#include "OptimizationParams.h"

void printDimensions(char * buffer, int size, const mxArray * prhs);
void printParams(int len, const mxArray * prhs[]);
void printSingleParams(int i, const mxArray * prhs);

OptimizationParams::OptimizationParams() { }

OptimizationParams::~OptimizationParams()
{
	if (itemCount != 0 && params != NULL)
	{
		for (int i = 0; i < itemCount; i++)
		{
			delete params[i];
		}

		delete params;
	}
}


int OptimizationParams::GetCount()
{
	return itemCount;
}

mexCellItem * OptimizationParams::GetParamCell(int i)
{
	return params[i];
}

void OptimizationParams::SetParamsWithMap(int nrhs, const mxArray * prhs[], IndexMap map)
{
	indexMap = map;
	SetParams(nrhs, prhs);
	SetCommand(GetParamCell(indexMap.command)->getString());
}
void OptimizationParams::SetParams(int nrhs, const mxArray * prhs[])
{
	itemCount = nrhs;
	params = new mexCellItem*[itemCount];
	for (int i = 0; i < itemCount; i++)
	{
		params[i] = new mexCellItem(prhs[i]);
	}
}

void OptimizationParams::SetParamsWithCommand(int nrhs, const mxArray * prhs[], int commandIndex)
{
	SetParams(nrhs, prhs);
	SetCommand(GetParamCell(commandIndex)->getString());
}

int OptimizationParams::Get_kt()
{
	if (indexMap.kt == -1)
	{
		mexErrMsgTxt("param.tt is not set.");
		return -1;
	}

	return (OptimizationParams::GetParamCell(indexMap.kt)->getInt() - 1); //Matlab to c++ Indexing 
}
int OptimizationParams::Get_tt()
{
	if (indexMap.tt == -1)
	{
		mexErrMsgTxt("param.tt is not set.");
		return -1;
	}

	return (OptimizationParams::GetParamCell(indexMap.tt)->getInt()-1); //Matlab to c++ Indexing 
}
int OptimizationParams::Get_i()
{
	if (indexMap.i == -1)
	{
		mexErrMsgTxt("param.i is not set.");
		return -1;
	}

	return (OptimizationParams::GetParamCell(indexMap.i)->getInt()-1);//Matlab to c++ Indexing 
}

int OptimizationParams::Get_j()
{
	if (indexMap.j == -1)
	{
		mexErrMsgTxt("j is not set.");
		return -1;
	}

	return (OptimizationParams::GetParamCell(indexMap.j)->getInt() - 1);//Matlab to c++ Indexing 
}

int OptimizationParams::Get_k()
{
	if (indexMap.k == -1)
	{
		mexErrMsgTxt("k is not set.");
		return -1;
	}

	return (OptimizationParams::GetParamCell(indexMap.k)->getInt() - 1);//Matlab to c++ Indexing 
}

int OptimizationParams::Get_l()
{
	if (indexMap.l == -1)
	{
		mexErrMsgTxt("l is not set.");
		return -1;
	}

	return (OptimizationParams::GetParamCell(indexMap.l)->getInt() - 1);//Matlab to c++ Indexing 
}

int OptimizationParams::Get_m()
{
	if (indexMap.m == -1)
	{
		mexErrMsgTxt("m is not set.");
		return -1;
	}

	return (OptimizationParams::GetParamCell(indexMap.m)->getInt() - 1);//Matlab to c++ Indexing 
}

printLevel OptimizationParams::Get_print_level()
{
	if(indexMap.printLevel < 0)
		return printLevel::enum_printLevel_unknown;

	char * commandLevel = OptimizationParams::GetParamCell(indexMap.printLevel)->getString();

	if (!strcmp(commandLevel, "numeric"))
		return printLevel::enum_printLevel_numeric;

	if (!strcmp(commandLevel, "equation"))
		return printLevel::enum_printLevel_equation;

	return printLevel::enum_printLevel_unknown;
}

double OptimizationParams::Get_lw_grid_1_val(int j, int tt)
{
	if (indexMap.lw_grid_1 == -1)
	{
		mexErrMsgTxt("param.lw_grid_1 is not set.");
		return -1;
	}

	return OptimizationParams::GetParamCell(indexMap.lw_grid_1)->getDouble(j,tt);
}
double OptimizationParams::Get_lw_grid_2_val(int k, int tt)
{
	if (indexMap.lw_grid_2 == -1)
	{
		mexErrMsgTxt("param.lw_grid_2 is not set.");
		return -1;
	}

	return OptimizationParams::GetParamCell(indexMap.lw_grid_2)->getDouble(k,tt);
}
double OptimizationParams::Get_u1_grid_val(int l, int tt)
{
	if (indexMap.u1_grid == -1)
	{
		mexErrMsgTxt("param.u1_grid is not set.");
		return -1;
	}

	return OptimizationParams::GetParamCell(indexMap.u1_grid)->getDouble(l,tt);
}
double OptimizationParams::Get_u2_grid_val(int m, int tt)
{
	if (indexMap.u2_grid == -1)
	{
		mexErrMsgTxt("param.u2_grid is not set.");
		return -1;
	}

	return OptimizationParams::GetParamCell(indexMap.u2_grid)->getDouble(m,tt);
}

double OptimizationParams::Get_L1_star_f_val(int j, int k, int l, int m)
{
	if (indexMap.L1_star_f == -1)
	{
		mexErrMsgTxt("param.L1_star_f is not set.");
		return -1;
	}

	int i = OptimizationParams::Get_i();
	return OptimizationParams::GetParamCell(indexMap.L1_star_f)->getDouble(i,j,k,l,m);
}
double OptimizationParams::Get_T1_star_f_val(int j, int k, int l, int m)
{
	if (indexMap.T1_star_f == -1)
	{
		mexErrMsgTxt("param.T1_star_f is not set.");
		return -1;
	}

	int i = OptimizationParams::Get_i();
	return OptimizationParams::GetParamCell(indexMap.T1_star_f)->getDouble(i, j, k, l, m);
}
double OptimizationParams::Get_L2_star_f_val(int j, int k, int l, int m)
{
	if (indexMap.L2_star_f == -1)
	{
		mexErrMsgTxt("param.T1_star_f is not set.");
		return -1;
	}

	int i = OptimizationParams::Get_i();
	return OptimizationParams::GetParamCell(indexMap.L2_star_f)->getDouble(i, j, k, l, m);
}
double OptimizationParams::Get_A_star_f_val(int j, int k, int l, int m)
{
	if (indexMap.A_star_f == -1)
	{
		mexErrMsgTxt("param.T1_star_f is not set.");
		return -1;
	}

	int i = OptimizationParams::Get_i();
	return OptimizationParams::GetParamCell(indexMap.A_star_f)->getDouble(i, j, k, l, m);
}

double OptimizationParams::GetStructParamDouble(const char * fieldName)
{
	if (indexMap.param_struct == -1)
	{
		mexErrMsgTxt("param.param_struct is not set.");
		return -1;
	}

	return OptimizationParams::GetParamCell(indexMap.param_struct)->getStructField(fieldName)->getDouble();
}

mexCellItem * OptimizationParams::GetStructParamCell(const char * fieldName)
{
	if (indexMap.param_struct == -1)
	{
		mexErrMsgTxt("param.param_struct is not set.");
		return nullptr;
	}

	mexCellItem * cell = OptimizationParams::GetParamCell(indexMap.param_struct);
	if (!cell->checkStructFieldExists(fieldName))
		return nullptr;

	return cell->getStructField(fieldName);
}

int OptimizationParams::GetStructParamInt(const char * fieldName)
{
	if (indexMap.param_struct == -1)
	{
		mexErrMsgTxt("param.param_struct is not set.");
		return -1;
	}

	return OptimizationParams::GetParamCell(indexMap.param_struct)->getStructField(fieldName)->getInt();
}

mexCommand OptimizationParams::GetCommand()
{
	if (indexMap.command == -1)
	{
		mexErrMsgTxt("param.command is not set.");
		return mexCommand::enum_unknown;
	}
	return command;
}

PartwiseLinFunc * OptimizationParams::Get_EUfunc(int j, int k, bool nullIfUnset)
{
	if (indexMap.EUFUNC == -1 && !nullIfUnset)
	{
		mexErrMsgTxt("param.EUFUNC is not set.");
		return nullptr;
	}
	mexCellItem * cell = GetParamCell(indexMap.EUFUNC);
	mexCellItem * innerCell = cell->getCell(j, k); //An inner cell is extracted here, and needs to be deleted 
	PartwiseLinFunc * func = new PartwiseLinFunc(innerCell);
	delete(innerCell); //The inner cell is deleted.
	return func;
}

PartwiseLinFunc * OptimizationParams::Get_EVfunc(int j, int k, bool nullIfUnset)
{
	if (indexMap.EVFUNC == -1 && !nullIfUnset)
	{
		mexErrMsgTxt("param.EVFUNC is not set.");
		return nullptr;
	}
	mexCellItem * cell = GetParamCell(indexMap.EVFUNC);
	mexCellItem * innerCell = cell->getCell(j, k); //An inner cell is extracted here, and needs to be deleted 
	PartwiseLinFunc * func = new PartwiseLinFunc(innerCell);
	delete(innerCell); //The inner cell is deleted.
	return func;
}

double OptimizationParams::Get_A_grid_val(int i)
{
	if (indexMap.A_grid == -1)
	{
		mexErrMsgTxt("param.A_grid is not set.");
		return -1;
	}

	return OptimizationParams::GetParamCell(indexMap.A_grid)->getDouble(i);
}

int OptimizationParams::Get_wpoints()
{
	if (indexMap.wpoints == -1)
	{
		mexErrMsgTxt("param.wpoints is not set.");
		return -1;
	}

	return OptimizationParams::GetParamCell(indexMap.wpoints)->getInt();
}

int OptimizationParams::Get_upoints()
{
	if (indexMap.upoints == -1)
	{
		mexErrMsgTxt("param.upoints is not set.");
		return -1;
	}

	return OptimizationParams::GetParamCell(indexMap.upoints)->getInt();
}

void OptimizationParams::SetCommand(const char * commandString)
{
	if (!strcmp("foc_int", commandString))
	{
		//Generate Parameter Map:
		IndexMap map = {
			map.command = 0,
			map.A_grid = 1,
			map.lw_grid_1 = 2,
			map.lw_grid_2 = 3,
			map.u1_grid = 4,
			map.u2_grid = 5,
			map.param_struct = 6,
			map.tt = 7,
			map.wpoints = 8,
			map.upoints = 9,
			map.i = 10,
			map.kt = 11, 
			map.EUFUNC = 12,
			map.EVFUNC = 13,
			map.L1_star_f = 14,
			map.T1_star_f = 15,  
			map.L2_star_f = -1,
			map.A_star_f = -1,
			map.j = -1, 
			map.k = -1, 
			map.l = -1,
			map.m = -1,
			map.printLevel = -1
		};
		indexMap = map;
		command = mexCommand::enum_foc_int;
		return;
	}

	if (!strcmp("solve_loop_E1_0", commandString))
	{
		//Generate Parameter Map:
		IndexMap map = {
			map.command = 0,
			map.A_grid = 1, //A_grid
			map.lw_grid_1 = 2, //lw_grid_1
			map.lw_grid_2 = 3, //lw_grid_2
			map.u1_grid = 4, // u1_grid
			map.u2_grid = 5, // u2_grid
			map.param_struct = 6, //parameters
			map.tt = 7, //tt
			map.wpoints = 8, //wpoints
			map.upoints = 9, //upoints
			map.i = 10, //i
			map.kt = 11, //kt 
			map.EUFUNC = 12, //EUFUNC, 
			map.EVFUNC = 13, //EVFUNC, 
			map.L1_star_f = 14, //L1_star_f, 
			map.T1_star_f = 15, //T1_star_f, 
			map.L2_star_f = 16,
			map.A_star_f = 17,
			map.j = -1, 
			map.k = -1, 
			map.l = -1,
			map.m = -1,
			map.printLevel = -1
		};
		indexMap = map;
		command = mexCommand::enum_solve_loop_E1_0;
		return;
	}

	if (!strcmp("solve_loop_E1_1", commandString))
	{
		//Generate Parameter Map:
		IndexMap map = {
			map.command = 0,
			map.A_grid = 1, //A_grid
			map.lw_grid_1 = 2, //lw_grid_1
			map.lw_grid_2 = 3, //lw_grid_2
			map.u1_grid = 4, // u1_grid
			map.u2_grid = 5, // u2_grid
			map.param_struct = 6, //parameters
			map.tt = 7, //tt
			map.wpoints = 8, //wpoints
			map.upoints = 9, //upoints
			map.i = 10, //i
			map.kt = 11, //kt 
			map.EUFUNC = 12, //EUFUNC, 
			map.EVFUNC = 13, //EVFUNC, 
			map.L1_star_f = 14, //L1_star_f, 
			map.T1_star_f = 15, //T1_star_f, 
			map.L2_star_f = 16,
			map.A_star_f = 17,
			map.j = -1, 
			map.k = -1,  
			map.l = -1,
			map.m = -1,
			map.printLevel = -1
		};
		indexMap = map;
		command = mexCommand::enum_solve_loop_E1_1;
		return;
	}
	
	if (!strcmp("solve_loop_E1_0_nk", commandString))
	{
		//Generate Parameter Map:
		IndexMap map = {
			map.command = 0,
			map.A_grid = 1, //A_grid
			map.lw_grid_1 = 2, //lw_grid_1
			map.lw_grid_2 = 3, //lw_grid_2
			map.u1_grid = 4, // u1_grid
			map.u2_grid = 5, // u2_grid
			map.param_struct = 6, //parameters
			map.tt = 7, //tt
			map.wpoints = 8, //wpoints
			map.upoints = 9, //upoints
			map.i = 10, //i
			map.kt = 11, //kt 
			map.EUFUNC = 12, //EUFUNC, 
			map.EVFUNC = 13, //EVFUNC, 
			map.L1_star_f = 14, //L1_star_f, 
			map.T1_star_f = 15, //T1_star_f, 
			map.L2_star_f = 16,
			map.A_star_f = 17,
			map.j = -1, 
			map.k = -1, 
			map.l = -1,
			map.m = -1,
			map.printLevel = -1
		};
		indexMap = map;
		command = mexCommand::enum_solve_loop_E1_0_nk;
		return;
	}

	if (!strcmp("solve_loop_E1_1_nk", commandString))
	{
		//Generate Parameter Map:
		IndexMap map = {
			map.command = 0,
			map.A_grid = 1, //A_grid
			map.lw_grid_1 = 2, //lw_grid_1
			map.lw_grid_2 = 3, //lw_grid_2
			map.u1_grid = 4, // u1_grid
			map.u2_grid = 5, // u2_grid
			map.param_struct = 6, //parameters
			map.tt = 7, //tt
			map.wpoints = 8, //wpoints
			map.upoints = 9, //upoints
			map.i = 10, //i
			map.kt = 11, //kt 
			map.EUFUNC = 12, //EUFUNC, 
			map.EVFUNC = 13, //EVFUNC, 
			map.L1_star_f = 14, //L1_star_f, 
			map.T1_star_f = 15, //T1_star_f, 
			map.L2_star_f = 16,
			map.A_star_f = 17,
			map.j = -1, 
			map.k = -1, 
			map.l = -1,
			map.m = -1,
			map.printLevel = -1
		};
		indexMap = map;
		command = mexCommand::enum_solve_loop_E1_1_nk;
		return;
	}

	if (!strcmp("foc_int_terminal_nk", commandString))
	{
		//Generate Parameter Map:
		IndexMap map = {
			map.command = 0,
			map.A_grid = 1,
			map.lw_grid_1 = 2,
			map.lw_grid_2 = 3,
			map.u1_grid = 4,
			map.u2_grid = 5,
			map.param_struct = 6,
			map.tt = 7,
			map.wpoints = 8,
			map.upoints = 9,
			map.i = 10,
			map.kt = 11, //kt 
			map.EUFUNC = 12,
			map.EVFUNC = 13,
			map.L1_star_f = 14,//L1_star_f, 
			map.T1_star_f = 15, //T1_star_f, 
			map.L2_star_f = -1,
			map.A_star_f = -1,
			map.j = -1, 
			map.k = -1,  
			map.l = -1,
			map.m = -1,
			map.printLevel = -1
		};
		indexMap = map;
		command = mexCommand::enum_foc_int_terminal_nk;
		return;
	}

	if (!strcmp("foc_calc_single", commandString))
	{
		//Generate Parameter Map:
		IndexMap map = {
			map.command = 0,
			map.A_grid = 1, //A_grid
			map.lw_grid_1 = 2, //lw_grid_1
			map.lw_grid_2 = 3, //lw_grid_2
			map.u1_grid = 4, // u1_grid
			map.u2_grid = 5, // u2_grid
			map.param_struct = 6, //parameters
			map.tt = 7, //tt
			map.wpoints = 8, //wpoints
			map.upoints = 9, //upoints
			map.i = 10, //i
			map.kt = 11, //kt 
			map.EUFUNC = 12, //EUFUNC, 
			map.EVFUNC = 13, //EVFUNC, 
			map.L1_star_f = 14, //L1_star_f, 
			map.T1_star_f = 15, //T1_star_f, 
			map.L2_star_f = 16,
			map.A_star_f = 17,
			map.j = 18, 
			map.k = 19, 
			map.l  = 20,
			map.m = 21,
			map.printLevel = 22
		};
		indexMap = map;

		command = mexCommand::enum_foc_calc_single;
		return;
	}
}

void printParams(int len, const mxArray * prhs[])
{
	for (int i = 0; i < len; i++) {
		printSingleParams(i, prhs[i]);
	}
}

void printSingleParams(int i, const mxArray * prhs)
{
	const int buffSize = 1000;
	char buffer[buffSize];
	printDimensions(buffer, buffSize, prhs);
	int total_num_of_elements = mxGetNumberOfElements(prhs);
	mexPrintf("Param[%i] := %s[%s] - number of elements = %i ", i, mxGetClassName(prhs), buffer, total_num_of_elements);
	if (mxGetClassID(prhs) == mxSTRUCT_CLASS)
	{
		mexPrintf("\n");
		
		int number_of_fields = mxGetNumberOfFields(prhs);
		
		for (int index = 0; index < total_num_of_elements; index++) {
			mexPrintf("Struct Content: \n");
			for (int field_index = 0; field_index < number_of_fields; field_index++) {
				mexPrintf("(Field) %s: ", mxGetFieldNameByNumber(prhs, field_index));
				const mxArray * field_array_ptr = mxGetFieldByNumber(prhs, index, field_index);
				if (field_array_ptr != NULL) {
					printSingleParams(field_index, field_array_ptr);
				}
			}
			mexPrintf("End struct.\n");
		}
	}
	else
	{
		if (total_num_of_elements == 400) {
			mwIndex indexItem = 1;
			mexPrintf("Second element Pr is: %f, ", ((double*)mxGetData(prhs))[indexItem]);
		}
		mexPrintf("\n");
	}
}

void printDimensions(char * buffer, int size, const mxArray * prhs)
{
	int number_of_dimensions = mxGetNumberOfDimensions(prhs);
	const size_t * dims = mxGetDimensions(prhs);

	char * shape_string = (char *)mxCalloc(number_of_dimensions * 3, sizeof(char));
	shape_string[0] = '\0';
	char * temp_string = (char *)mxCalloc(64, sizeof(char));

	for (int c = 0; c<number_of_dimensions; c++) {
		sprintf(temp_string, "%i;", dims[c]);
		strcat(shape_string, temp_string);
	}

	int length_of_shape_string = strlen(shape_string);
	shape_string[length_of_shape_string - 1] = '\0';
	if (length_of_shape_string > 16) {
		sprintf(shape_string, "%u-D", number_of_dimensions);
	}
	sprintf_s(buffer, size, "%s", shape_string);

	/* free up memory for shape_string */
	mxFree(shape_string);
}