Variational inference
Probabilistic Graphical Models Sharif University of Technology Soleymani Spring 2018
Some slides are adapted from Xing’s slides
Variational inference Probabilistic Graphical Models Sharif - - PowerPoint PPT Presentation
Variational inference Probabilistic Graphical Models Sharif University of Technology Soleymani Spring 2018 Some slides are adapted from Xing s slides Exact methods for inference Variable elimination Message Passing: shared terms
Some slides are adapted from Xing’s slides
2
Variable elimination Message Passing: shared terms
Sum-product (belief propagation) Max-product Junction Tree
3
General algorithm on graphs with cycles Message passing on junction trees
𝐷
𝑘
𝑇𝑗𝑘 𝐷𝑗 𝑛𝑗𝑘 𝑇𝑗𝑘 𝑛𝑘𝑗 𝑇𝑗𝑘
4
The computational complexity of Junction tree algorithm with be
𝐿 𝐷 where 𝐷 shows the largest elimination clique (the largest clique in the triangulated graph)
For a distribution 𝑄 associated with a complex graph, computing the
marginal (or conditional) probability of arbitrary random variable(s) is intractable
Tree-width of an 𝑂 × 𝑂 grid is 𝑂
5
Learning usually needs inference
For Bayesian inference that is one of the principal foundations
For Maximum Likelihood approach, also, we need inference
6
Approximate inference techniques
Variational algorithms
Loopy belief propagation Mean field approximation Expectation propagation
Stochastic simulation / sampling methods
7
“variational”:
Many problems can be expressed in terms of an optimization
Variational inference is deterministic framework that is
8
Constructing an approximation to the target distribution 𝑄
We define a target class of distributions Search for an instance 𝑅∗ in that is the best approximation to 𝑄 Queries will be answered using 𝑅∗ rather than on 𝑄
: given family of distributions
Simpler families for which solving the optimization problem will
However, the family may not be sufficiently expressive to
Constrained
9
Assume that we are interested in the posterior distribution
𝑄 𝑎 𝑌, 𝛽 = 𝑄(𝑎, 𝑌|𝛽) 𝑄 𝑎, 𝑌 𝛽 𝑒𝑎
The problem of computing the posterior is an instance of a more
general problem that variational inference solves
Main idea:
We pick a family of distributions over the latent variables with its own
variational parameters
Then, find the setting of the parameters that makes 𝑅 close to the posterior
Use 𝑅 with the fitted parameters as an approximation for the posterior
𝑌 = {𝑦1, … , 𝑦𝑜} 𝑎 = {𝑨1, … , 𝑨𝑛} Observed variables Hidden variables
10
Goal: Approximate difficult distribution 𝑄(𝑎|𝑌) with a
𝑄(𝑎|𝑌) and 𝑅(𝑎) are close Computation on 𝑅(𝑎) is easy
Typically, the true posterior is not in the variational family. How should we measure distance between distributions?
The Kullback-Leibler divergence (KL-divergence) between two
11
Kullback-Leibler divergence between 𝑄 and 𝑅:
A result from information theory: For any 𝑄 and 𝑅
𝐿𝑀(𝑄| 𝑅 = 0 if and only if 𝑄 ≡ 𝑅 𝐸 is asymmetric
12
We wish to find a distribution 𝑅 such that 𝑅 is a “good”
We can therefore use KL divergence as a scoring function
But, 𝐿𝑀(𝑄(𝑎|𝑌)||𝑅(𝑎)) ≠ 𝐿𝑀(𝑅(𝑎)||𝑄(𝑎|𝑌))
13
M-projection of 𝑅 onto 𝑄
𝑅∈
I-projection of 𝑅 onto 𝑄
𝑅∈
These two will differ only when 𝑅 is minimized over a
14
Let 𝑄 be a 2D Gaussian and 𝑅 be a Gaussian distribution
𝑄: Green 𝑅∗: Red 𝑅∗ = argmin
𝑅
𝑄 𝒜 log 𝑄 𝒜 𝑅 𝒜 𝑒𝒜 𝑅∗ = argmin
𝑅
𝑅 𝒜 log 𝑅 𝒜 𝑄 𝒜 𝑒𝒜 𝐹𝑄 𝒜 = 𝐹𝑅[𝒜] 𝐹𝑄 𝒜 = 𝐹𝑅[𝒜] [Bishop]
15
Let 𝑄 is mixture of two 2D Gaussians and 𝑅 be a 2D
𝑄: Blue 𝑅∗: Red [Bishop] 𝐹𝑄 𝒜 = 𝐹𝑅 𝒜 𝐷𝑝𝑤𝑄 𝒜 = 𝐷𝑝𝑤𝑅 𝒜 𝑅∗ = argmin
𝑅
𝑄 𝒜 log 𝑄 𝒜 𝑅 𝒜 𝑒𝒜 𝑅∗ = argmin
𝑅
𝑅 𝒜 log 𝑅 𝒜 𝑄 𝒜 𝑒𝒜 two good solutions!
16
Computing 𝐿𝑀(𝑄| 𝑅 requires inference on 𝑄
𝑨
When 𝑅 is in the exponential family:
Expectation
Moment projection Inference on 𝑄 (that is difficult) is required!
17
𝐿𝑀(𝑅| 𝑄
Most variational inference algorithms make use of 𝐿𝑀(𝑅| 𝑄 Computing expectations w.r.t. 𝑅 is tractable (by choosing a
We choose a restricted family of distributions such that the expectations
can be evaluated and optimized efficiently.
and yet which is still sufficiently flexible as to give a good approximation
18
[Bishop] Variational Laplace Approx.
19
We can maximize the lower bound ℒ 𝑅
equivalent to minimizing KL divergence. if we allow any possible choice for 𝑅(𝑎), then the maximum of the
lower bound occurs when the KL divergence vanishes
occurs when 𝑅(𝑎) equals the posterior distribution 𝑄(𝑎|𝑌).
The
difference between the ELBO and the KL divergence is ln 𝑄(𝑌) which is what the ELBO bounds
We also called ℒ 𝑅 as 𝐺[𝑄, 𝑅] latter. 𝑌 = {𝑦1, … , 𝑦𝑜} 𝑎 = {𝑨1, … , 𝑨𝑛}
20
Lower bound on the marginal likelihood This quantity should increase monotonically with each
we maximize the ELBO to find the parameters that gives as
ELBO converges to a local minimum.
Variational inference is closely related to EM
21
𝑅 𝑎 =
𝑗
𝑅𝑗(𝑎𝑗) ℒ 𝑅 =
𝑗
𝑅𝑗 ln 𝑄(𝑌, 𝑎) −
𝑗
ln 𝑅𝑗 𝑒𝑎 Coordinate ascent to optimize ℒ 𝑅 (we first find ℒ𝑘 𝑅 that is a functional of 𝑅𝑘): ℒ𝑘 𝑅 = 𝑅𝑘 ln 𝑄(𝑌, 𝑎)
𝑗≠𝑘
𝑅𝑗𝑒𝑎𝑗 𝑒𝑎
𝑘 − 𝑅𝑘 ln 𝑅𝑘 𝑒𝑎 𝑘 + 𝑑𝑝𝑜𝑡𝑢
⇒ ℒ𝑘 𝑅 = 𝐹−𝑘 ln 𝑄 𝑌, 𝑎 − 𝑅𝑘 ln 𝑅𝑘 𝑒𝑎
𝑘 + 𝑑𝑝𝑜𝑡𝑢
The restriction on the distributions in the form of factorization assumptions:
𝐹−𝑘 ln 𝑄 𝑌, 𝑎 = ln 𝑄 𝑌, 𝑎
𝑗≠𝑘
𝑅𝑗 𝑒𝑎𝑗
22
𝑀(𝑅𝑘, 𝜇) = ℒ𝑘 𝑅 + 𝜇(
𝑎𝑘
𝑅 𝑎
𝑘 − 1)
𝑒𝑀 𝑒𝑅(𝑎
𝑘) = 𝐹−𝑘 log 𝑄 𝑎, 𝑌
− log 𝑅 𝑎
𝑘 − 1 + 𝜇 = 0
⇒ 𝑅∗(𝑎
𝑘) ∝ exp 𝐹−𝑘 ln 𝑄 𝑌, 𝑎
𝑅∗(𝑎
𝑘) ∝ exp 𝐹−𝑘 ln 𝑄 𝑎 𝑘|𝑎−𝑘, 𝑌
The above formula determines the form of the optimal 𝑅 𝑎
𝑘 . We didn't
specify the form in advance and only the factorization has been assumed.
Depending on that form, the optimal 𝑅 𝑎 𝑘
might not be easy to work with. Nonetheless, for many models it is easy.
Since we are replacing the neighboring values by their mean value, the
method is known as mean field
23
24
Solution:
25
For simplicity, assume that the data generating variance is one.
𝑄 𝝂 = 𝑙=1
𝐿
𝒪 𝝂𝑙|𝒏0, 𝚳0
−1
𝑄 𝑨𝑙
𝑜 = 1|𝝆 = 𝜌𝑙 𝑄 𝒚(𝑜)|𝑨𝑙
𝑜 = 1, 𝝂 = 𝒪 𝒚(𝑜)|𝝂𝑙, 𝑱
For 𝑙 = 1, … , 𝐿 Draw 𝝂𝑙~𝒪 𝒏0, 𝚳0
−1
For 𝑜 = 1, … , 𝑂 Draw 𝒜(𝑜)~𝑁𝑣𝑚𝑢 𝝆 Draw 𝒚(𝑜)~ 𝑙=1
𝐿
𝒪 𝝂𝑙, 𝑱 𝑨𝑙
𝑜
𝒜(𝑜) 𝒚(𝑜) 𝑂 𝝆𝝆 𝝂𝝆
26
𝑎 = 𝒜(1), … , 𝒜 𝑂 , 𝝂1, … , 𝝂𝐿 𝑌 = 𝒚(1), … , 𝒚 𝑂
𝑄 𝒜(1), … , 𝒜 𝑂 , 𝝂1, … , 𝝂𝐿|𝒚(1), … , 𝒚 𝑂 = 𝑙=1
𝐿
𝑄 𝝂𝑙 𝑜=1
𝑂
𝑄 𝒜(𝑜) 𝑄 𝒚(𝑜)|𝒜 𝑜 , 𝝂1, … , 𝝂𝐿
𝝂1,…,𝝂𝐿 𝒜(1),…,𝒜 𝑂 𝑙=1 𝐿
𝑄 𝝂𝑙 𝑜=1
𝑂
𝑄 𝒜(𝑜) 𝑄 𝒚(𝑜)|𝒜 𝑜 , 𝝂1, … , 𝝂𝐿
The denominator is difficult to compute
27
Consider
This is the only assumption required to make in order to
28
ln 𝑅(𝒜 1 , … , 𝒜 𝑂 ) = 𝐹𝝂1,…,𝝂𝐿 ln 𝑄 𝑎, 𝑌, 𝝂 + const = 𝐹𝝂1,…,𝝂𝐿 ln 𝑄 𝒜 1 , … , 𝒜 𝑂 , 𝝂1, … , 𝝂𝐿, 𝒚 1 , … , 𝒚 𝑂 + const = 𝐹𝝂1,…,𝝂𝐿 ln
𝑙=1 𝐿
𝑄 𝝂𝑙
𝑜=1 𝑂
𝑄 𝒜(𝑜) 𝑄 𝒚(𝑜)|𝒜 𝑜 , 𝝂1, … , 𝝂𝐿
+ const
= 𝐹𝝂1,…,𝝂𝐿
𝑙=1 𝐿
ln 𝑄 𝝂𝑙 +
𝑜=1 𝑂
ln 𝑄 𝒜(𝑜) +
𝑜=1 𝑂
ln 𝑄 𝒚(𝑜)|𝒜 𝑜 , 𝝂1, … , 𝝂𝐿
+ const
ln 𝑄 𝒚(𝑜)|𝒜 𝑜 , 𝝂1, … , 𝝂𝐿 =
𝑙=1 𝐿
𝑨𝑙
(𝑜) ln 𝑂 𝒚(𝑜)|𝝂𝑙, 𝑱
= − 𝑒 2 ln 2𝜌 − 1 2
𝑙=1 𝐿
𝑨𝑙
𝑜
𝒚 𝑜 − 𝝂𝑙
𝑈 𝒚 𝑜 − 𝝂𝑙
ln 𝑄 𝒜(𝑜) =
𝑙=1 𝐿
𝑨𝑙
𝑜 ln 𝜌𝑙
29
ln 𝑅(𝒜 1 , … , 𝒜 𝑂 ) =
𝑜=1 𝑂
ln 𝑅 𝒜(𝑜) ⇒ 𝑅 𝒜 1 , … , 𝒜 𝑂 =
𝑜=1 𝑂
𝑅 𝒜(𝑜) ln 𝑅 𝒜(𝑜) = 𝐹𝝂1,…,𝝂𝐿
𝑙=1 𝐿
𝑨𝑙
𝑜 ln 𝜌𝑙 − 1
2
𝑙=1 𝐿
𝑨𝑙
𝑜
𝒚 𝑜 − 𝝂𝑙
𝑈 𝒚 𝑜 − 𝝂𝑙
+ const
30
ln 𝑅 𝒜(𝑜) = 𝐹𝝂1,…,𝝂𝐿
𝑙=1 𝐿
𝑨𝑙
𝑜 ln 𝜌𝑙 − 1
2
𝑙=1 𝐿
𝑨𝑙
𝑜
𝒚 𝑜 − 𝝂𝑙
𝑈 𝒚 𝑜 − 𝝂𝑙
+ const
ln 𝑅 𝒜(𝑜) =
𝑙=1 𝐿
𝑨𝑙
𝑜
ln 𝜌𝑙 + 𝒚 𝑜 𝑈𝐹𝝂𝑙 𝝂𝑙 − 1 2 𝐹𝝂𝑙 𝝂𝑙
𝑈𝝂𝑙 − 1
2 𝒚 𝑜 𝑈𝒚 𝑜 + const ⇒ 𝑅 𝒜(𝑜) = 𝑁𝑣𝑚𝑢 𝑠𝑜1, … , 𝑠𝑜𝑙 𝐹 𝑨𝑙
(𝑜) = 𝑠𝑜𝑙
𝑠𝑜𝑙 = 𝑓𝑦𝑞 ln 𝜌𝑙 + 𝒚 𝑜 𝑈𝐹 𝝂𝑙 − 1 2 𝐹 𝝂𝑙
𝑈𝝂𝑙 − 1
2 𝒚 𝑜 𝑈𝒚 𝑜 𝑙=1
𝐿
𝑓𝑦𝑞 ln 𝜌𝑙 + 𝒚 𝑜 𝑈𝐹 𝝂𝑙 − 1 2 𝐹 𝝂𝑙
𝑈𝝂𝑙 − 1
2 𝒚 𝑜 𝑈𝒚 𝑜
31
𝑙=1 𝐿
𝑜=1 𝑂
𝑙=1 𝐿
𝑜=1 𝑂
ln 𝑄 𝒚(𝑜)|𝒜 𝑜 , 𝝂1, … , 𝝂𝐿 =
𝑙=1 𝐿
𝑨𝑙
(𝑜) ln 𝑂 𝒚(𝑜)|𝝂𝑙, 𝑱
32
=
𝑙=1 𝐿
ln 𝑄 𝝂𝑙 +
𝑜=1 𝑂 𝑙=1 𝐿
𝐹 𝑨𝑙
(𝑜) ln 𝑂 𝒚(𝑜)|𝝂𝑙, 𝑱
+ const ⇒ 𝑅 𝝂1, … , 𝝂𝐿 =
𝑙=1 𝐿
𝑅 𝝂𝑙 𝑅(𝝂𝑙) ∝ exp ln 𝑄 𝝂𝑙 +
𝑜=1 𝑂
𝐹 𝑨𝑙
𝑜
ln 𝑂 𝒚 𝑜 |𝝂𝑙, 𝑱 ⇒ 𝑅 𝝂𝑙 = 𝑂 𝝂𝑙|𝒏𝑙, 𝜧𝑙
−1
𝜧𝑙 = 𝜧0 +
𝑜=1 𝑂
𝐹 𝑨𝑙
(𝑜)
𝑱 𝒏𝑙 = 𝜧𝑙
−1
𝜧0𝝂0 +
𝑜=1 𝑂
𝐹 𝑨𝑙
(𝑜) 𝒚(𝑜)
33
In this example, variational posterior distribution have the same
functional form as the corresponding factor in the joint distribution
This is a general result and is a consequence of the choice of conjugate
distributions.
The form of posteriors will be determined by the form of the likelihood and
prior
There are general results for general class of conjugate-exponential
models
The additional factorizations of variational posterior distributions
are a consequence of the interaction between the assumed factorization and the conditional independencies in 𝑄
34
𝑄 𝑨
𝑘|𝑨−𝑘, 𝑦 = ℎ 𝑨 𝑘 exp 𝜃 𝑨−𝑘, 𝑦 𝑈𝑈 𝑨 𝑘 − 𝐵 𝜃 𝑨−𝑘, 𝑦
ln 𝑄 𝑨
𝑘|𝑨−𝑘, 𝑦 = ln ℎ 𝑨 𝑘 + 𝜃 𝑨−𝑘, 𝑦 𝑈𝑈 𝑨 𝑘 − 𝐵 𝜃 𝑨−𝑘, 𝑦
Mean field variational inference is straightforward:
ln 𝑅 𝑨
𝑘 = 𝐹𝑅−𝑘 log 𝑄 𝑨 𝑘|𝑨−𝑘, 𝑦
+ const = ln ℎ 𝑨
𝑘 + 𝐹𝑅−𝑘 𝜃 𝑨−𝑘, 𝑦 𝑈𝑈 𝑨 𝑘 − 𝐹𝑅−𝑘 𝐵 𝜃 𝑨−𝑘, 𝑦
𝑅 𝑨
𝑘 ∝ ℎ 𝑨 𝑘 exp 𝐹𝑅−𝑘 𝜃 𝑨−𝑘, 𝑦 𝑈𝑈 𝑨 𝑘
𝑅 𝑨
𝑘
35
Give each hidden variable 𝑨
𝑘 a variational parameter 𝑤𝑘, and
𝑘
𝑘|𝑤𝑘
In each iteration of coordinate descent:
sets each natural variational parameter 𝑤𝑘 to the expectation of the
natural conditional parameter for variable 𝑨
𝑘 :
∗ = 𝐹𝑅−𝑘 𝜃 𝑨−𝑘, 𝑦
36
When
𝑜=1 𝑂
We shall also use a conjugate prior for η:
𝑄 𝜽|𝜉0, 𝝍𝟏 = 𝑔 𝜉𝟏, 𝝍𝟏 exp 𝜉0𝜽𝑈𝝍𝟏 − 𝜉0𝐵 𝜽
𝒂 = 𝒜 1 , … , 𝒜 𝑜 𝒀 = 𝒚 1 , … , 𝒚 𝑜
37
Suppose 𝑅 𝒂, 𝜽 = 𝑅 𝒂 𝑅 𝜽 :
⇒ 𝑅 𝒂 =
𝑜=1 𝑂
𝑅 𝒜(𝑜) 𝑅∗ 𝒜(𝑜) = ℎ 𝒚 𝑜 , 𝒜 𝑜 exp 𝐹𝜽[𝜽]𝑈𝑈 𝒚 𝑜 , 𝒜 𝑜 − 𝐵 𝐹𝜽[𝜽] 𝑅∗ 𝜽 = 𝑔 𝜉𝑂, 𝝍𝑂 exp 𝜽𝑈𝝍𝑂 − 𝜉𝑂𝐵 𝜽 𝜉𝑂 = 𝜉0 + 𝑂 𝝍𝑂 = 𝝍0 +
𝑜=1 𝑂
𝐹𝒜 𝑜 𝑈 𝒚 𝑜 , 𝒜 𝑜
38
Bayesian inference with incomplete data
For complete data, we could derive closed-form solutions to this
inference problem when we take some assumptions.
In the case of incomplete data, these solutions do not exist, and so we
need to resort to the approximate inference.
Variational Bayes EM (VBEM) provides a way to model
Bayesian estimation at a computational cost that is essentially the same
as EM.
Thus, it often gives us the speed benefits of ML or MAP estimation but
the statistical benefits of the Bayesian approach
39
ln 𝑄() = ln
ℋ
𝑄 , ℋ|𝜾 𝑄 𝜾 𝑒𝜾 ln 𝑄() ≥
ℋ
𝑅 ℋ, 𝜾 ln 𝑄 , ℋ, 𝜾 𝑅 ℋ, 𝜾 𝑒𝜾 ln 𝑄() ≥
ℋ
𝑅ℋ ℋ 𝑅𝜾 𝜾 ln 𝑄 , ℋ, 𝜾 𝑅ℋ ℋ 𝑅𝜾 𝜾 𝑒𝜾 ln 𝑄() ≥
ℋ
𝑅ℋ ℋ 𝑅𝜾 𝜾 ln 𝑄 , ℋ, 𝜾 𝑒𝜾 + 𝐼 𝑅ℋ + 𝐼 𝑅𝜾
Mean field: 𝑅 ℋ, 𝜾 = 𝑅ℋ ℋ 𝑅𝜾 𝜾 𝐺 𝑄, 𝑅 𝑎 = ℋ ∪ 𝜾 𝑌 = 𝐺 𝑄, 𝑅
40
We want to find 𝑅∗ = argmax
𝑅
𝑄, 𝑅
We assume factorization 𝑅 ℋ, 𝜾 = 𝑅ℋ ℋ 𝑅𝜾 𝜾
41
Initialization: Randomly select starting distribution 𝜾1 Repeat
E-Step: Given parameters, find posterior of hidden data
𝑢+1 = argmax 𝑅ℋ
𝑄, 𝑅ℋ, 𝑅𝜾 𝑢
M-Step: Given posterior distributions, find likely parameters
𝑢+1 = argmax 𝑅𝜾
𝑄, 𝑅ℋ 𝑢+1, 𝑅𝜾
Until convergence
𝐺 𝑄, 𝑅ℋ, 𝑅𝜾 =
ℋ
𝑅ℋ ℋ 𝑅𝜾 𝜾 ln 𝑄 , ℋ, 𝜾 𝑒𝜾 + 𝐼 𝑅ℋ + 𝐼 𝑅𝜾
42
𝑔
𝑏∈ℱ
𝑏(𝒚𝑏)
𝑔
𝑏∈𝐺
𝑏 𝒚𝑏
𝑔
𝑏∈𝐺
𝑏 𝒚𝑏
𝑔
𝑏∈𝐺
𝑏 𝒚𝑏 ℒ 𝑅 = 𝐼 𝑅 +
𝑔
𝑏∈𝐺
𝐹𝑅 log 𝑔
𝑏 𝒚𝑏
43
𝑅 𝒚 =
𝑗=1 𝑂
𝑅𝑗(𝑦𝑗) ℒ 𝑅 =
𝑏∈ℱ
𝐹𝑅 log 𝑔
𝑏 𝒚𝑏
+ 𝐼 𝑅
𝐹𝑅 log 𝑔
𝑏(𝒚𝑏) = 𝒚𝑏∈𝑊𝑏𝑚(𝑌𝑏) 𝑗∈𝒪 𝑏
𝑅𝑗(𝑦𝑗) log 𝑔
𝑏(𝒚𝑏)
𝐼 𝑅 =
𝑗=1 𝑂
𝐼[𝑅𝑗]
Thus, ℒ[𝑅] can be rewritten simply as a sum of expectations,
𝒪 𝑏 = 𝑗|𝑦𝑗 ∈ 𝑡𝑑𝑝𝑞𝑓 𝑔
𝑏
44
𝑅𝑗 𝑦𝑗 = 1 𝑎𝑗 exp
𝑏:𝑗∈𝒪 𝑏 𝒚𝑏∈𝑊𝑏𝑚(𝑌𝑏)
𝑅 𝒚𝑏|𝑦𝑗 log 𝑔
𝑏(𝒚𝑏) Proof:
ℒ 𝑅 =
𝑗=1 𝑂
ℒ𝑗[𝑅] ℒ𝑗 𝑅 =
𝑏:𝑗∈𝒪 𝑏 𝒚𝑏 𝑘∈𝒪 𝑏
𝑅𝑘 𝑦𝑘 log 𝑔
𝑏 𝒚𝑏 + 𝐼 𝑅𝑗
𝑀𝑗 𝑅, 𝜇 = ℒ𝑗 𝑅 + 𝜇𝑗(
𝑦𝑗∈𝑊𝑏𝑚 𝑌𝑗
𝑅𝑗 𝑦𝑗 − 1) 𝜖𝑀𝑗 𝜖𝑅𝑗(𝑦𝑗) = 0 ⇒ 𝑅𝑗 𝑦𝑗 = 1 𝑓1−𝜇𝑗 exp
𝑏:𝑌𝑗∈𝒪 𝑏 𝒚𝑏
𝑅 𝒚𝑏|𝑦𝑗 log 𝑔
𝑏(𝒚𝑏)
Update rule: We can optimize each 𝑅𝑗 given values for other potentials
45
𝑏:𝑌𝑗∈𝒪 𝑏 𝒚𝑏
𝑏(𝒚𝑏, 𝑦𝑗)
Coordinate ascent algorithm repeatedly optimizes a single
While not converged Iterate over each of the variables 𝑗 ∈ 𝒲 Maximize the objective function with respect to 𝑅𝑗 𝑦𝑗 , ∀𝑦𝑗 ∈ 𝑊𝑏𝑚 𝑌𝑗 by the above formula.
All these terms involve expectations of variables other than 𝑌𝑗 and do not depend on the choice of 𝑅𝑗 𝑌𝑗 . block coordinate descent
46
ℒ𝑗 is concave in 𝑅𝑗(𝑌𝑗)
Update of 𝑅𝑗 is guaranteed to increase (or not decrease) ℒ
Mean Field iterations are guaranteed to converge.
Each step of coordinate ascent procedure is monotonically non-
decreasing in ℒ.
Because ℒ is bounded, the sequence of distributions represented by
successive iterations of Mean-Field must converge.
At the convergence point, the fixed-point equations hold for
As a consequence, the convergence point is a stationary point of the
energy functional subject to the constraints
The result of the mean field approximation is a local maximum,
47
When updating 𝑅𝑘, we only need to reason about the
𝑘
the expectations required to evaluate 𝑅𝑘 involve only those
the other terms get absorbed into the constant term.
The optimization of 𝑅𝑘 can therefore be expressed as a local
48
Each algorithm can be explained via two perspective:
Message-passing algorithm
As a on-way of solving the optimization problem
Constrained optimization
49
𝑄 𝒚 = 1 𝑎 exp
(𝑗,𝑘)∈ℰ
𝜄𝑗𝑘 𝑦𝑗, 𝑦𝑘 +
𝑗∈𝒲
𝜄𝑗 𝑦𝑗 𝑅∗ = argmax
𝑅∈
ℒ 𝑅 𝑅 𝒚 =
𝑗=1 𝑂
𝑅𝑗(𝑦𝑗)
𝑦𝑗∈𝑊𝑏𝑚 𝑌𝑗
𝑅𝑗(𝑦𝑗) = 1
Subject to:
50
𝑄 : Pairwise MRF
𝑄 𝒚 = 1 𝑎
(𝑗,𝑘)∈ℰ
𝜚𝑗𝑘 𝑦𝑗, 𝑦𝑘
𝑗∈𝒲
𝜚𝑗 𝑦𝑗 𝑄 𝒚 = 1 𝑎 exp
(𝑗,𝑘)∈ℰ
𝜄𝑗𝑘 𝑦𝑗, 𝑦𝑘 +
𝑗∈𝒲
𝜄𝑗 𝑦𝑗 𝑅𝑗 𝑦𝑗 = 1 𝑎𝑗 exp 𝜄𝑗 𝑦𝑗 +
𝑘∈𝒪 𝑗 𝑦𝑘
𝑅𝑘 𝑦𝑘 𝜄𝑗𝑘 𝑦𝑗, 𝑦𝑘 ⇒ 𝑅𝑗 𝑦𝑗 ∝ 𝜚𝑗 𝑦𝑗
𝑘∈𝒪 𝑗
𝑛𝑘𝑗 𝑦𝑗 𝑛𝑘𝑗 𝑦𝑗 ∝ exp
𝑦𝑘
𝑅𝑘 𝑦𝑘 𝜄𝑗𝑘 𝑦𝑗, 𝑦𝑘 𝜄𝑗 = ln 𝜚𝑗 𝜄𝑗𝑘 = ln 𝜚𝑗𝑘
51
𝑄 𝒚 = 1 𝑎
(𝑗,𝑘)∈ℰ
𝜚𝑗𝑘 𝑦𝑗, 𝑦𝑘
𝑗∈𝒲
𝜚𝑗 𝑦𝑗
Mean Field:
𝑅𝑗 𝑦𝑗 ∝ 𝜚𝑗 𝑦𝑗
𝑘∈𝒪 𝑗
𝑛𝑘𝑗 𝑦𝑗 𝑛𝑘𝑗 𝑦𝑗 ∝ exp
𝑦𝑘
𝑅𝑘 𝑦𝑘 𝜄𝑗𝑘 𝑦𝑗, 𝑦𝑘
Belief propagation (sum product)
𝑐𝑗(𝑦𝑗) ∝ 𝜚𝑗 𝑦𝑗
𝑘∈𝒪 𝑗
𝑛𝑘𝑗 𝑦𝑗 𝑛𝑗𝑘(𝑦𝑘) ∝
𝑦𝑗
𝜚𝑗 𝑦𝑗 𝜚𝑗𝑘 𝑦𝑗, 𝑦𝑘
𝑙∈𝒪 𝑗 \𝑘
𝑛𝑙𝑗 𝑦𝑗
𝜄𝑗 𝑦𝑗, 𝑦𝑘 = ln 𝜚𝑗𝑘 𝑦𝑗, 𝑦𝑘
52
Mean field methods are all very similar
just compute each node’s full conditional, and average out the neighbors
𝑄 𝒚 =
𝑗
𝑄(𝑦𝑗|𝑄𝑏𝑦𝑗) ln 𝑅 𝑦𝑘 = 𝐹𝑅−𝑘
𝑘,𝐷ℎ𝑘
ln 𝑄 𝑦𝑗|𝑄𝑏𝑗 + const
It is possible to derive a general purpose set of update equations that work for
any DGM for which all CPDs are in the exponential family, and for which all parent nodes have conjugate distributions Updating nodes one at a time
updating posterior beliefs using local operations at each node. each update increases a lower bound on the log evidence (unless already
at a local maximum)
53
can be carried out, given that the model is conjugate-
ln 𝑅 𝑦𝑘 = 𝐹𝑅−𝑘
𝑗=1 𝑁
ln 𝑄 𝑦𝑗|𝑄𝑏𝑗 + const ln 𝑅 𝑦𝑘 = 𝐹𝑅−𝑘 ln 𝑄 𝑦𝑘|𝑄𝑏𝑘 +
𝑙∈𝐷ℎ𝑗𝑚𝑒𝑘
𝐹𝑅−𝑘 ln 𝑄 𝑦𝑙|𝑄𝑏𝑙 + const ln 𝑅 𝑦𝑘 = 𝐹𝑅−𝑘 𝜃 𝑄𝑏𝑘
𝑈𝑈(𝑦𝑘) + ln ℎ(𝑦𝑘) + 𝑙∈𝐷ℎ𝑗𝑚𝑒𝑘
𝐹𝑅−𝑘 𝜃 𝑦𝑙, 𝑑𝑞𝑙 𝑈𝑈(𝑦𝑘) + const
Winn and Bishop, Variational Message Passing, JMLR 2005.
54
Mean Field
Naïve mean field Structured mean field
55
Naïve mean-field can lead to very poor approximations
we must use a richer class of distributions , which has greater expressive
power (by capturing some of the dependencies in 𝑄)
use network structures of different complexity
subgraph of 𝐻𝑄 over which exact computation of 𝐼[𝑅] is feasible
Example: for grid 𝐻𝑄, a collection of independent chain structures.
Exact inference with such structures is linear
56
𝑄 𝒚 = 1 𝑎
𝑙=1 𝐿
𝜚𝑙 𝒚𝑙 𝑅 𝒚 = 1 𝑎𝑅
𝑘=1 𝐾
𝜔𝑘 𝒚𝑘 𝐺 𝑄, 𝑅 =
𝑙=1 𝐿
𝐹𝑅 ln 𝜚𝑙 𝒚𝑙 − 𝐹𝑅 ln 𝑅 𝐺 𝑄, 𝑅 =
𝑙=1 𝐿
𝐹𝑅 ln 𝜚𝑙 𝒚𝑙 −
𝑘=1 𝐾
𝐹𝑅 ln 𝜔𝑘 𝒚𝑘 + ln 𝑎𝑅
57
𝜔𝑘 is a stationary point of the energy functional iff: 𝜔𝑘 𝒚𝑘 ∝ exp 𝐹𝑅 log 𝑄(𝒚) |𝒚𝑘 −
𝑙≠𝑘
𝐹𝑅 log 𝜔𝑙 𝒚𝑙 |𝒚𝑘 𝜔𝑘 𝒚𝑘 ∝ exp
𝑗
𝐹𝑅 log 𝜚𝑗(𝒚𝑗) |𝒚𝑘 −
𝑙≠𝑘
𝐹𝑅 log 𝜔𝑙 𝒚𝑙 |𝒚𝑘
We need to perform inference after each update step 𝜔𝑘 𝒚𝑘 does not affect the right-hand side of the fixed-point equations defining its value
58
Both the quality and the computational complexity of the
We want to be able to perform efficient inference in the
we often select our network so that the resulting factorization
leads to a tractable network (that is, one of low tree-width)
59
Ghahramani and Jordan, Factorial Hidden Markov Models, Machine Learning 1997.
60
𝑃(𝑈𝑁𝐿𝑁+1) We can use variational inference
Ghahramani and Jordan, Factorial Hidden Markov Models, Machine Learning 1997.
61
Ghahramani and Jordan, Factorial Hidden Markov Models, Machine Learning 1997.
62
A fixed point iteration procedure that tries to minimize an
Start with initialization of all messages to one While not converged do
Compute (i.e. update) message on all the edges
LBP optimizes approximate versions of the energy functional
approximate 𝐺[𝑄, 𝑅] works directly with pseudomarginals which may not be consistent with
any joint distribution
LBP does not always converge, and even when it does, it may
63
If BP is used on graphs with loops, messages may circulate
But we can run it anyway and hope for the best Stop message passing when
fixed number of iterations is reached or when no significant change in beliefs is occurred
Empirically, a good approximation can be achievable
If solution is not oscillatory but converges, it usually is a good
64
C.M. Bishop, “Pattern Recognition and Machine Learning”,
D. Koller and N. Friedman, “Probabilistic Graphical Models:
65
[D.M. Blei, A.Y. Ng, M.I. Jordan, 2003]
66
67
68
69
70
71
𝑞(𝒙|𝛽, 𝛾) under LDA for three words and four topics
72
73
The problematic coupling between 𝜾 and 𝛾 arises due to
74
=
75
76
Variational EM procedure
E
and
77
78
79
We assume that each row of 𝛾 is independently drawn
80
Initialization: Randomly select starting distribution 𝜾1 Repeat
E-Step: Given parameters, find posterior of hidden data
𝑢+1 = argmax 𝑅ℋ
𝑄, 𝑅ℋ, 𝑅𝜾 𝑢
M-Step: Given posterior distributions, find likely parameters
𝑢+1 = argmax 𝑅𝜾
𝑄, 𝑅ℋ 𝑢+1, 𝑅𝜾
Until convergence
81
𝑟 𝛾1:𝑙, 𝒜1:𝑁, 𝜄1:𝑁 = 𝑟 𝛾1:𝑙 𝑟 𝒜1:𝑁, 𝜄1:𝑁 ⇒ 𝑟 𝛾1:𝑙, 𝒜1:𝑁, 𝜄1:𝑁 =
𝑙=1 𝐿
𝑟 𝛾𝑙
𝑛=1 𝑁
𝑟 𝜄𝑒, 𝒜𝑒
82
83
84
D.M. Blei, A.Y. Ng, M.I. Jordan, “Latent Dirichlet Allocation”,