% This wrapper file is used for the calculation of SE for panel b of table 3, 
% around a set of point estimates

clc
clear all

%% Set up the parallel environment
npool=7;
p = gcp('nocreate'); 
if isempty(p)
    poolsize = 0;
else
    poolsize = p.NumWorkers;
end
if poolsize <npool
    delete(gcp('nocreate'))
    parpool('local',npool)
end
warning('off','MATLAB:mir_warning_maybe_uninitialized_temporary');

%% Set if saving data at each run
savedata= 0;    % Choose 1 to save the entire output for the function solution

%% Set if run IR as part of the solution
irflag= 0; % Choose 1 to run the IR which is used for the calculation of elasticities (and save the results), and 0 otherwise

%% Set non-estimated parameters according to MRS estimation and parametrization
PSI(1) = 0.211;   % psi_l1
PSI(2) = 0.162;   % psi_l2
PSI(3) = 0.535;   % rho_l
PSI(4) = 0.107;   % psi_t1
PSI(5) = 0.411;   % psi_t2
PSI(6) = -0.288;  % rho_t
PSI(7) = 0.903;   % eta_cp

chi = [2.2 2.2];    % level parameter in the tax function
taumu=[0.1 0.1];    % slope parameter in the tax function
tau_b=[12*(526+476)/4160 12*367/4160];    % minimum pre-tax transfers that the household gets regardless of work. Apply the scaling by 4160


%% run once with THETA_0 to confirm point estimates 
THETA_0=csvread('THETA.csv');   % Plug here the theta with the point estimates
dirname_estimates = main_estimation_se(THETA_0,PSI,chi,taumu,tau_b,savedata,irflag);
load(strcat(dirname_estimates,'/estimation_output'))
THETA_0 = THETA';
diff_0 = moments_out(:,2)-moments_out(:,1); 

%% Loop for 4 calculations of derivatives of the moments w.r.t. each parameter in THETA
derivatives_mom = zeros(11,length(THETA_0),4); 

for i=1:length(THETA_0)
    THETA = THETA_0;
    diff_s_p = 0.05*THETA_0(i); % 5%
    diff_s_m = -0.05*THETA_0(i); % -5%
    diff_l_p = 0.1*THETA_0(i); % 10%
    diff_l_m = -0.1*THETA_0(i); % -10%

    for j=1:4
        if j==1 
            delta=diff_s_p;
        elseif j==2
            delta=diff_s_m;
        elseif j==3
            delta=diff_l_p;
        elseif j==4
            delta=diff_l_m;
        end
    
    THETA(i) = THETA_0(i) + delta;
    dirname = main_estimation_se(THETA,PSI,chi,taumu,tau_b,savedata,irflag);
    
    % Fill the derivative matrices
    load(strcat(dirname,'/estimation_output'))
    
    diff = moments_out(:,2)-moments_out(:,1);
    derivatives_mom(:,i,j) = (diff - diff_0)/delta;
    
    end    
end

save derivatives_raw_THETA

%% Loop for 4 calculations of derivatives of the moments w.r.t. each parameter in PSI (first stage estimates)
% note order of parameters in the data-psi covariance matrix: 
% psi_L2, eta_cp, psi_L1, rho_L,  psi_T2, psi_T1, rho_T which are:
% PSI(2), PSI(7), PSI(1), PSI(3), PSI(5), PSI(4), PSI(6)    

derivatives_mom_PSI = zeros(11,length(PSI),4); 

PSI_0 = PSI;

for i=1:length(PSI)
    THETA = THETA_0;
    PSI = PSI_0;
    if i==1; k=2; elseif i==2; k=7; elseif i==3; k=1; elseif i==4; k=3; elseif i==5; k=5; elseif i==6; k=4; elseif i==7; k=6; end
    diff_s_p = 0.05*PSI_0(k); % 5%
    diff_s_m = -0.05*PSI_0(k); % -5%
    diff_l_p = 0.1*PSI_0(k); % 10%
    diff_l_m = -0.1*PSI_0(k); % -10%

    for j=1:4
        if j==1 
            delta=diff_s_p;
        elseif j==2
            delta=diff_s_m;
        elseif j==3
            delta=diff_l_p;
        elseif j==4
            delta=diff_l_m;
        end
    
    PSI(k) = PSI_0(k) + delta;
    dirname = main_estimation_se(THETA,PSI,chi,taumu,tau_b,savedata,irflag);
    
    % Fill the derivative matrices
    load(strcat(dirname,'/estimation_output'))
    
    diff = moments_out(:,2)-moments_out(:,1);
    derivatives_mom_PSI(:,i,j) = (diff - diff_0)/delta;

    end    
end

save derivatives_raw

derivatives_mom_final = mean(derivatives_mom,3);
derivatives_mom_final_PSI = mean(derivatives_mom_PSI,3);

%% Calculate SE for SMM point estimates - incorporating the correction as in GP
W = diag(diag(csvread('SIGMA.csv')))^-1;
SIGMA    = csvread('SIGMA.csv');
VAR_psid = csvread('SIGMA_psid.csv');
VAR_atus = csvread('SIGMA_atus.csv');
zers = zeros(size(VAR_psid,1),size(VAR_atus,1));
VAR_first = [VAR_psid  zers;...
            zers'    VAR_atus];
G = derivatives_mom_final;
G_PSI = derivatives_mom_final_PSI;

N_first_inv = mean([11195 11195 11195 11195 2921 2921 2921])^-1;
N_smm   = mean(csvread('moments_N.csv'));
VAR =((G'*W*G)^-1)*(G'*W*(SIGMA + N_smm*G_PSI*N_first_inv*VAR_first*G_PSI')*W*G)*((G'*W*G)^-1); 
SE_smm  =sqrt(diag(VAR));

% Delta method for variance of log of estimates
logder = diag(1./THETA_0); 
VAR_logder = logder*VAR*logder;
SE_smm_log = sqrt(diag(VAR_logder));

dlmwrite('SE_table3_panel_B.csv',SE_smm);
dlmwrite('SE_table3_panel_B_log.csv',SE_smm_log);

