EM알고리즘은 잠재변수(latent variable)를 갖는 확률 모델의 MLE 값을 찾는 과정에서 활용되는 기법이다. 잠재변수를 활용하는 가우시안 혼합 모델에 관한 추정에 자주 활용되는 기법이기도 하다.

우선 관측변수 X={x1,...xN} 이라 하자. 잠재변수 Z={z1,z2,...zn} 는 이 논의를 진행하는 과정에서 이산형 변수라고 가정하자. 만약 Z가 연속형이라면, 아래의 과정에서 합표기가 되어있는 것을 적분으로 바꾸면 된다. 그리고 모델에 활용되는 Parameter들을 Θ 라고 표현하자.

우리의 목표는 P(XΘ) 혹은 l(ΘX) 를 최대화시키는 Θ 값을 구하는 것이다. 이 P(XΘ)에 대한 log-likelihood는 다음과 같이 표현할 수 있다.

lnP(XΘ)=ln{ZP(X,ZΘ)}

X 의 관측값에 따른 잠재변수 Z를 알게 되었다고 한다면, {X,Z} 를 완전한(complete) 데이터 집합이라 표현할 수 있다. Z는 기존에 알지 못했던 것이기 때문에 missing value, X는 기존에 알고 있었기 때문에 observed value라고 할 수 있다.

EM알고리즘은 observed variable X를 이용한 log-likelihood의 최대화보다 complete data {X,Z} 를 활용한 log-likelihood의 최대화가 보다 쉽다는 가정에서 시작한다.

실제 상황에서는 완전한 데이터 집합 {X,Z}이 주어지지 않고 불완전한 X만 주어질 가능성이 아주 높다. 잠재변수 Z에 대해서는 Z에 대한 posterior distribution인 P(ZX,Θ)를 통해서만 확인할 수 있다. 그래서 우리는 잠재 변수의 Posterior distribution을 통해 기대값을 구하고 이를 활용하여 완전한 데이터셋 {X,Z}에 대한 log-likelihood를 구할 수 있다.

여기서 잠재변수의 Posterior distribution을 활용해 기대값을 구하는 것이 EM 알고리즘의 E(Expectation) 단계이다. M(Maximization) 단계는 E단계에서 구한 기대값을 최대화시키는 Θ 값을 추정하는 단계이다. 각 parameter에 대한 현재값을 Θ(t)라 하고 E,M 단계를 거쳐서 수정된 parameter값을 Θ(t+1) 라고 표기한다.

E단계에서는 현재의 parameter Θ(t)를 이용하여, P(ZX,Θ(t)) 형태의 Z 의 Posterior distribution을 먼저 구한다. 그리고 이 Posterior distribution을 이용하여 다음의 값을 구한다.

Q(ΘΘ(t))=EZ[l(ΘX,Z)X,Θ(t)]=ZP(ZX,Θ(t))lnP(X,ZΘ)

M단계에서는 E단계에서 구한 Q 함수를 최대화시키는 parameter 값을 찾는다.

Θ(t+1)=argmaxΘQ(ΘΘ(t))

전반적인 과정을 살펴본다면 EM 알고리즘은 다음의 4가지 단계로 구성된다.

  1. P(X,ZΘ) 의 log-likelihood를 구한다.
  2. 1단계에서 구한 log-likelihood를 이용하여 Q함수를 찾는다. (E-step)
  3. 2단계에서 구한 Q함수를 최대화시키는 parameter 값을 구한다. (M-step)
  4. parameter 값이 수렴할 때까지 해당 과정을 반복하고 parameter의 변동이 없다고 생각될 때, 반복 시행을 중단한다.

EM알고리즘의 Ascent Property

EM알고리즘은 매 반복마다 l(ΘX)가 증가한다는 성질을 갖는다. 이 말을 좀 더 수식으로 이야기하자면, l(Θ(t+1)X)l(ΘtX) 라는 것이다. 이는 매 반복시행마다 log-likelihood를 증가시키는 Θ(t+1)를 찾는다는 것으로 이 알고리즘이 MLE를 찾는다는 것과 같은 의미이다.

완전한 데이터 {X,Z} 의 joint distribution은 다음과 같이 표현된다.

P(X,ZΘ)=P(XΘ)P(ZX,Θ)lnP(X,ZΘ)=lnP(XΘ)+lnP(ZX,Θ)

아래의 식 양변에 기대값을 취해보자.

EZ[lnP(XΘ)X,Θ(t)]=EZ[lnP(X,ZΘ)X,Θ(t)]EZ[lnP(ZX,Θ)X,Θ(t)]lnP(XΘ)=Q(ΘΘ(t))EZ[lnP(ZX,Θ)X,Θ(t)]

이제 이 식을 활용하여 l(Θ(t+1)X)l(ΘtX)0 임을 보일 것이다. 수식 표기의 간결성을 위해 EZ[lnP(ZX,Θ)X,Θ(t)]H(ΘΘ(t)) 로 표기한다.

l(Θ(t+1)X)l(ΘtX)=Q(Θ(t+1)Θ(t))Q(Θ(t)Θ(t))H(Θ(t+1)Θ(t))+H(Θ(t+1)Θ(t))

이며 Θ(t+1)=argmaxΘQ(ΘΘ(t) 임을 고려하면 Q(Θ(t+1)Θ(t))Q(Θ(t)Θ(t))0 인것은 자명하다. 그렇다면, 우리는 H(Θ(t+1)Θ(t))H(Θ(t+1)Θ(t))0 임을 보이기만 하면 된다.

여기서 다시 H 함수를 원래의 EZ[lnP(ZX,Θ)X,Θ(t)] 로 바꾼다.

H(Θ(t+1)Θ(t))H(Θ(t+1)Θ(t))=EZ[lnP(ZX,Θ(t+1))X,Θ(t)]EZ[lnP(ZX,Θ(t))X,Θ(t)]=EZ[lnP(ZX,Θ(t+1))lnP(ZX,Θ(t))X,Θ(t)]ln[EZ[P(ZX,Θ(t+1))P(ZX,Θ(t))]X,Θ(t)]=0

로그함수와 같이 concave한 함수에서는 E[g(X)]g[E(X)] 라는 Jensen’s Inequality 공식을 활용하였다.

예제

혼합 가우시안 분포(Gaussian Mixture)에 대한 EM 알고리즘 적용 예제를 살펴보도록 하자. 이 예제에서는 2개의 가우시안 분포가 혼합된 형태를 가정한다. 또한 각각의 데이터가 가우시안 분포로부터 서로 독립적으로 n개 추출되는 것으로 가정하자.

Xi(1π)N(xiμ1,σ21)+πN(xiμ2,σ22)

먼저 이 분포에 대한 likelihood와 log-likelihood를 구한다. 편의상 μσ2 로 parameter를 묶어서 표현하겠다.

L(π,μ,σ2X)=ni=1[(1π)N(xiμ1,σ21)+πN(xiμ2,σ22)]l(π,μ,σ2X)=ni=1ln[(1π)N(xiμ1,σ21)+πN(xiμ2,σ22)]

데이터 추출과정은 다음과 같다고 할 수 있다. 아래와 같은 확률로 각각의 데이터 포인트가 어느 가우시안 분포에서 추출될 것인지가 결정되고

zi={1with probabilityπ0with probability1π

이제 Z 가 주어진 상태이기 때문에 데이터 포인트의 추출은 조건부 분포를 따를 것이다.

XiZi={N(xiμ1,σ21)ifzi=0N(xiμ2,σ22)ifzi=1

더불어 P(X,Z midΘ)=P(XZ,Θ)P(ZΘ) 임을 고려한다면, 다음과 같은 likelihood 식을 구할 수 있다.

ni=1{πN(xiμ1,σ21)}zi{(1π)N(xiμ2,σ22)}1zi

그래서 최종적인 log-likelihood는 다음과 같이 표현할 수 있다.

l(π,μ,σ2X,Z)=ni=1zi{lnπ+lnN(xiμ1,σ21)}+ni=1(1zi){ln(1π)+lnN(xiμ2,σ22)}

이제 E-step으로 들어가자. E-step에서는 잠재변수에 대한 조건부 기대값과 Q 함수를 도출해야 한다.

잠재변수에 대한 조건부 기대값은 다음의 과정을 통해 구할 수 있다.

EZ[zixi,Θ(t)]=1×P(zi=1xi,Θ(t))+0×P(zi=0xi,Θ(t))=P(xizi=1,Θ(t))P(zi=1Θ(t))P(xiΘ(t))=P(xizi=1,Θ(t))P(zi=1Θ(t))P(xizi=1,Θ(t))P(zi=1Θ(t))+P(xizi=0,Θ(t))P(zi=0Θ(t))=π(t)N(xiμ(t)1,(σ21)2)π(t)N(xiμ(t)1,(σ21)2)+(1π(t))N(xiμ(t)2,(σ22)2)

zixi,Θ(t)는 다음의 베르누이 분포를 따른다고 할 수 있으며 이 확률이 zixi,Θ(t)의 기대값 ^zi라 할 수 있다.

zixi,Θ(t)Ber(π(t)N(xiμ(t)1,(σ21)2)π(t)N(xiμ(t)1,(σ21)2)+(1π(t))N(xiμ(t)2,(σ22)2))

Q함수 Q(ΘΘ(t)EZ[l(ΘX,Z)X,Θ(t)] 이며 다음과 같다.

Q(ΘΘ(t))=ni=1^zi(lnπ+lnN(xiμ1,σ21))+ni=1(1^zi)(ln(1π)+lnN(xiμ2,σ22))

마지막으로 Q함수를 최대화시키는 5가지 모수에 대한 값을 찾는 것이다. 각 paramter별로 미분하여 최대값이 나오는 값을 찾아 업데이트 한다.

Q(ΘΘ(t))π=ni=1^zi1π+ni=1(1^zi)(11π)=0π(t+1)=ni=1^zin

나머지 parameter에 대해서는 업데이트 되는 값이 다음과 같다.

μ(t+1)1=ni=1^zixini=1^ziμ(t+1)2=ni=1(1^zi)xini=1(1^zi)(σ21)(t+1)=ni=1^zi(xiμ1)2ni=1^zi(σ22)(t+1)=ni=1(1^zi)(xiμ2)2ni=1(1^zi)

이 과정을 각각의 parameter에 대해 수렴하는 시점까지(t+1시점 값과 t시점 값의 차이가 일정 수준 이하가 될 때까지) 시행하여 최종적인 값을 구해낼 수 있다.

이 포스팅의 예제 코드는 다음의 주소에서 확인할 수 있습니다.