clear
close all
clc

%% generate data

%[X, truelabels] = simulate_data('twogaussians');
%[X, truelabels] = simulate_data('threegaussians');

%[X, truelabels] = simulate_data('twocircles');
%[X, truelabels] = simulate_data('threecircles');

%load twomoons

n1 = 100000;
X = [randn(n1,2)*0.5 - 3; randn(n1,2)*0.5; randn(n1,2)*0.5 + 5];
truelabels = repelem(1:3,n1);

k = 3;

figure; 
subplot(1,2,1)
gscatter(X(:,1), X(:,2), truelabels, 'rbm', 'po.', 16)
set(gca,'fontsize',16)
legend('fontsize',16)
title('given data', 'FontSize',18)

%% randomly sample points as landmarks

n = size(X,1); % number of data points
m = k*100; % number of landmark points

subset = sort(randsample(n,m));

landmarks = X(subset,:); 

subplot(1,2,2)
gscatter(landmarks(:,1), landmarks(:,2), truelabels(subset), 'rbm', 'po.', 16)
set(gca,'fontsize',16)
legend('fontsize',16)
title('selected landmarks', 'FontSize',18)

%% Compute the similarity matrix A

r = ceil(log(n)); %number of nearest landmark points 

[knnidx,knndists] = knnsearch(landmarks, X, 'K', r+1);
sigma = mean(knndists(:,ceil((r+1))/2));

A_knn = exp(-knndists(:,2:end).^2/(2*sigma^2)); % nxr

A = zeros(n,m); 
all_idx = (1:n)' + n * (knnidx(:,2:end)-1);
A(all_idx(:)) = A_knn(:);

A = sparse(A);

figure; 
imagesc(A); 
colorbar
title('Similarity matrix A','fontsize',18)

%% Find the top k eigenvectors of W = [O A; A^T O] from SVD of A

D1 = sum(A,2); % row sums
D2 = sum(A,1); % column sums

Atilde  = A ./ sqrt(D1);
Atilde = Atilde ./ sqrt(D2);

[U,S,V] = svds(Atilde, 2*k);

W = [U(:,2:k)./sqrt(D1);  V(:,2:k)./sqrt(D2')];

figure; 
plot(diag(S), '.', 'markersize', 16)
title('top singular values of Atilde', 'fontsize',18)
set(gca,'fontsize',16)
grid on

%% perform kmeans clustering in eigenvector space

labels_ncut = kmeans(W, 3, 'replicates', 10);

figure; 
subplot(1,2,1)
plot(W(:,1), W(:,2), '.', 'markersize', 16)
set(gca,'fontsize',16)
title('second and third smallest eigenvectors','fontsize',18)
subplot(1,2,2)
gscatter(W(:,1), W(:,2), labels_ncut, 'rbm', 'po.', 16)
set(gca,'fontsize',16)
legend('fontsize',16)
title('clusters found by kmeans in eigenvector space','fontsize',18)

figure; 
subplot(1,2,1)
gscatter(X(:,1), X(:,2), labels_ncut(1:n), 'rbm', 'po.', 16)
set(gca,'fontsize',16)
legend('fontsize',16)
title('Data clusters found by Ncut','fontsize',18)
subplot(1,2,2)
gscatter(landmarks(:,1), landmarks(:,2), labels_ncut(n+1:end), 'rbm', 'po.', 16)
set(gca,'fontsize',16)
legend('fontsize',16)
title('Landmark clusters found by Ncut','fontsize',18)

err = clustering_error(labels_ncut, truelabels)
