% the function OurGMM() simplifies a given mixture model f to smaller one, g % Input: % H: d-by-d-by-n matrix, each d-by-d block containing the covariance of one component in f. % weight: n-by-1 vector, each containing the weight of one component in f. % data: n-by-d matrix, each row containing the center of one component in f. % Output: % B: d-by-d-by-n matrix, each d-by-d block containing the covariance of one component in g. % W: n-by-1 vector, each containing the weight of one component in g. % T: m-by-d matrix, each row containing the center of one component in g. function [T,W,B, class] = OurGMM(H,weight,data, class); Const = 1/sqrt(det(2*H)); bear = 0.005; % stopping criteria [n,dim]=size(data); K = max(class); W = zeros(1,K); T = zeros(K,dim); B = zeros(dim,dim,K); invH = zeros(dim,dim,K); S1 = zeros(1,K); S2 = zeros(1,K); v = zeros(1,K); Z = zeros(1,K); % Controlling maximum number of iterations in different levels % of cycling (for computational consideration). C1 = 15; % highest level of iterations C2 = 3; % alternations between updating center and covariance % (can be set as a very small integer without influencing too much the result) C3 = 10; % iterations for determining the center t_i C4 = 10; % iterations for determining the covariance tilde_H_i GMM_err0 = 1e50; for cycle = 1:C1; for i=1:K; dex = find(class == i); set = data(dex,:); num = length(dex); wgt = weight(dex); sumwgt = sum(wgt); t = zeros(1,dim); for j=1:num; t = t + set(j,:) * wgt(j); end; t = t / sumwgt; covset = zeros(dim,dim); for j=1:num; v = t - set(j,:); covset = covset + wgt(j) * v' * v; end; Hi = H + covset/ sumwgt; for cycle1 = 1:C2; invHiH = inv(Hi + H); Ms_err0 = 1e50; for mscycle = 1:C3; ct = zeros(1,dim); c1 = 0; for j=1:num; v = t - set(j,:); c = exp(-v*invHiH*v'/2) * wgt(j); ct = ct + set(j,:) * c; c1 = c1 + c; end; t1 = ct / c1; Ms_err1 = norm(t1 - t); t = t1; if abs(Ms_err0 - Ms_err1) <= bear * Ms_err0; break; else Ms_err0 = Ms_err1; end; end; covset = zeros(dim,dim); c1 = 0; for j=1:num; v = t - set(j,:); c = exp(-v*invHiH*v'/2) * wgt(j); covset = covset + c * v' * v; c1 = c1 + c; end; covset = 2 * covset/c1; Hi_err0 = 1e50; for Hicycle = 1:C4; Hi1 = H + Hi * pinv(Hi + H) * covset; Hi_err1 = norm(Hi(:) - Hi1(:)); Hi = Hi1; if abs(Hi_err1 - Hi_err0) < bear * Hi_err0; break; else Hi_err0 = Hi_err1; end; end; if(mscycle <=2 & Hicycle <=2); break; end; end; T(i,:) = t; B(:,:,i) = Hi; invHiH = inv(Hi + H); c = 0; for j = 1:num; v = t - set(j,:); c = c + exp(-v * invHiH * v'/2) * wgt(j); end; W(i) = sqrt(det(2*Hi) / det(H + Hi)) * c; end; for i = 1:K; invH(:,:,i) = inv(H + B(:,:,i)); Z(i) = sum(weight(find(class==i))); S1(i) = det(2*B(:,:,i))^(0.5); S2(i) = det(B(:,:,i) + H)^(0.5); end; GMM_err1 = 0; for i = 1 : n; for j = 1:K; f = data(i,:) - T(j,:); v(j) = Const + (W(j)/Z(j))^2/S1(j) - 2 * W(j)/Z(j)/S2(j)*exp(-f*invH(:,:,j)*f'/2); end; [xx,class(i)] = min(v); GMM_err1 = GMM_err1 + xx * weight(i); end; if abs(GMM_err1 - GMM_err0) <= bear * GMM_err0; break; else GMM_err0 = GMM_err1; end; end;