Probabilistic & Unsupervised Learning Factored Variational - - PowerPoint PPT Presentation

probabilistic unsupervised learning factored variational
SMART_READER_LITE
LIVE PREVIEW

Probabilistic & Unsupervised Learning Factored Variational - - PowerPoint PPT Presentation

Probabilistic & Unsupervised Learning Factored Variational Approximations and Variational Bayes Maneesh Sahani maneesh@gatsby.ucl.ac.uk Gatsby Computational Neuroscience Unit, and MSc ML/CSML, Dept Computer Science University College


slide-1
SLIDE 1

Probabilistic & Unsupervised Learning Factored Variational Approximations and Variational Bayes

Maneesh Sahani

maneesh@gatsby.ucl.ac.uk

Gatsby Computational Neuroscience Unit, and MSc ML/CSML, Dept Computer Science University College London Term 1, Autumn 2015

slide-2
SLIDE 2

Expectations in Statistical Modelling

◮ Parameter estimation

ˆ θ = argmax

θ

  • dY P(Y|θ)P(X|Y, θ)

(or, using EM)

θnew = argmax

θ

  • dY P(Y|X, θold) log P(X, Y|θ)

◮ Prediction

p(x|D, m) =

  • dθ p(θ|D, m)p(x|θ, D, m)

◮ Model selection or weighting (by marginal likelihood)

p(D|m) =

  • dθ p(θ|m)p(D|θ, m)

These integrals are often intractable:

◮ Analytic intractability: integrals may not have closed form in non-linear, non-Gaussian

models ⇒ numerical integration.

◮ Computational intractability: Numerical integral (or sum if Y or θ are discrete) may be

exponential in data or model size.

slide-3
SLIDE 3

Examples of Intractability

◮ Marginal likelihood/model evidence for Mixture of Gaussians: exact computations are

exponential in number of data points p(x1, . . . , xN) =

  • dθ p(θ)

N

  • i=1
  • si

p(xi|si, θ)p(si|θ)

=

  • s1
  • s2

. . .

  • sN
  • dθ p(θ)

N

  • i=1

p(xi|si, θ)p(si|θ)

◮ Computing the conditional probabilities in a very large multiply-connected DAG:

p(xi|Xj = a) =

  • all settings of y\{i,j}

p(xi, y, Xj = a)/p(Xj = a)

◮ Computing the hidden state distribution in a general nonlinear dynamical system

p(yt|x1, . . . , xt) ∝

  • dyt−1p
  • yt|f(yt−1)
  • p
  • xt|g(yt)
  • p(yt−1|x1, . . . , xt−1)
slide-4
SLIDE 4

Distributed models

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT Consider an FHMM with M state variables taking on K values each.

slide-5
SLIDE 5

Distributed models

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT Consider an FHMM with M state variables taking on K values each.

◮ Moralisation puts simultaneous states (s(1) t

, s(2)

t

, . . . , s(M)

t

) into a single clique

slide-6
SLIDE 6

Distributed models

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT Consider an FHMM with M state variables taking on K values each.

◮ Moralisation puts simultaneous states (s(1) t

, s(2)

t

, . . . , s(M)

t

) into a single clique

◮ Triangulation extends cliques to size M + 1

slide-7
SLIDE 7

Distributed models

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT Consider an FHMM with M state variables taking on K values each.

◮ Moralisation puts simultaneous states (s(1) t

, s(2)

t

, . . . , s(M)

t

) into a single clique

◮ Triangulation extends cliques to size M + 1 ◮ Each state takes K values ⇒ sums over K M+1 terms.

slide-8
SLIDE 8

Distributed models

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT Consider an FHMM with M state variables taking on K values each.

◮ Moralisation puts simultaneous states (s(1) t

, s(2)

t

, . . . , s(M)

t

) into a single clique

◮ Triangulation extends cliques to size M + 1 ◮ Each state takes K values ⇒ sums over K M+1 terms. ◮ Factorial prior ⇒ Factorial posterior (explaining away).

slide-9
SLIDE 9

Distributed models

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT Consider an FHMM with M state variables taking on K values each.

◮ Moralisation puts simultaneous states (s(1) t

, s(2)

t

, . . . , s(M)

t

) into a single clique

◮ Triangulation extends cliques to size M + 1 ◮ Each state takes K values ⇒ sums over K M+1 terms. ◮ Factorial prior ⇒ Factorial posterior (explaining away).

Variational methods approximate the posterior, often in a factored form.

slide-10
SLIDE 10

Distributed models

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT Consider an FHMM with M state variables taking on K values each.

◮ Moralisation puts simultaneous states (s(1) t

, s(2)

t

, . . . , s(M)

t

) into a single clique

◮ Triangulation extends cliques to size M + 1 ◮ Each state takes K values ⇒ sums over K M+1 terms. ◮ Factorial prior ⇒ Factorial posterior (explaining away).

Variational methods approximate the posterior, often in a factored form. To see how they work, we need to review the free-energy interpretation of EM.

slide-11
SLIDE 11

The Free Energy for a Latent Variable Model

Observed data X = {xi}; Latent variables Y = {yi}; Parameters θ. Goal: Maximize the log likelihood wrt θ (i.e. ML learning):

ℓ(θ) = log P(X|θ) = log

  • P(Y, X|θ)dY
slide-12
SLIDE 12

The Free Energy for a Latent Variable Model

Observed data X = {xi}; Latent variables Y = {yi}; Parameters θ. Goal: Maximize the log likelihood wrt θ (i.e. ML learning):

ℓ(θ) = log P(X|θ) = log

  • P(Y, X|θ)dY

Any distribution, q(Y), over the hidden variables can be used to obtain a lower bound on the log likelihood using Jensen’s inequality:

ℓ(θ) = log

  • q(Y)P(Y, X|θ)

q(Y) dY ≥

  • q(Y) log P(Y, X|θ)

q(Y) dY

def

= F(q, θ)

slide-13
SLIDE 13

The Free Energy for a Latent Variable Model

Observed data X = {xi}; Latent variables Y = {yi}; Parameters θ. Goal: Maximize the log likelihood wrt θ (i.e. ML learning):

ℓ(θ) = log P(X|θ) = log

  • P(Y, X|θ)dY

Any distribution, q(Y), over the hidden variables can be used to obtain a lower bound on the log likelihood using Jensen’s inequality:

ℓ(θ) = log

  • q(Y)P(Y, X|θ)

q(Y) dY ≥

  • q(Y) log P(Y, X|θ)

q(Y) dY

def

= F(q, θ)

  • q(Y) log P(Y, X|θ)

q(Y) dY =

  • q(Y) log P(Y, X|θ) dY −
  • q(Y) log q(Y) dY

=

  • q(Y) log P(Y, X|θ) dY + H[q],

where H[q] is the entropy of q(Y). So: F(q, θ) = log P(Y, X|θ)q(Y) + H[q]

slide-14
SLIDE 14

The E and M steps of EM

The log likelihood is bounded below by:

F(q, θ) = log P(Y, X|θ)q(Y) + H[q] = ℓ(θ) − KL[q(Y)P(Y|X, θ)]

EM alternates between: E step: optimise F(q, θ) wrt distribution over hidden variables holding parameters fixed: q(k)(Y) := argmax

q(Y)

F

  • q(Y), θ(k−1)

= P(Y|X, θ(k−1))

M step: maximise F(q, θ) wrt parameters holding hidden distribution fixed:

θ(k) := argmax

θ

F

  • q(k)(Y), θ
  • = argmax

θ

log P(Y, X|θ)q(k)(Y)

slide-15
SLIDE 15

EM as Coordinate Ascent in F

slide-16
SLIDE 16

EM Never Decreases the Likelihood

The E and M steps together never decrease the log likelihood:

  • θ(k−1)

=

E step

F

  • q(k), θ(k−1)

M step

F

  • q(k), θ(k)

Jensen

  • θ(k)

,

◮ The E step brings the free energy to the likelihood. ◮ The M-step maximises the free energy wrt θ. ◮ F ≤ ℓ by Jensen – or, equivalently, from the non-negativity of KL

If the M-step is executed so that θ(k) = θ(k−1) iff F increases, then the overall EM iteration will step to a new value of θ iff the likelihood increases.

slide-17
SLIDE 17

Intractability

The M-step for a graphical model is usually (relatively) easy. A B C D E

A B C D E P(A, B, C, D, E) = P(A)P(B)P(C|A, B)

  • f1(A,B,C)

P(D|B, C)

  • f2(B,C,D)

P(E|C, D)

  • f3(C,D,E)

◮ Need expected sufficient stats from marginal posteriors on each factor group.

slide-18
SLIDE 18

Intractability

The M-step for a graphical model is usually (relatively) easy. A B C D E

A B C D E P(A, B, C, D, E) = P(A)P(B)P(C|A, B)

  • f1(A,B,C)

P(D|B, C)

  • f2(B,C,D)

P(E|C, D)

  • f3(C,D,E)

◮ Need expected sufficient stats from marginal posteriors on each factor group. ◮ Then (at least for a DAG) can optimise each factor parameter vector separately.

slide-19
SLIDE 19

Intractability

The M-step for a graphical model is usually (relatively) easy. A B C D E

A B C D E P(A, B, C, D, E) = P(A)P(B)P(C|A, B)

  • f1(A,B,C)

P(D|B, C)

  • f2(B,C,D)

P(E|C, D)

  • f3(C,D,E)

◮ Need expected sufficient stats from marginal posteriors on each factor group. ◮ Then (at least for a DAG) can optimise each factor parameter vector separately. ◮ Intractability in EM comes from the difficulty of computing marginal posteriors in graphs

with large tree-width or non-linear/non-conjugate conditionals.

slide-20
SLIDE 20

Intractability

The M-step for a graphical model is usually (relatively) easy. A B C D E

A B C D E P(A, B, C, D, E) = P(A)P(B)P(C|A, B)

  • f1(A,B,C)

P(D|B, C)

  • f2(B,C,D)

P(E|C, D)

  • f3(C,D,E)

◮ Need expected sufficient stats from marginal posteriors on each factor group. ◮ Then (at least for a DAG) can optimise each factor parameter vector separately. ◮ Intractability in EM comes from the difficulty of computing marginal posteriors in graphs

with large tree-width or non-linear/non-conjugate conditionals.

◮ [For non-DAG models, partition function (normalising constant) may also be intractable.]

slide-21
SLIDE 21

Free-energy-based variational approximation

What if finding expected sufficient stats under P(Y|X, θ) is computationally intractable?

slide-22
SLIDE 22

Free-energy-based variational approximation

What if finding expected sufficient stats under P(Y|X, θ) is computationally intractable? For the generalised EM algorithm, we argued that intractable maximisations could be replaced by gradient M-steps.

◮ Each step increases the likelihood. ◮ A fixed point of the gradient M-step must be at a mode of the expected log-joint.

slide-23
SLIDE 23

Free-energy-based variational approximation

What if finding expected sufficient stats under P(Y|X, θ) is computationally intractable? For the generalised EM algorithm, we argued that intractable maximisations could be replaced by gradient M-steps.

◮ Each step increases the likelihood. ◮ A fixed point of the gradient M-step must be at a mode of the expected log-joint.

For the E-step we could:

◮ Parameterise q = qρ(Y) and take a gradient step in ρ. ◮ Assume some simplified form for q, usually factored: q = i qi(Yi) where Yi partition

Y, and maximise within this form.

slide-24
SLIDE 24

Free-energy-based variational approximation

What if finding expected sufficient stats under P(Y|X, θ) is computationally intractable? For the generalised EM algorithm, we argued that intractable maximisations could be replaced by gradient M-steps.

◮ Each step increases the likelihood. ◮ A fixed point of the gradient M-step must be at a mode of the expected log-joint.

For the E-step we could:

◮ Parameterise q = qρ(Y) and take a gradient step in ρ. ◮ Assume some simplified form for q, usually factored: q = i qi(Yi) where Yi partition

Y, and maximise within this form.

In either case, we choose q from within a limited set Q: VE step: maximise F(q, θ) wrt constrained latent distribution given parameters: q(k)(Y) := argmax

q(Y)∈Q←Constraint

F

  • q(Y), θ(k−1)

.

M step: unchanged

θ(k) := argmax

θ

F

  • q(k)(Y), θ
  • = argmax

θ

  • q(k)(Y) log p(Y, X|θ)dY,

Unlike in GEM, the fixed point may not be at an unconstrained optimum of F.

slide-25
SLIDE 25

What do we lose?

What does restricting q to Q cost us?

slide-26
SLIDE 26

What do we lose?

What does restricting q to Q cost us?

◮ Recall that the free-energy is bounded above by Jensen:

F(q, θ) ≤ ℓ(θML)

Thus, as long as every step increases F, convergence is still guaranteed.

slide-27
SLIDE 27

What do we lose?

What does restricting q to Q cost us?

◮ Recall that the free-energy is bounded above by Jensen:

F(q, θ) ≤ ℓ(θML)

Thus, as long as every step increases F, convergence is still guaranteed.

◮ But, since P(Y|X, θ(k)) may not lie in Q, we no longer saturate the bound after the

E-step. Thus, the likelihood may not increase on each full EM step.

  • θ(k−1)
  • =

E step

F

  • q(k), θ(k−1)

M step

F

  • q(k), θ(k)

Jensen

  • θ(k)

,

slide-28
SLIDE 28

What do we lose?

What does restricting q to Q cost us?

◮ Recall that the free-energy is bounded above by Jensen:

F(q, θ) ≤ ℓ(θML)

Thus, as long as every step increases F, convergence is still guaranteed.

◮ But, since P(Y|X, θ(k)) may not lie in Q, we no longer saturate the bound after the

E-step. Thus, the likelihood may not increase on each full EM step.

  • θ(k−1)
  • =

E step

F

  • q(k), θ(k−1)

M step

F

  • q(k), θ(k)

Jensen

  • θ(k)

,

◮ This means we may not converge to a maximum of ℓ.

slide-29
SLIDE 29

What do we lose?

What does restricting q to Q cost us?

◮ Recall that the free-energy is bounded above by Jensen:

F(q, θ) ≤ ℓ(θML)

Thus, as long as every step increases F, convergence is still guaranteed.

◮ But, since P(Y|X, θ(k)) may not lie in Q, we no longer saturate the bound after the

E-step. Thus, the likelihood may not increase on each full EM step.

  • θ(k−1)
  • =

E step

F

  • q(k), θ(k−1)

M step

F

  • q(k), θ(k)

Jensen

  • θ(k)

,

◮ This means we may not converge to a maximum of ℓ.

The hope is that by increasing a lower bound on ℓ we will find a decent solution.

slide-30
SLIDE 30

What do we lose?

What does restricting q to Q cost us?

◮ Recall that the free-energy is bounded above by Jensen:

F(q, θ) ≤ ℓ(θML)

Thus, as long as every step increases F, convergence is still guaranteed.

◮ But, since P(Y|X, θ(k)) may not lie in Q, we no longer saturate the bound after the

E-step. Thus, the likelihood may not increase on each full EM step.

  • θ(k−1)
  • =

E step

F

  • q(k), θ(k−1)

M step

F

  • q(k), θ(k)

Jensen

  • θ(k)

,

◮ This means we may not converge to a maximum of ℓ.

The hope is that by increasing a lower bound on ℓ we will find a decent solution. [Note that if P(Y|X, θML) ∈ Q, then θML is a fixed point of the variational algorithm.]

slide-31
SLIDE 31

KL divergence

Recall that

F(q, θ) = log P(X, Y|θ)q(Y) + H[q] = log P(X|θ) + log P(Y|X, θ)q(Y) − log q(Y)q(Y) = log P(X|θ)q(Y) − KL[qP(Y|X, θ)].

Thus, E step maximise F(q, θ) wrt the distribution over latents, given parameters: q(k)(Y) := argmax

q(Y)∈Q

F

  • q(Y), θ(k−1)

.

is equivalent to: E step minimise KL[qp(Y|X, θ)] wrt distribution over latents, given parameters: q(k)(Y) := argmin

q(Y)∈Q

  • q(Y) log

q(Y) p(Y|X, θ(k−1))dY So, in each E step, the algorithm is trying to find the best approximation to P(Y|X) in Q in a KL sense. This is related to ideas in information geometry. It also suggests generalisations to

  • ther distance measures.
slide-32
SLIDE 32

Factored Variational E-step

The most common form of variational approximation partitions Y into disjoint sets Yi with

Q =

  • q
  • q(Y) =
  • i

qi(Yi)

  • .
slide-33
SLIDE 33

Factored Variational E-step

The most common form of variational approximation partitions Y into disjoint sets Yi with

Q =

  • q
  • q(Y) =
  • i

qi(Yi)

  • .

In this case the E-step is itself iterative: (Factored VE step)i: maximise F(q, θ) wrt qi(Yi) given other qj and parameters: q(k)

i

(Yi) := argmax

qi(Yi)

F

  • qi(Yi)
  • j=i

qj(Yj), θ(k−1)

.

slide-34
SLIDE 34

Factored Variational E-step

The most common form of variational approximation partitions Y into disjoint sets Yi with

Q =

  • q
  • q(Y) =
  • i

qi(Yi)

  • .

In this case the E-step is itself iterative: (Factored VE step)i: maximise F(q, θ) wrt qi(Yi) given other qj and parameters: q(k)

i

(Yi) := argmax

qi(Yi)

F

  • qi(Yi)
  • j=i

qj(Yj), θ(k−1)

.

◮ qi updates iterated to convergence to “complete” VE-step.

slide-35
SLIDE 35

Factored Variational E-step

The most common form of variational approximation partitions Y into disjoint sets Yi with

Q =

  • q
  • q(Y) =
  • i

qi(Yi)

  • .

In this case the E-step is itself iterative: (Factored VE step)i: maximise F(q, θ) wrt qi(Yi) given other qj and parameters: q(k)

i

(Yi) := argmax

qi(Yi)

F

  • qi(Yi)
  • j=i

qj(Yj), θ(k−1)

.

◮ qi updates iterated to convergence to “complete” VE-step. ◮ In fact, every (VE)i-step separately increases F, so any schedule of (VE)i- and M-steps

will converge. Choice can be dictated by practical issues (rarely efficient to fully converge E-step before updating parameters).

slide-36
SLIDE 36

Factored Variational E-step

The Factored Variational E-step has a general form.

slide-37
SLIDE 37

Factored Variational E-step

The Factored Variational E-step has a general form. The free energy is:

F

j

qj(Yj), θ(k−1)

=

  • log P(X, Y|θ(k−1))
  • j qj(Yj) + H
  • j

qj(Yj)

slide-38
SLIDE 38

Factored Variational E-step

The Factored Variational E-step has a general form. The free energy is:

F

j

qj(Yj), θ(k−1)

=

  • log P(X, Y|θ(k−1))
  • j qj(Yj) + H
  • j

qj(Yj)

  • =
  • dYi qi(Yi)
  • log P(X, Y|θ(k−1))
  • j=i qj(Yj)+ H[qi] +
  • j=i

H[qj]

slide-39
SLIDE 39

Factored Variational E-step

The Factored Variational E-step has a general form. The free energy is:

F

j

qj(Yj), θ(k−1)

=

  • log P(X, Y|θ(k−1))
  • j qj(Yj) + H
  • j

qj(Yj)

  • =
  • dYi qi(Yi)
  • log P(X, Y|θ(k−1))
  • j=i qj(Yj)+ H[qi] +
  • j=i

H[qj] Now, taking the variational derivative of the Lagrangian (enforcing normalisation of qi):

δ δqi

  • F + λ
  • qi − 1
  • =
slide-40
SLIDE 40

Factored Variational E-step

The Factored Variational E-step has a general form. The free energy is:

F

j

qj(Yj), θ(k−1)

=

  • log P(X, Y|θ(k−1))
  • j qj(Yj) + H
  • j

qj(Yj)

  • =
  • dYi qi(Yi)
  • log P(X, Y|θ(k−1))
  • j=i qj(Yj)+ H[qi] +
  • j=i

H[qj] Now, taking the variational derivative of the Lagrangian (enforcing normalisation of qi):

δ δqi

  • F + λ
  • qi − 1
  • =
  • log P(X, Y|θ(k−1))
  • j=i qj(Yj) − log qi(Yi) − qi(Yi)

qi(Yi) + λ

slide-41
SLIDE 41

Factored Variational E-step

The Factored Variational E-step has a general form. The free energy is:

F

j

qj(Yj), θ(k−1)

=

  • log P(X, Y|θ(k−1))
  • j qj(Yj) + H
  • j

qj(Yj)

  • =
  • dYi qi(Yi)
  • log P(X, Y|θ(k−1))
  • j=i qj(Yj)+ H[qi] +
  • j=i

H[qj] Now, taking the variational derivative of the Lagrangian (enforcing normalisation of qi):

δ δqi

  • F + λ
  • qi − 1
  • =
  • log P(X, Y|θ(k−1))
  • j=i qj(Yj) − log qi(Yi) − qi(Yi)

qi(Yi) + λ

(= 0) ⇒

qi(Yi) ∝ exp

  • log P(X, Y|θ(k−1))
  • j=i qj(Yj)
slide-42
SLIDE 42

Factored Variational E-step

The Factored Variational E-step has a general form. The free energy is:

F

j

qj(Yj), θ(k−1)

=

  • log P(X, Y|θ(k−1))
  • j qj(Yj) + H
  • j

qj(Yj)

  • =
  • dYi qi(Yi)
  • log P(X, Y|θ(k−1))
  • j=i qj(Yj)+ H[qi] +
  • j=i

H[qj] Now, taking the variational derivative of the Lagrangian (enforcing normalisation of qi):

δ δqi

  • F + λ
  • qi − 1
  • =
  • log P(X, Y|θ(k−1))
  • j=i qj(Yj) − log qi(Yi) − qi(Yi)

qi(Yi) + λ

(= 0) ⇒

qi(Yi) ∝ exp

  • log P(X, Y|θ(k−1))
  • j=i qj(Yj)

In general, this depends only on the expected sufficient statistics under qj. Thus, again, we don’t actually need the entire distributions, just the relevant expectations (now for approximate inference as well as learning).

slide-43
SLIDE 43

Mean-field approximations

If Yi = yi (i.e., q is factored over all variables) then the variational technique is often called a “mean field” approximation.

slide-44
SLIDE 44

Mean-field approximations

If Yi = yi (i.e., q is factored over all variables) then the variational technique is often called a “mean field” approximation.

◮ Suppose P(X, Y) has sufficient statistics that are separable in the latent variables:

e.g. the Boltzmann machine P(X, Y) = 1 Z exp

ij

Wijsisj +

  • i

bisi

  • with some si ∈ Y and others observed.
slide-45
SLIDE 45

Mean-field approximations

If Yi = yi (i.e., q is factored over all variables) then the variational technique is often called a “mean field” approximation.

◮ Suppose P(X, Y) has sufficient statistics that are separable in the latent variables:

e.g. the Boltzmann machine P(X, Y) = 1 Z exp

ij

Wijsisj +

  • i

bisi

  • with some si ∈ Y and others observed.

◮ Expectations wrt a fully-factored q distribute over all si ∈ Y

log P(X, Y)

qi =

  • ij

Wijsiqi sjqj +

  • i

bisiqi (where qi for si ∈ X is a delta function on the observed value).

slide-46
SLIDE 46

Mean-field approximations

If Yi = yi (i.e., q is factored over all variables) then the variational technique is often called a “mean field” approximation.

◮ Suppose P(X, Y) has sufficient statistics that are separable in the latent variables:

e.g. the Boltzmann machine P(X, Y) = 1 Z exp

ij

Wijsisj +

  • i

bisi

  • with some si ∈ Y and others observed.

◮ Expectations wrt a fully-factored q distribute over all si ∈ Y

log P(X, Y)

qi =

  • ij

Wijsiqi sjqj +

  • i

bisiqi (where qi for si ∈ X is a delta function on the observed value).

◮ Thus, we can update each qi in turn given the means (or, in general, mean sufficient

statistics) of the others.

slide-47
SLIDE 47

Mean-field approximations

If Yi = yi (i.e., q is factored over all variables) then the variational technique is often called a “mean field” approximation.

◮ Suppose P(X, Y) has sufficient statistics that are separable in the latent variables:

e.g. the Boltzmann machine P(X, Y) = 1 Z exp

ij

Wijsisj +

  • i

bisi

  • with some si ∈ Y and others observed.

◮ Expectations wrt a fully-factored q distribute over all si ∈ Y

log P(X, Y)

qi =

  • ij

Wijsiqi sjqj +

  • i

bisiqi (where qi for si ∈ X is a delta function on the observed value).

◮ Thus, we can update each qi in turn given the means (or, in general, mean sufficient

statistics) of the others.

◮ Each variable sees the mean field imposed by its neighbours, and we update these

fields until they all agree.

slide-48
SLIDE 48

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

slide-49
SLIDE 49

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

slide-50
SLIDE 50

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

qm

t (sm t ) ∝ exp

  • log P(s1:M

1:T , x1:T)

  • ¬(m,t)

qm′

t′ (sm′ t′ )

slide-51
SLIDE 51

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

qm

t (sm t ) ∝ exp

  • log P(s1:M

1:T , x1:T)

  • ¬(m,t)

qm′

t′ (sm′ t′ )

= exp

  • µ
  • τ

log P(sµ

τ |sµ τ–1) +

  • τ

log P(xτ|s1:M

τ )

  • ¬(m,t)

qm′

t′

slide-52
SLIDE 52

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

qm

t (sm t ) ∝ exp

  • log P(s1:M

1:T , x1:T)

  • ¬(m,t)

qm′

t′ (sm′ t′ )

= exp

  • µ
  • τ

log P(sµ

τ |sµ τ–1) +

  • τ

log P(xτ|s1:M

τ )

  • ¬(m,t)

qm′

t′

∝ exp

  • log P(sm

t |sm t–1)

  • qm

t–1+

  • log P(xt|s1:M

t

)

  • ¬m

qm′

t +

  • log P(sm

t+1|sm t )

  • qm

t+1

slide-53
SLIDE 53

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

qm

t (sm t ) ∝ exp

  • log P(s1:M

1:T , x1:T)

  • ¬(m,t)

qm′

t′ (sm′ t′ )

= exp

  • µ
  • τ

log P(sµ

τ |sµ τ–1) +

  • τ

log P(xτ|s1:M

τ )

  • ¬(m,t)

qm′

t′

∝ exp

  • log P(sm

t |sm t–1)

  • qm

t–1+

  • log P(xt|s1:M

t

)

  • ¬m

qm′

t

  • αm

t (i) ∝ e

  • j log Φm

ji qm t–1(j) · e

log Ai(xt )q¬m

t

  • Cf. forward-backward:

αt (i)∝

j αt–1(j)Φji·Ai(xt )

+

  • log P(sm

t+1|sm t )

  • qm

t+1

  • βm

t (i) ∝ e

  • j log Φm

ji qm t+1(j)

βt (i)∝

j Φij Aj(xt+1)βt+1(j)

slide-54
SLIDE 54

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

qm

t (sm t ) ∝ exp

  • log P(sm

t |sm t–1)

  • qm

t–1+

  • log P(xt|s1:M

t

)

  • ¬m

qm′

t

  • αm

t (i) ∝ e

  • j log Φm

ji qm t–1(j) · e

log Ai(xt )q¬m

t

  • Cf. forward-backward:

αt (i)∝

j αt–1(j)Φji·Ai(xt )

+

  • log P(sm

t+1|sm t )

  • qm

t+1

  • βm

t (i) ∝ e

  • j log Φm

ji qm t+1(j)

βt (i)∝

j Φij Aj(xt+1)βt+1(j)

  • ◮ Yields a message-passing algorithm like forward-backward
slide-55
SLIDE 55

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

qm

t (sm t ) ∝ exp

  • log P(sm

t |sm t–1)

  • qm

t–1+

  • log P(xt|s1:M

t

)

  • ¬m

qm′

t

  • αm

t (i) ∝ e

  • j log Φm

ji qm t–1(j) · e

log Ai(xt )q¬m

t

  • Cf. forward-backward:

αt (i)∝

j αt–1(j)Φji·Ai(xt )

+

  • log P(sm

t+1|sm t )

  • qm

t+1

  • βm

t (i) ∝ e

  • j log Φm

ji qm t+1(j)

βt (i)∝

j Φij Aj(xt+1)βt+1(j)

  • ◮ Yields a message-passing algorithm like forward-backward

◮ Updates depend only on immediate neighbours in chain

slide-56
SLIDE 56

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

qm

t (sm t ) ∝ exp

  • log P(sm

t |sm t–1)

  • qm

t–1+

  • log P(xt|s1:M

t

)

  • ¬m

qm′

t

  • αm

t (i) ∝ e

  • j log Φm

ji qm t–1(j) · e

log Ai(xt )q¬m

t

  • Cf. forward-backward:

αt (i)∝

j αt–1(j)Φji·Ai(xt )

+

  • log P(sm

t+1|sm t )

  • qm

t+1

  • βm

t (i) ∝ e

  • j log Φm

ji qm t+1(j)

βt (i)∝

j Φij Aj(xt+1)βt+1(j)

  • ◮ Yields a message-passing algorithm like forward-backward

◮ Updates depend only on immediate neighbours in chain ◮ Chains couple only through joint output

slide-57
SLIDE 57

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

qm

t (sm t ) ∝ exp

  • log P(sm

t |sm t–1)

  • qm

t–1+

  • log P(xt|s1:M

t

)

  • ¬m

qm′

t

  • αm

t (i) ∝ e

  • j log Φm

ji qm t–1(j) · e

log Ai(xt )q¬m

t

  • Cf. forward-backward:

αt (i)∝

j αt–1(j)Φji·Ai(xt )

+

  • log P(sm

t+1|sm t )

  • qm

t+1

  • βm

t (i) ∝ e

  • j log Φm

ji qm t+1(j)

βt (i)∝

j Φij Aj(xt+1)βt+1(j)

  • ◮ Yields a message-passing algorithm like forward-backward

◮ Updates depend only on immediate neighbours in chain ◮ Chains couple only through joint output ◮ Multiple passes; messages depend on (approximate) marginals

slide-58
SLIDE 58

Mean-field FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

q(s1:M

1:T ) =

  • m,t

qm

t (sm t )

qm

t (sm t ) ∝ exp

  • log P(sm

t |sm t–1)

  • qm

t–1+

  • log P(xt|s1:M

t

)

  • ¬m

qm′

t

  • αm

t (i) ∝ e

  • j log Φm

ji qm t–1(j) · e

log Ai(xt )q¬m

t

  • Cf. forward-backward:

αt (i)∝

j αt–1(j)Φji·Ai(xt )

+

  • log P(sm

t+1|sm t )

  • qm

t+1

  • βm

t (i) ∝ e

  • j log Φm

ji qm t+1(j)

βt (i)∝

j Φij Aj(xt+1)βt+1(j)

  • ◮ Yields a message-passing algorithm like forward-backward

◮ Updates depend only on immediate neighbours in chain ◮ Chains couple only through joint output ◮ Multiple passes; messages depend on (approximate) marginals ◮ Evidence does not appear explicitly in backward message (cf Kalman smoothing)

slide-59
SLIDE 59

Structured variational approximation

◮ q(Y) need not be completely factorized.

At Dt Ct Bt At+1 Dt+1 Ct+1 Bt+1

...

At+2 Dt+2 Ct+2 Bt+2

slide-60
SLIDE 60

Structured variational approximation

◮ q(Y) need not be completely factorized. ◮ For example, suppose Y can be partitioned into sets Y1 and Y2 such that computing the

expected sufficient statistics under P(Y1|Y2, X) and P(Y2|Y1, X) would be tractable.

⇒ Then the factored approximation q(Y) = q(Y1)q(Y2) is tractable.

At Dt Ct Bt At+1 Dt+1 Ct+1 Bt+1

...

At+2 Dt+2 Ct+2 Bt+2

slide-61
SLIDE 61

Structured variational approximation

◮ q(Y) need not be completely factorized. ◮ For example, suppose Y can be partitioned into sets Y1 and Y2 such that computing the

expected sufficient statistics under P(Y1|Y2, X) and P(Y2|Y1, X) would be tractable.

⇒ Then the factored approximation q(Y) = q(Y1)q(Y2) is tractable.

◮ In particular, any factorisation of q(Y) into a product of distributions on trees, yields a

tractable approximation. At Dt Ct Bt At+1 Dt+1 Ct+1 Bt+1

...

At+2 Dt+2 Ct+2 Bt+2

slide-62
SLIDE 62

Stuctured FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

slide-63
SLIDE 63

Stuctured FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

For the FHMM we can factor the chains: q(s1:M

1:T ) =

  • m

qm(sm

1:T)

slide-64
SLIDE 64

Stuctured FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

For the FHMM we can factor the chains: q(s1:M

1:T ) =

  • m

qm(sm

1:T)

qm(sm

1:T) ∝ exp

  • log P(s1:M

1:T , x1:T)

  • ¬m

qm′ (sm′

1:T )

= exp

  • µ
  • t

log P(sµ

t |sµ t−1) +

  • t

log P(xt|s1:M

t

)

  • ¬m

qm′

∝ exp

t

log P(sm

t |sm t−1) +

  • t
  • log P(xt|s1:M

t

)

  • ¬m

qm′ (sm′

t

)

  • =
  • t

P(sm

t |sm t−1)

  • t

e

log P(xt |s1:M

t

)

¬m qm′ (sm′ t )

slide-65
SLIDE 65

Stuctured FHMM

s(1)

1

s(1)

2

s(1)

3

s(1)

T

  • • •

s(2)

1

s(2)

2

s(2)

3

s(2)

T

  • • •

s(3)

1

s(3)

2

s(3)

3

s(3)

T

  • • •

x1 x2 x3 xT

For the FHMM we can factor the chains: q(s1:M

1:T ) =

  • m

qm(sm

1:T)

qm(sm

1:T) ∝ exp

  • log P(s1:M

1:T , x1:T)

  • ¬m

qm′ (sm′

1:T )

= exp

  • µ
  • t

log P(sµ

t |sµ t−1) +

  • t

log P(xt|s1:M

t

)

  • ¬m

qm′

∝ exp

t

log P(sm

t |sm t−1) +

  • t
  • log P(xt|s1:M

t

)

  • ¬m

qm′ (sm′

t

)

  • =
  • t

P(sm

t |sm t−1)

  • t

e

log P(xt |s1:M

t

)

¬m qm′ (sm′ t )

This looks like a standard HMM joint, with a modified likelihood term ⇒ cycle through multiple forward-backward passes, updating likelihood terms each time.

slide-66
SLIDE 66

Messages on an arbitrary graph

Consider a DAG: A B C D E P(X, Y) =

  • k

P(Zk| pa(Zk)) and let q(Y) =

i qi(Yi) for disjoint sets {Yi}.

slide-67
SLIDE 67

Messages on an arbitrary graph

Consider a DAG: A B C D E P(X, Y) =

  • k

P(Zk| pa(Zk)) and let q(Y) =

i qi(Yi) for disjoint sets {Yi}.

We have that the VE update for qi is given by q∗

i (Yi) ∝ exp log p(Y, X)q¬i(Y) where

·q¬i(Y) denotes averaging with respect to qj(Yj) for all j = i

slide-68
SLIDE 68

Messages on an arbitrary graph

Consider a DAG: A B C D E P(X, Y) =

  • k

P(Zk| pa(Zk)) and let q(Y) =

i qi(Yi) for disjoint sets {Yi}.

We have that the VE update for qi is given by q∗

i (Yi) ∝ exp log p(Y, X)q¬i(Y) where

·q¬i(Y) denotes averaging with respect to qj(Yj) for all j = i

Then: log q∗

i (Yi) =

  • k

log P(Zk| pa(Zk))

  • q¬i(Y)

+ const =

  • j∈Yi

log P(Yj| pa(Yj))q¬i(Y) +

  • j∈ch(Yi)

log P(Zj| pa(Zj))q¬i(Y) + const

slide-69
SLIDE 69

Messages on an arbitrary graph

Consider a DAG: A B C D E P(X, Y) =

  • k

P(Zk| pa(Zk)) and let q(Y) =

i qi(Yi) for disjoint sets {Yi}.

We have that the VE update for qi is given by q∗

i (Yi) ∝ exp log p(Y, X)q¬i(Y) where

·q¬i(Y) denotes averaging with respect to qj(Yj) for all j = i

Then: log q∗

i (Yi) =

  • k

log P(Zk| pa(Zk))

  • q¬i(Y)

+ const =

  • j∈Yi

log P(Yj| pa(Yj))q¬i(Y) +

  • j∈ch(Yi)

log P(Zj| pa(Zj))q¬i(Y) + const

This defines messages that are passed between nodes in the graph. Each node receives messages from its Markov boundary: parents, children and parents of children (all neighbours in the corresponding factor graph).

slide-70
SLIDE 70

Non-factored variational methods

The term variational approximation is used whenever a bound on the likelihood (or on another estimation cost function) is optimised, but does not necessarily become tight. Many further variational approximations have been developed, including:

◮ parametric forms (e.g. Gaussian) for non-linear models ◮ non-free-energy-based bounds (both upper and lower) on the likelihood.

We can also see MAP- or zero-temperature EM and recognition models as parametric forms

  • f variational inference.
slide-71
SLIDE 71

Non-factored variational methods

The term variational approximation is used whenever a bound on the likelihood (or on another estimation cost function) is optimised, but does not necessarily become tight. Many further variational approximations have been developed, including:

◮ parametric forms (e.g. Gaussian) for non-linear models ◮ non-free-energy-based bounds (both upper and lower) on the likelihood.

We can also see MAP- or zero-temperature EM and recognition models as parametric forms

  • f variational inference.

Variational methods can also be used to find an approximate posterior on the parameters.

slide-72
SLIDE 72

Variational Bayes

So far, we have applied Jensen’s bound and factorisations to help with integrals over latent variables.

slide-73
SLIDE 73

Variational Bayes

So far, we have applied Jensen’s bound and factorisations to help with integrals over latent variables. We can do the same for integrals over parameters in order to bound the log marginal likelihood or evidence. log P(X|M) = log

  • dY dθ P(X, Y|θ, M)P(θ|M)

= argmax

Q

  • dY dθ Q(Y, θ) log P(X, Y, θ|M)

Q(Y, θ)

slide-74
SLIDE 74

Variational Bayes

So far, we have applied Jensen’s bound and factorisations to help with integrals over latent variables. We can do the same for integrals over parameters in order to bound the log marginal likelihood or evidence. log P(X|M) = log

  • dY dθ P(X, Y|θ, M)P(θ|M)

= argmax

Q

  • dY dθ Q(Y, θ) log P(X, Y, θ|M)

Q(Y, θ)

≥ argmax

QY ,Qθ

  • dY dθ QY(Y)Qθ(θ) log P(X, Y, θ|M)

QY(Y)Qθ(θ)

slide-75
SLIDE 75

Variational Bayes

So far, we have applied Jensen’s bound and factorisations to help with integrals over latent variables. We can do the same for integrals over parameters in order to bound the log marginal likelihood or evidence. log P(X|M) = log

  • dY dθ P(X, Y|θ, M)P(θ|M)

= argmax

Q

  • dY dθ Q(Y, θ) log P(X, Y, θ|M)

Q(Y, θ)

≥ argmax

QY ,Qθ

  • dY dθ QY(Y)Qθ(θ) log P(X, Y, θ|M)

QY(Y)Qθ(θ) The constraint that the distribution Q must factor into the product Qy(Y)Qθ(θ) leads to the variational Bayesian EM algorithm or just “Variational Bayes”.

slide-76
SLIDE 76

Variational Bayesian EM . . .

Coordinate maximization of the VB free-energy lower bound

F(QY, Qθ) =

  • dY dθ QY(Y)Qθ(θ) log p(X, Y, θ|M)

QY(Y)Qθ(θ) leads to EM-like updates:

slide-77
SLIDE 77

Variational Bayesian EM . . .

Coordinate maximization of the VB free-energy lower bound

F(QY, Qθ) =

  • dY dθ QY(Y)Qθ(θ) log p(X, Y, θ|M)

QY(Y)Qθ(θ) leads to EM-like updates: Q∗

Y(Y) ∝ exp log P(Y,X|θ)Qθ(θ)

E-like step Q∗

θ(θ) ∝ P(θ) exp log P(Y,X|θ)QY (Y)

M-like step

slide-78
SLIDE 78

Variational Bayesian EM . . .

Coordinate maximization of the VB free-energy lower bound

F(QY, Qθ) =

  • dY dθ QY(Y)Qθ(θ) log p(X, Y, θ|M)

QY(Y)Qθ(θ) leads to EM-like updates: Q∗

Y(Y) ∝ exp log P(Y,X|θ)Qθ(θ)

E-like step Q∗

θ(θ) ∝ P(θ) exp log P(Y,X|θ)QY (Y)

M-like step Maximizing F is equivalent to minimizing KL-divergence between the approximate posterior, Q(θ)Q(Y) and the true posterior, P(θ, Y|X). log P(X) − F(QY, Qθ) = log P(X) −

  • dY dθ QY(Y)Qθ(θ) log

P(X, Y, θ) QY(Y)Qθ(θ)

=

  • dY dθ QY(Y)Qθ(θ) log QY(Y)Qθ(θ)

P(Y, θ|X)

= KL(Q||P)

slide-79
SLIDE 79

Conjugate-Exponential models

Let’s focus on conjugate-exponential (CE) latent-variable models:

◮ Condition (1). The joint probability over variables is in the exponential family:

P(Y, X|θ) = f(Y, X) g(θ) exp

  • φ(θ)TT(Y, X)
  • where φ(θ) is the vector of natural parameters, T are sufficient statistics

◮ Condition (2). The prior over parameters is conjugate to this joint probability:

P(θ|ν, τ) = h(ν, τ) g(θ)ν exp

  • φ(θ)Tτ
  • where ν and τ are hyperparameters of the prior.

Conjugate priors are computationally convenient and have an intuitive interpretation:

◮ ν: number of pseudo-observations ◮ τ: values of pseudo-observations

slide-80
SLIDE 80

Conjugate-Exponential examples

In the CE family:

◮ Gaussian mixtures ◮ factor analysis, probabilistic PCA ◮ hidden Markov models and factorial HMMs ◮ linear dynamical systems and switching models ◮ discrete-variable belief networks

Other as yet undreamt-of models combinations of Gaussian, Gamma, Poisson, Dirichlet, Wishart, Multinomial and others.

Not in the CE family:

◮ Boltzmann machines, MRFs (no simple conjugacy) ◮ logistic regression (no simple conjugacy) ◮ sigmoid belief networks (not exponential) ◮ independent components analysis (not exponential)

Note: one can often approximate such models with a suitable choice from the CE family.

slide-81
SLIDE 81

Conjugate-exponential VB

Given an iid data set D = (x1, . . . xn), if the model is CE then:

slide-82
SLIDE 82

Conjugate-exponential VB

Given an iid data set D = (x1, . . . xn), if the model is CE then:

◮ Qθ(θ) is also conjugate, i.e.

Qθ(θ) ∝ P(θ) exp

  • i log P(yi, xi|θ)
  • QY

= h(ν, τ)g(θ)νeφ(θ)Tτ

g(θ)ne

  • log f(Y,X)
  • QY e

φ(θ)T

i T(yi,xi)

  • QY

∝ h(˜ ν, ˜ τ)g(θ)˜

νeφ(θ)T ˜ τ

with ˜

ν = ν + n and ˜ τ = τ +

i T(yi, xi)QY

slide-83
SLIDE 83

Conjugate-exponential VB

Given an iid data set D = (x1, . . . xn), if the model is CE then:

◮ Qθ(θ) is also conjugate, i.e.

Qθ(θ) ∝ P(θ) exp

  • i log P(yi, xi|θ)
  • QY

= h(ν, τ)g(θ)νeφ(θ)Tτ

g(θ)ne

  • log f(Y,X)
  • QY e

φ(θ)T

i T(yi,xi)

  • QY

∝ h(˜ ν, ˜ τ)g(θ)˜

νeφ(θ)T ˜ τ

with ˜

ν = ν + n and ˜ τ = τ +

i T(yi, xi)QY ◮ QY(Y) = n i=1 Qyi (yi) takes the same form as in the E-step of regular EM

Qyi (yi) ∝ exp log P(yi, xi|θ)Qθ

∝ f(yi, xi)e

φ(θ)T

Qθ T(yi,xi) = P(yi|xi, φ(θ))

with natural parameters φ(θ) = φ(θ)Qθ

slide-84
SLIDE 84

Conjugate-exponential VB

Given an iid data set D = (x1, . . . xn), if the model is CE then:

◮ Qθ(θ) is also conjugate, i.e.

Qθ(θ) ∝ P(θ) exp

  • i log P(yi, xi|θ)
  • QY

= h(ν, τ)g(θ)νeφ(θ)Tτ

g(θ)ne

  • log f(Y,X)
  • QY e

φ(θ)T

i T(yi,xi)

  • QY

∝ h(˜ ν, ˜ τ)g(θ)˜

νeφ(θ)T ˜ τ

with ˜

ν = ν + n and ˜ τ = τ +

i T(yi, xi)QY ⇒ only need to track ˜

ν, ˜ τ.

◮ QY(Y) = n i=1 Qyi (yi) takes the same form as in the E-step of regular EM

Qyi (yi) ∝ exp log P(yi, xi|θ)Qθ

∝ f(yi, xi)e

φ(θ)T

Qθ T(yi,xi) = P(yi|xi, φ(θ))

with natural parameters φ(θ) = φ(θ)Qθ ⇒ inference unchanged from regular EM.

slide-85
SLIDE 85

The Variational Bayesian EM algorithm

EM for MAP estimation Goal: maximize P(θ|X, m) wrt θ E Step: compute QY(Y) ← p(Y|X, θ) M Step:

θ ← argmax

θ

  • dY QY(Y) log P(Y, X, θ)

Variational Bayesian EM Goal: maximise bound on P(X|m) wrt Qθ VB-E Step: compute QY(Y) ← p(Y|X, ¯

φ)

VB-M Step: Qθ(θ) ← exp

  • dY QY(Y) log P(Y, X, θ)
slide-86
SLIDE 86

The Variational Bayesian EM algorithm

EM for MAP estimation Goal: maximize P(θ|X, m) wrt θ E Step: compute QY(Y) ← p(Y|X, θ) M Step:

θ ← argmax

θ

  • dY QY(Y) log P(Y, X, θ)

Variational Bayesian EM Goal: maximise bound on P(X|m) wrt Qθ VB-E Step: compute QY(Y) ← p(Y|X, ¯

φ)

VB-M Step: Qθ(θ) ← exp

  • dY QY(Y) log P(Y, X, θ)

Properties:

◮ Reduces to the EM algorithm if Qθ(θ) = δ(θ − θ∗).

slide-87
SLIDE 87

The Variational Bayesian EM algorithm

EM for MAP estimation Goal: maximize P(θ|X, m) wrt θ E Step: compute QY(Y) ← p(Y|X, θ) M Step:

θ ← argmax

θ

  • dY QY(Y) log P(Y, X, θ)

Variational Bayesian EM Goal: maximise bound on P(X|m) wrt Qθ VB-E Step: compute QY(Y) ← p(Y|X, ¯

φ)

VB-M Step: Qθ(θ) ← exp

  • dY QY(Y) log P(Y, X, θ)

Properties:

◮ Reduces to the EM algorithm if Qθ(θ) = δ(θ − θ∗). ◮ Fm increases monotonically, and incorporates the model complexity penalty.

slide-88
SLIDE 88

The Variational Bayesian EM algorithm

EM for MAP estimation Goal: maximize P(θ|X, m) wrt θ E Step: compute QY(Y) ← p(Y|X, θ) M Step:

θ ← argmax

θ

  • dY QY(Y) log P(Y, X, θ)

Variational Bayesian EM Goal: maximise bound on P(X|m) wrt Qθ VB-E Step: compute QY(Y) ← p(Y|X, ¯

φ)

VB-M Step: Qθ(θ) ← exp

  • dY QY(Y) log P(Y, X, θ)

Properties:

◮ Reduces to the EM algorithm if Qθ(θ) = δ(θ − θ∗). ◮ Fm increases monotonically, and incorporates the model complexity penalty. ◮ Analytical parameter distributions (but not constrained to be Gaussian).

slide-89
SLIDE 89

The Variational Bayesian EM algorithm

EM for MAP estimation Goal: maximize P(θ|X, m) wrt θ E Step: compute QY(Y) ← p(Y|X, θ) M Step:

θ ← argmax

θ

  • dY QY(Y) log P(Y, X, θ)

Variational Bayesian EM Goal: maximise bound on P(X|m) wrt Qθ VB-E Step: compute QY(Y) ← p(Y|X, ¯

φ)

VB-M Step: Qθ(θ) ← exp

  • dY QY(Y) log P(Y, X, θ)

Properties:

◮ Reduces to the EM algorithm if Qθ(θ) = δ(θ − θ∗). ◮ Fm increases monotonically, and incorporates the model complexity penalty. ◮ Analytical parameter distributions (but not constrained to be Gaussian). ◮ VB-E step has same complexity as corresponding E step.

slide-90
SLIDE 90

The Variational Bayesian EM algorithm

EM for MAP estimation Goal: maximize P(θ|X, m) wrt θ E Step: compute QY(Y) ← p(Y|X, θ) M Step:

θ ← argmax

θ

  • dY QY(Y) log P(Y, X, θ)

Variational Bayesian EM Goal: maximise bound on P(X|m) wrt Qθ VB-E Step: compute QY(Y) ← p(Y|X, ¯

φ)

VB-M Step: Qθ(θ) ← exp

  • dY QY(Y) log P(Y, X, θ)

Properties:

◮ Reduces to the EM algorithm if Qθ(θ) = δ(θ − θ∗). ◮ Fm increases monotonically, and incorporates the model complexity penalty. ◮ Analytical parameter distributions (but not constrained to be Gaussian). ◮ VB-E step has same complexity as corresponding E step. ◮ We can use the junction tree, belief propagation, Kalman filter, etc, algorithms in the

VB-E step of VB-EM, but using expected natural parameters, ¯

φ.

slide-91
SLIDE 91

VB and model selection

◮ Variational Bayesian EM yields an approximate posterior Qθ over model parameters.

slide-92
SLIDE 92

VB and model selection

◮ Variational Bayesian EM yields an approximate posterior Qθ over model parameters. ◮ It also yields an optimised lower bound on the model evidence

max FM(QY, Qθ) ≤ P(D|M)

slide-93
SLIDE 93

VB and model selection

◮ Variational Bayesian EM yields an approximate posterior Qθ over model parameters. ◮ It also yields an optimised lower bound on the model evidence

max FM(QY, Qθ) ≤ P(D|M)

◮ These lower bounds can be compared amongst models to learn the right (structure,

connectivity . . . of the) model

slide-94
SLIDE 94

VB and model selection

◮ Variational Bayesian EM yields an approximate posterior Qθ over model parameters. ◮ It also yields an optimised lower bound on the model evidence

max FM(QY, Qθ) ≤ P(D|M)

◮ These lower bounds can be compared amongst models to learn the right (structure,

connectivity . . . of the) model

◮ If a continuous domain of models is specified by a hyperparameter η, then the VB free

energy depends on that parameter:

F(QY, Qθ, η) =

  • dY dθ QY(Y)Qθ(θ) log P(X, Y, θ|η)

QY(Y)Qθ(θ) ≤ P(X|η) A hyper-M step maximises the current bound wrt η:

η ← argmax

η

  • dY dθ QY(Y)Qθ(θ) log P(X, Y, θ|η)
slide-95
SLIDE 95

ARD for unsupervised learning

Recall that ARD (automatic relevance determination) was a hyperparameter method to select relevant or useful inputs in regression.

◮ A similar idea used with variational Bayesian methods can learn a latent dimensionality.

slide-96
SLIDE 96

ARD for unsupervised learning

Recall that ARD (automatic relevance determination) was a hyperparameter method to select relevant or useful inputs in regression.

◮ A similar idea used with variational Bayesian methods can learn a latent dimensionality. ◮ Consider factor analysis:

x ∼ N (Λy, Ψ) y ∼ N (0, I) with a column-wise prior

Λ:i ∼ N

  • 0, α−1

i

I

slide-97
SLIDE 97

ARD for unsupervised learning

Recall that ARD (automatic relevance determination) was a hyperparameter method to select relevant or useful inputs in regression.

◮ A similar idea used with variational Bayesian methods can learn a latent dimensionality. ◮ Consider factor analysis:

x ∼ N (Λy, Ψ) y ∼ N (0, I) with a column-wise prior

Λ:i ∼ N

  • 0, α−1

i

I

  • ◮ The VB free energy is

F(QY(Y), QΛ(Λ), Ψ, α) =

  • log P(X, Y|Λ, Ψ) + log P(Λ|α) + log P(Ψ)
  • QY QΛ +. . .

and so hyperparameter optimisation requires

α ← argmax log P(Λ|α)QΛ

slide-98
SLIDE 98

ARD for unsupervised learning

Recall that ARD (automatic relevance determination) was a hyperparameter method to select relevant or useful inputs in regression.

◮ A similar idea used with variational Bayesian methods can learn a latent dimensionality. ◮ Consider factor analysis:

x ∼ N (Λy, Ψ) y ∼ N (0, I) with a column-wise prior

Λ:i ∼ N

  • 0, α−1

i

I

  • ◮ The VB free energy is

F(QY(Y), QΛ(Λ), Ψ, α) =

  • log P(X, Y|Λ, Ψ) + log P(Λ|α) + log P(Ψ)
  • QY QΛ +. . .

and so hyperparameter optimisation requires

α ← argmax log P(Λ|α)QΛ

◮ Now QΛ is Gaussian, with the same form as in linear regression, but with expected

moments of y appearing in place of the inputs.

slide-99
SLIDE 99

ARD for unsupervised learning

Recall that ARD (automatic relevance determination) was a hyperparameter method to select relevant or useful inputs in regression.

◮ A similar idea used with variational Bayesian methods can learn a latent dimensionality. ◮ Consider factor analysis:

x ∼ N (Λy, Ψ) y ∼ N (0, I) with a column-wise prior

Λ:i ∼ N

  • 0, α−1

i

I

  • ◮ The VB free energy is

F(QY(Y), QΛ(Λ), Ψ, α) =

  • log P(X, Y|Λ, Ψ) + log P(Λ|α) + log P(Ψ)
  • QY QΛ +. . .

and so hyperparameter optimisation requires

α ← argmax log P(Λ|α)QΛ

◮ Now QΛ is Gaussian, with the same form as in linear regression, but with expected

moments of y appearing in place of the inputs.

◮ Optimisation wrt the distributions, Ψ and α in turn causes some αi to diverge as in

regression ARD.

slide-100
SLIDE 100

ARD for unsupervised learning

Recall that ARD (automatic relevance determination) was a hyperparameter method to select relevant or useful inputs in regression.

◮ A similar idea used with variational Bayesian methods can learn a latent dimensionality. ◮ Consider factor analysis:

x ∼ N (Λy, Ψ) y ∼ N (0, I) with a column-wise prior

Λ:i ∼ N

  • 0, α−1

i

I

  • ◮ The VB free energy is

F(QY(Y), QΛ(Λ), Ψ, α) =

  • log P(X, Y|Λ, Ψ) + log P(Λ|α) + log P(Ψ)
  • QY QΛ +. . .

and so hyperparameter optimisation requires

α ← argmax log P(Λ|α)QΛ

◮ Now QΛ is Gaussian, with the same form as in linear regression, but with expected

moments of y appearing in place of the inputs.

◮ Optimisation wrt the distributions, Ψ and α in turn causes some αi to diverge as in

regression ARD.

◮ In this case, these parameters select “relevant” latent dimensions, effectively learning

the dimensionality of y.

slide-101
SLIDE 101

Augmented Variational Methods

In our examples so far, the approximate variational distribution has been over the “natural” latent variables (and parameters) of the generative model. Sometimes it may be useful to introduce additional latent variables, solely to achieve computational tractability. Two examples are GP regression and the GPLVM.

slide-102
SLIDE 102

Sparse GP approximations

GP predictions: y′|X, Y, x′ ∼ N

  • Kx′XK −1

XX Y, Kx′x′ − Kx′XK −1 XX KXx′

Evidence (for learning kernel hyperparameters): log P(Y|X) = −1 2 log |2π(KXX + σ2I)| − 1 2Y(KXX + σ2I)−1Y T Computing either form requires inverting the N × N matrix KXX, in O(N3) time.

slide-103
SLIDE 103

Sparse GP approximations

GP predictions: y′|X, Y, x′ ∼ N

  • Kx′XK −1

XX Y, Kx′x′ − Kx′XK −1 XX KXx′

Evidence (for learning kernel hyperparameters): log P(Y|X) = −1 2 log |2π(KXX + σ2I)| − 1 2Y(KXX + σ2I)−1Y T Computing either form requires inverting the N × N matrix KXX, in O(N3) time. One proposal to make this more efficient is to find (or select) a smaller set of possibly fictitious measurements U at inputs Z such that predictions made on the basis of U are close to those made with Y.

slide-104
SLIDE 104

Sparse GP approximations

GP predictions: y′|X, Y, x′ ∼ N

  • Kx′XK −1

XX Y, Kx′x′ − Kx′XK −1 XX KXx′

Evidence (for learning kernel hyperparameters): log P(Y|X) = −1 2 log |2π(KXX + σ2I)| − 1 2Y(KXX + σ2I)−1Y T Computing either form requires inverting the N × N matrix KXX, in O(N3) time. One proposal to make this more efficient is to find (or select) a smaller set of possibly fictitious measurements U at inputs Z such that predictions made on the basis of U are close to those made with Y. Where (and what values) should the U lie?

slide-105
SLIDE 105

Variational Sparse GP approximations

We write F for the (smooth) GP function values that underlie Y (so Y ∼ N

  • F, σ2I
  • ).

Introduce additional latent measurements U at inputs Z. Then the likelihood is P(Y|X) =

  • dF dU P(Y, F, U|X, Z) =
  • dF dU P(Y|F)P(F|U, X, Z)P(U|Z)

The U and F are latent, so we introduce a variational distribution q(F, U) to form a free-energy.

F(q(F, U), θ) =

  • log P(Y|F)P(F|U, X, Z)P(U|Z)

q(F, U)

  • q(F,U)

Now, choose the variational form q(F, U) = P(F|U, X, Z)q(U). That is, fix F|U without reference to Y – so information about Y will need to be “compressed” into q(U). Then

F(q(F, U), θ) =

  • log P(Y|F)P(F|U, X, Z)P(U|Z)

P(F|U, X, Z)q(U)

  • P(F|U)q(U)

=

  • log P(Y|F)P(F|U) + log P(U|Z) − log q(U)
  • q(U)
slide-106
SLIDE 106

Variational Sparse GP approximations

F(q(U), θ) =

  • log P(Y|F)P(F|U) + log P(U|Z) − log q(U)
  • q(U)

Now P(F|U) is fixed by the generative model (rather than being subject to free optimisation). So we can evaluate that expectation:

log P(Y|F)P(F|U) =

  • −1

2 log

  • 2πσ2I

1 2σ2 Tr

  • (Y − F)(Y − F)T

P(F|U)

= −1

2 log

  • 2πσ2I

1 2σ2 Tr

  • (Y − FP(F|U))(Y − FP(F|U))T

1 2σ2 Tr

  • FF T

P(F|U)

  • = log N
  • Y|KXZK −1

ZZ U, σ2I

1 2σ2 Tr

  • KXX − KXZK −1

ZZ KZX

  • So,

F(q(U), θ) =

  • log N
  • Y|KXZK −1

ZZ U, σ2I

  • + log P(U|Z) − log q(U)
  • q(U)

1 2σ2 Tr

  • KXX − KXZK −1

ZZ KZX

  • .
slide-107
SLIDE 107

Variational Sparse GP approximations

F(q(U), θ) =

  • log N
  • Y|KXZK −1

ZZ U, σ2I

  • P(U|Z)

q(U)

  • q(U)

− 1

2σ2 Tr

  • KXX − KXZK −1

ZZ KZX

  • .

Now, we may recognise the expectation as the free energy of a PPCA-like model with normal prior U ∼ N (0, KUU) and loading matrix KXZK −1

ZZ . The maximum of the free energy is the

log-likelihood (and it is achieved with q equal to the posterior under this PPCA model). This gives

F(q∗(U), θ) = log N

  • Y|0, KXZK −1

ZZ KZZK −1 ZZ KZX + σ2I

1 2σ2 Tr

  • KXX − KXZK −1

ZZ KZX

  • .

Note that we have eliminated all terms in K −1

XX .

We can optimise this free energy numerically with respect to Z and θ to adjust the GP prior and quality of variational approximation. A similar approach can be used to learn X if they are unobserved (i.e. in the GPLVM). Assume q(X, F, U) = q(X)P(F|X, U)q(U). Then F = log P(Y, F, U|X) log P(X)q(U)q(X) which simplifies into tractable components in much the same way as above.

slide-108
SLIDE 108

A few references

◮ Jordan, Ghahramani, Jaakkola, Saul, 1999. An introduction to variational methods for

graphical models. Machine Learning 37:183–233.

◮ Attias, 2000. A variational Bayesian framework for graphical models. NIPS 12.

http://www.gatsby.ucl.ac.uk/publications/papers/03-2000.ps

◮ Beal, 2003. Variational algorithms for approximate Bayesian inference. PhD thesis,

Gatsby Unit, UCL. http://www.cse.buffalo.edu/faculty/mbeal/thesis/

◮ Winn, 2003. Variational message passing and its applications. PhD thesis, Cambridge.

http://johnwinn.org/Publications/Thesis.html; also VIBES software for

conjugate-exponential graphs. Some complexities:

◮ MacKay, 2001. A problem with variational free energy minimization.

http://www.inference.phy.cam.ac.uk/mackay/minima.pdf

◮ Turner, MS, 2011. Two problems with variational expectation maximisation for

time-series models. In Barber, Cemgil, Chiappa, eds., Bayesian Time Series Models. http://www.gatsby.ucl.ac.uk/~maneesh/papers/turner-sahani-2010-ildn.pdf

◮ Berkes, Turner, MS, 2008. On sparsity and overcompleteness in image models. NIPS

  • 20. http://www.gatsby.ucl.ac.uk/~maneesh/papers/berkes-etal-2008-nips.pdf