function B = vaidya(A, partition_goal)
% [B, order] = vaidya(A, partition_goal)
%
% Form the Vaidya preconditioner.
% Inputs:
%   A - the actual matrix.
%   (optional) partition_goal - vaidya's partition goal (when splitting the tree).
% Output:
%   B - The preconditioner matrix (not factored).


% params - complete to default if need be
if nargin < 2
    partition_goal = 0.8;
end

% First convert to a graph.
[GA, diagonal_overweight] = convert_to_graph(A);

% Build the preconditioner
GB = sparsify_graph(GA, partition_goal);

% Convert back to graph
B = convert_to_matrix(GB, diagonal_overweight);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%  INTERNAL FUNCTIONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [G, diagonal_overweight] = convert_to_graph(A)
% Converts the matrix to a graph when it is assumed that the matrix has the
% following properties:
%   1) All off-diagonal values are negative.
%   2) for all i we have A(i,i) >= sum (j ~= i) of -A(i,j)
% In the graph edge (i, j) should have weight A(i,j).
% The diagonal "overweight" should be stored in the output parameter
% diagonal_overweight, that is:
%   diagonal_overweight(i) = A(i,i) - [sum (j ~= i) of -A(i,j)].

% Format of G:
% G.n - the number of vertices. vertices are 1..G.n
% G.edges is an Mx3 matrix where M is the number of edges.
% each row of G.edges represents a single edge. Row [i j w] implies that
% that (i,j) is an edge with weight w.

G.n = size(A, 1);

% Edge list
[i, j, w] = find(A);
c = i > j;
G.edges = [i(c) j(c) -w(c)];
diagonal_overweight = full(sum(A, 2));

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function GB = sparsify_graph(GA, partition_goal)
% This function sparsifies the graph using Vaidya's method.

% Keep this! This sets the recursion level to be big enough. 
old_recursion_limit = get(0, 'RecursionLimit');
set(0, 'RecursionLimit', max(GA.n, old_recursion_limit));

% STUB
GB = GA;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function B = convert_to_matrix(G, diagonal_overweight)
% Convert matrix back to graph and add the diagonal overweight

u = [G.edges(:, 1); G.edges(:, 1); G.edges(:, 2); G.edges(:, 2); (1:G.n)'];
v = [G.edges(:, 1); G.edges(:, 2); G.edges(:, 1); G.edges(:, 2); (1:G.n)'];
w = [G.edges(:, 3); -G.edges(:, 3); -G.edges(:, 3); G.edges(:, 3); diagonal_overweight];
B = sparse(u, v, w, G.n, G.n);
    
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%  UNION FIND
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%% Example of using union-find:
%
% Make the sets:
% ==============
% ufs = makesets(5);
%
% Unite sets:
% ===========
% ufs = unionsets(ufs, 1, 2);
% ufs = unionsets(ufs, 1, 4);
% ufs = unionsets(ufs, 3, 5);
%
% Set representatives and checking if two objects are of the same set
% ===================================================================
% [ufs, rep1] = findset(ufs, 1);
% [ufs, rep4] = findset(ufs, 4);
% [ufs, rep3] = findset(ufs, 3);
% At the end:
% rep1 and rep4 will be equal.
% rep1 and rep3 will be different.


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function ufs = makesets(n)
% Builds a set of sets, sized n, that can be used in a UNION-FIND setting.
ufs.parent = 1:n;
ufs.rank = zeros(1, n);    

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function ufs = unionsets(ufs, x, y)
% Inside ufs unites the sets that x and y are member of

[ufs, xp] = findset(ufs, x);
[ufs, yp] = findset(ufs, y);
ufs = link(ufs, xp, yp);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [ufs, set] = findset(ufs, x)
% Finds the set representative of x inside the the sets

if (x ~= ufs.parent(x))
    [ufs, ufs.parent(x)] = findset(ufs, ufs.parent(x));
end
set = ufs.parent(x);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function ufs = link(ufs, x, y)
% Internal to union-find

if (ufs.rank(x) > ufs.rank(y))
    ufs.parent(y) = x;
else
    ufs.parent(x) = y;
    if (ufs.rank(x) == ufs.rank(y))
        ufs.rank(y) = ufs.rank(y) + 1;
    end
end