% This function takes as input the simulated shocks, and adds to those
% shocks an aggregate tax shift which is either transitory or permanent. It
% then re-simulates for all households with the additional shocks. 
% Getting u1, u2, v1 and v2 from simul_output as inputs is ensures that when comparing
% economies for elasticities calculations, the aggregate tax change in the only thing that matters.


function IR_output = simul_all_IR(parameters,simul_output,wage_grids_outputs,output_vf,shock_u1,shock_u2,shock_v1,shock_v2,shock_t)

%% Read parameters
TT     = parameters.TT      ;
eta_cp = parameters.eta_cp ;
mu_0  = parameters.mu_0  ;
psi_l1 = parameters.psi_l1 ;
psi_l2 = parameters.psi_l2 ;
rhou_l = parameters.rhou_l ;
psi_t1 = parameters.psi_t1 ;
psi_t2 = parameters.psi_t2 ;
rhou_t = parameters.rhou_t ;
mu_t1 = parameters.mu_t1 ;
mu_t2 = parameters.mu_t2 ;
amax   = parameters.amax   ;
amin   = parameters.amin   ;
apoints= parameters.apoints;
L_bar  = parameters.L_bar;
beta   = parameters.beta   ;
r      = parameters.r      ;        
div    = parameters.div;
wpoints= parameters.wpoints;
theta  = parameters.theta  ;
n      = parameters.n ;
KT     = parameters.KT ;
CHI    = parameters.chi    ;
TAUMU  = parameters.taumu  ;
TAU_B  = parameters.tau_b  ;
TAU_GAMMA= parameters.tau_gamma;
A0_pzero = parameters.A0_pzero;
A0_mu    = parameters.A0_mu;
A0_sig   = parameters.A0_sig;

log_w_det_1_source= wage_grids_outputs.log_w_det_1;
log_w_det_2_source= wage_grids_outputs.log_w_det_2;
lw_grid_1  = wage_grids_outputs.lw_grid_1 ;
lw_grid_2  = wage_grids_outputs.lw_grid_2 ;
u1_grid    = wage_grids_outputs.u1_grid   ;
u2_grid    = wage_grids_outputs.u2_grid   ;

A_grid   = output_vf.A_grid;
A_star_E2_0 = output_vf.A_star_E2_0;
A_star_E2_1 = output_vf.A_star_E2_1;
A_star   = output_vf.A_star;
L1_star_E2_0  = output_vf.L1_star_E2_0;
L1_star_E2_1  = output_vf.L1_star_E2_1;
L1_star  = output_vf.L1_star;
L2_star_E2_0  = output_vf.L2_star_E2_0;
L2_star_E2_1  = output_vf.L2_star_E2_1;
L2_star  = output_vf.L2_star;
T1_star_E2_0  = output_vf.T1_star_E2_0;
T1_star_E2_1  = output_vf.T1_star_E2_1;
T1_star  = output_vf.T1_star;
T2_star_E2_0  = output_vf.T2_star_E2_0;
T2_star_E2_1  = output_vf.T2_star_E2_1;
T2_star  = output_vf.T2_star;
V_star_E2_1 = output_vf.V_star_E2_1;
V_star_E2_0 = output_vf.V_star_E2_0;
u1 = simul_output.u1_simul;
u2 = simul_output.u2_simul;
v1 = simul_output.v1_simul;
v2 = simul_output.v2_simul;
w0_1 = simul_output.w0_1_simul;
w0_2 = simul_output.w0_2_simul;
A_simultemp = simul_output.A_simul;

%% Add the tax shock at period specidied by user
u1(:,shock_t) = u1(:,shock_t) + shock_u1;
u2(:,shock_t) = u2(:,shock_t) + shock_u2;
v1(:,shock_t) = v1(:,shock_t) + shock_v1;
v2(:,shock_t) = v2(:,shock_t) + shock_v2;

%% Simulate the wage processes for husband and wife
log_w_det_1=log_w_det_1_source(2:TT+1)';  
log_w_det_2=log_w_det_2_source(2:TT+1)';  

F1 = zeros(n,TT);
F2 = zeros(n,TT);
F1(:,1) = w0_1 + v1(:,1);
F2(:,1) = w0_2 + v2(:,1);

for tau = 1:TT-1
   F1(:,tau+1) = F1(:,tau) + v1(:,tau+1); 
   F2(:,tau+1) = F2(:,tau) + v2(:,tau+1); 
end
F1 = ones(n,1)*log_w_det_1 + F1; 
F2 = ones(n,1)*log_w_det_2 + F2; 

% In the very few cases that the simulated wages go off the grid, assign
% them the maximal or minimal point on the grid
maxtemp1 = ones(n,1)*max(lw_grid_1);
maxtemp2 = ones(n,1)*max(lw_grid_2);
F1(F1>maxtemp1)= maxtemp1(F1>maxtemp1);
F2(F2>maxtemp2)= maxtemp2(F2>maxtemp2);

mintemp1 = ones(n,1)*min(lw_grid_1);
mintemp2 = ones(n,1)*min(lw_grid_2);
F1(F1<mintemp1)= mintemp1(F1<mintemp1);
F2(F2<mintemp2)= mintemp2(F2<mintemp2);

%same for transitory shocks
maxtemp1 = ones(n,1)*max(u1_grid);
maxtemp2 = ones(n,1)*max(u2_grid);
u1(u1>maxtemp1)= maxtemp1(u1>maxtemp1);
u2(u2>maxtemp2)= maxtemp2(u2>maxtemp2);

mintemp1 = ones(n,1)*min(u1_grid);
mintemp2 = ones(n,1)*min(u2_grid);
u1(u1<mintemp1)= mintemp1(u1<mintemp1);
u2(u2<mintemp2)= mintemp2(u2<mintemp2);

lw_1 = F1 + u1; 
lw_2 = F2 + u2; 

%% Prepare matrices
V_E2_1_simul = zeros(size(F1));
V_E2_0_simul = zeros(size(F1));
A_simul  = zeros(size(F1));
L1_simul = zeros(size(F1));
L2_simul = zeros(size(F1));
T1_simul = zeros(size(F1));
T2_simul = zeros(size(F1));
H1_simul = zeros(size(F1));
H2_simul = zeros(size(F1));
Y1_simul = zeros(size(F1));
Y2_simul = zeros(size(F1));
Y_simul  = zeros(size(F1));
atY_simul= zeros(size(F1));
C_simul  = zeros(size(F1));
kt_simul = zeros(size(F1));
E2_simul = zeros(size(F1));

% Set first period assets as in the other simulations
kt   = KT(1);
taumu= TAUMU(kt);
sc_C = 4160^(1-taumu);
A_simul(:,1) = A_simultemp(:,1)/sc_C;
clear taumu sc_C kt

%% Use the interpolated VF for employment choice and choose the right assets accordingly
for tau=1:TT
    V_E2_1_simul(:,tau) = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(V_star_E2_1(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
    V_E2_0_simul(:,tau) = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(V_star_E2_0(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
    E2_simul(:,tau) = (V_E2_1_simul(:,tau)>=V_E2_0_simul(:,tau));
    temp1 = max(interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                                 squeeze(A_star_E2_1(:,:,:,:,:,tau)),...
                                 A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau)),amin);
    temp1 = min(temp1,max(A_grid));                         
    temp0 = max(interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                                 squeeze(A_star_E2_0(:,:,:,:,:,tau)),...
                                 A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau)),amin);
    temp0 = min(temp0,max(A_grid));
    if tau<TT
        A_simul(:,tau+1) = temp1.*E2_simul(:,tau) + temp0.*(1-E2_simul(:,tau));
    end
                          
end

A_simul(:,TT+1) = 0;    % terminal condition for assets. Zero at the end of the last period (or beginning of the TT+1 period)

%% Simulate time use goods using the policy functions, and income, consumption and hours using the budget and time constraints
for tau=1:TT
    temp1 = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(L1_star_E2_1(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
    temp0 = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(L1_star_E2_0(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
	L1_simul(:,tau) = temp1.*E2_simul(:,tau) + temp0.*(1-E2_simul(:,tau));
                          
    temp1 = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(L2_star_E2_1(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
    temp0 = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(L2_star_E2_0(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
    L2_simul(:,tau) = temp1.*E2_simul(:,tau) + temp0.*(1-E2_simul(:,tau));

    temp1 = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(T1_star_E2_1(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
    temp0 = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(T1_star_E2_0(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
    T1_simul(:,tau) = temp1.*E2_simul(:,tau) + temp0.*(1-E2_simul(:,tau));

    temp1 = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(T2_star_E2_1(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
    temp0 = interpn(A_grid,lw_grid_1(:,tau),lw_grid_2(:,tau),u1_grid(:,tau),u2_grid(:,tau),...
                              squeeze(T2_star_E2_0(:,:,:,:,:,tau)),...
                              A_simul(:,tau),F1(:,tau),F2(:,tau),u1(:,tau),u2(:,tau));
    T2_simul(:,tau) = temp1.*E2_simul(:,tau) + temp0.*(1-E2_simul(:,tau));
end
for tau=1:TT
    kt   = KT(tau);

    chi  = CHI(kt);
    taumu= TAUMU(kt);
    tau_b= TAU_B(kt);
    tau_gamma= TAU_GAMMA(kt);

    H1_simul(:,tau)= L_bar - L1_simul(:,tau) - T1_simul(:,tau);
    H2_simul(:,tau)= L_bar - L2_simul(:,tau) - T2_simul(:,tau);
    Y1_simul(:,tau)= exp(F1(:,tau)+u1(:,tau)).*H1_simul(:,tau);
    Y2_simul(:,tau)= exp(F2(:,tau)+u2(:,tau)).*H2_simul(:,tau);
    Y_simul(:,tau) = tau_b + Y1_simul(:,tau)+Y2_simul(:,tau);
    atY_simul(:,tau) = chi.*Y_simul(:,tau).^(1-taumu);
    C_simul(:,tau)= A_simul(:,tau) + atY_simul(:,tau) - (1/(1+r))*A_simul(:,tau+1);
end

% We rescale here to non-normalized numbers
for tau=1:TT
    kt   = KT(tau);
    taumu= TAUMU(kt);

    sc_L = 4160;
    sc_C = 4160^(1-taumu);

    Y1_simul(:,tau)  = sc_L*Y1_simul(:,tau) ;
    Y2_simul(:,tau)  = sc_L*Y2_simul(:,tau) ;
    Y_simul(:,tau)   = sc_L*Y_simul(:,tau)   ;
    atY_simul(:,tau) = sc_C*atY_simul(:,tau);
    C_simul(:,tau)   = sc_C*C_simul(:,tau)   ;
    A_simul(:,tau)   = sc_C*A_simul(:,tau)   ;
    H1_simul(:,tau)  = sc_L*H1_simul(:,tau) ;
    H2_simul(:,tau)  = sc_L*H2_simul(:,tau) ;
    L1_simul(:,tau)  = sc_L*L1_simul(:,tau) ;
    L2_simul(:,tau)  = sc_L*L2_simul(:,tau) ;
    T1_simul(:,tau)  = sc_L*T1_simul(:,tau) ;
    T2_simul(:,tau)  = sc_L*T2_simul(:,tau) ; 
    kt_simul(:,tau)  = kt ; 

end

%% Assign to struct
IR_output.F1_simul = F1;
IR_output.F2_simul = F2;
IR_output.u1_simul = u1;
IR_output.u2_simul = u2;
IR_output.v1_simul = v1;
IR_output.v2_simul = v2;
IR_output.w0_1_simul = w0_1;
IR_output.w0_2_simul = w0_2;
IR_output.lw_1_simul = lw_1;
IR_output.lw_2_simul = lw_2;
IR_output.Y1_simul  = Y1_simul;
IR_output.Y2_simul  = Y2_simul;
IR_output.Y_simul  = Y_simul;
IR_output.atY_simul= atY_simul;
IR_output.C_simul  = C_simul;
IR_output.A_simul  = A_simul;
IR_output.H1_simul = H1_simul;
IR_output.H2_simul = H2_simul;
IR_output.L1_simul = L1_simul;
IR_output.L2_simul = L2_simul;
IR_output.T1_simul = T1_simul;
IR_output.T2_simul = T2_simul;
IR_output.kt_simul = kt_simul;
IR_output.E2_simul = E2_simul;


end