function [indices, objFun, moreOutput] = hosc_ejs(X,d,K,opts)

% Higher Order Spectral Clustering (HOSC)
%
% USAGE
%   [indices, objFun] = hosc(X,d,K,opts)
%
% INPUT
%   X: NxD data matrix, rows are points
%     (normalized to be within the unit cube [0,1]^D)
%   d: common dimension of the manifolds
%   K: number of manifolds
%   opts: a structure array of the following optional parameters:
%       .m: m is the number of points used for computing affinity measures
%          (default = d+2)
%       .kernel: 'heat'(default) or 'simple' kernel used in the affinity
%       .knn: number of nearest neighbors (default = 10)
%       .KNNtype: type of k nearest neighbors
%          'oneway': assigns an edge if one point is the knn of the other
%             OR vice versa
%          'mutual': assigns an edge if one point is the knn of the other 
%             AND vice versa (default)
%       .eta: flatness parameter, a 2-vector containing upper and lower bounds
%          (default = [0.1 0.001]), from which the optimal eta is selected.
%       .nTuples: number of (m-2)-tuples used to compute the pairwise weights
%          W(i,j); default = nchoosek(knn-1, m-2), i.e., to use all tuples
%       .alpha: percentage of outliers (if <1); number (if >=1)
%
% OUTPUT
%   indices: a vector of group labels from (1,2,...K) of the data points
%       (outliers have label zero)
%   objFun: smallest total variance of the clusters in the eigenspace 
%       (when without outliers) or largest degree gap between inliers and 
%       outliers (for outliers detection)
%
% (c)2011 Ery Arias-Castro, Guangliang Chen and Gilad Lerman
%   Questions about the code should be directed to glchen@math.duke.edu.

%% set optional parameters
ABSOLUTE_MINIMUM = 1e-16;

if nargin<3, 
    error('X,d,K must all be provided!');
end

X = X./max(max(abs(X)));

if nargin<4
    opts = struct();
end

if ~isfield(opts, 'm')
    opts.m = d+2;
end

if ~isfield(opts,'kernel')
    opts.kernel = 'heat';
end

if ~isfield(opts,'knn')
    opts.knn = 10; 
end

if ~isfield(opts,'KNNtype')
    opts.KNNtype = 'mutual';
end

if ~isfield(opts, 'eta')
   opts.eta = [1e-1 1e-3]; % upper and lower bounds, resp.
else
    if opts.eta(1)<opts.eta(2)
        warning('wrong ordering!') %#ok<WNTAG>
        opts.eta = opts.eta([2 1]);
    end
end

nTuplesMax = nchoosek(opts.knn-1, opts.m-2);

if ~isfield(opts, 'nTuples')
    opts.nTuples = nTuplesMax;
end

if ~isfield(opts,'alpha')
    opts.alpha = 0;
end
    
%% find nearest neighbors for all points
N = size(X,1);

% fast knn search
[~, neighbors] = nrsearch(X', [], opts.knn+1, [], [], struct('ReturnAsArrays',1));

% %When the data set has a moderate or small size, can use the following
% %lines instead for finding nearest neighbors.
% G = X*X';
% lengths = diag(G);
% dists = repmat(lengths,1,N) + repmat(lengths',N,1) - (G+G); % matrix of pairwise distances (squared)
% [~, neighbors] = sort(dists,2,'ascend');

% The first column of the matrix 'neighbors' stores the query points, 
% so we discard it.
neighbors = neighbors(:, 2:end);

%% compute all m-affinities

curv = inf(N, opts.knn, opts.nTuples);

for i = 1:N

    for nj = 1:opts.knn
        
        j = neighbors(i,nj);
        
        % enumerate all possible (m-1)-tuples in the neighborhood
        if opts.nTuples >= nTuplesMax
            rem = nchoosek(neighbors(i,[1:nj-1 nj+1:opts.knn]),opts.m-2);
            %rem = nchoosek([2:nj-1 nj+1:opts.knn+1],opts.m-2);
            opts.nTuples = nTuplesMax;
        else % <
            % or, can sample a fixed number of (m-1)-tuples instead of using all
            for h = 1:opts.nTuples
                rem(h,:) = randsample(neighbors(i,[1:nj-1 nj+1:opts.knn]),opts.m-2);
            end
        end
        
        for k = 1:opts.nTuples
            points = X([i j rem(k,:)],:);
            sigvals = svd(points-repmat(mean(points,1),opts.m,1), 0);
            curv(i,nj,k) = sum(sigvals(d+1:end).^2)/opts.m;
            %diam(i,nj-1,k) = max(max(dists([i j rem(k,:)],[i j rem(k,:)])));
        end
        
    end
    
end

%%
if nargout>2
    moreOutput = struct();
    moreOutput.curv = curv;
    moreOutput.optEta = [];
    moreOutput.optW = [];
    moreOutput.optZ = [];
    moreOutput.optU = [];
    moreOutput.optEigvals = [];
    %moreOutput.Corr = [];
end

%% Compute number of outliers

if opts.alpha>0 % when outliers are present
    if opts.alpha>=1 % number of outliers
        nOutliers = opts.alpha;
    else % <1, percentage of outliers
        nOutliers = round(N*opts.alpha);
    end
end

%% find clusters or detect outliers for each eta   
eta = opts.eta(1);
if strcmpi(opts.kernel, 'heat')
    ieta = 0.5/eta^2; 
    curv_exp = zeros(size(curv));
    finiteEntries = ~isinf(curv);
    curv_exp(finiteEntries) = exp(-ieta.*curv(finiteEntries));
else
    curv_exp = (curv<eta^2);
end

%% initialization
objFun = Inf; 
indices = [];
W = zeros(N,N);

while eta >= opts.eta(2)
    
    for i = 1:N
         W(i,neighbors(i,1:opts.knn)) = sum(curv_exp(i,:,:),3);
    end
    
    switch opts.KNNtype
        case 'oneway'
            W = max(W, W');
        case 'mutual'
            W = min(W, W');
    end
    
    %figure; imagesc(W);
    degrees = sum(W,2);
    
    %if nargout>2,
    %    moreOutput.Corr(end+1) = sum(degrees).^(1/(opts.m-1));
    %end
    
    if opts.alpha > 0 % when outliers are present
        
        [~, ind_sort] = sort(degrees,'ascend');
        inliers = ind_sort(nOutliers+1:end);
        
        indices1 = zeros(N,1);        
        indices1(inliers) = 1;
        objFun1 = (mean(degrees(ind_sort(1:nOutliers))) - mean(degrees(inliers))) / degrees(ind_sort(end));
        
        if objFun > objFun1
            objFun = objFun1;
            indices = indices1;
            if nargout>2
                moreOutput.optEta = eta;
                moreOutput.optW = W;
            end
        end
        
    else % no outliers
        
        degrees(degrees<ABSOLUTE_MINIMUM) = 1;
        
        D_invsqrt = 1./sqrt(degrees);
        Z = repmat(D_invsqrt,1,N).*W.*repmat(D_invsqrt',N,1);
        
        Z = real(Z);
        Z = (Z+Z')/2;
        
        [Ui,eigvals] = eigs(Z,2*K,'LM',struct('isreal', true, 'issym', true, 'disp', 0));
        eigvals = diag(eigvals);
        
        U = Ui(:,1:K);
        V = U./repmat(sqrt(sum(U.^2,2)),1,K);
        
        % apply Kmeans 10 times and use the best result
        indices1 = zeros(N,1);
        objFun1 = Inf;
        for k = 1:10
            [inds,~,SUMD] = kmeans(V,K,'EmptyAction','drop');
            if objFun1 > sum(SUMD)
                indices1 = inds;
                objFun1 = sum(SUMD);
            end
        end
        
        if objFun1 < objFun
            objFun = objFun1;
            indices = indices1;
            if nargout>2
                moreOutput.optEta = eta;
                moreOutput.optW = W;
                moreOutput.optZ = Z;
                moreOutput.optU = V;
                moreOutput.optEigvals = eigvals;
            end
        end
         
    end
       
    eta = eta/sqrt(2);
    
    if strcmpi(opts.kernel, 'heat')
        curv_exp(finiteEntries) = curv_exp(finiteEntries).^2;
    else
        curv_exp = (curv<eta^2);
    end
        
end

%%
if opts.alpha > 0 
    objFun = -objFun;
end

