Simultaneous Inference for Massive Data: Distributed Bootstrap Yang - - PowerPoint PPT Presentation

simultaneous inference for massive data distributed
SMART_READER_LITE
LIVE PREVIEW

Simultaneous Inference for Massive Data: Distributed Bootstrap Yang - - PowerPoint PPT Presentation

Simultaneous Inference for Massive Data: Distributed Bootstrap Yang Yu 1 , Shih-Kang Chao 2 , Guang Cheng 1 1 Purdue University 2 University of Missouri ICML 2020 We have N i.i.d. data points: Z 1 , . . . , Z N Estimation: Fit a model that has an


slide-1
SLIDE 1

Simultaneous Inference for Massive Data: Distributed Bootstrap

Yang Yu1, Shih-Kang Chao2, Guang Cheng1

1Purdue University 2University of Missouri

ICML 2020

slide-2
SLIDE 2

We have N i.i.d. data points: Z1, . . . , ZN Estimation: Fit a model that has an unknown parameter θ ∈ Rd by minimizing the empirical risk

  • θ : = arg min

θ∈Rd

1 N

N

  • i=1

L(θ; Zi)

slide-3
SLIDE 3

We have N i.i.d. data points: Z1, . . . , ZN Estimation: Fit a model that has an unknown parameter θ ∈ Rd by minimizing the empirical risk

  • θ : = arg min

θ∈Rd

1 N

N

  • i=1

L(θ; Zi) Ideally, we want θ to be close to the expected risk minimizer θ∗ : = arg min

θ∈Rd EZ[L(θ; Z)]

slide-4
SLIDE 4

We have N i.i.d. data points: Z1, . . . , ZN Estimation: Fit a model that has an unknown parameter θ ∈ Rd by minimizing the empirical risk

  • θ : = arg min

θ∈Rd

1 N

N

  • i=1

L(θ; Zi) Ideally, we want θ to be close to the expected risk minimizer θ∗ : = arg min

θ∈Rd EZ[L(θ; Z)]

Examples:

◮ Linear regression: Z = (X, Y ), L(θ; Z) = (Y − X⊤θ)2/2 ◮ Logistic regression: Z = (X, Y ), L(θ; Z) = −Y X⊤θ + log(1 + exp[X⊤θ])

slide-5
SLIDE 5

Inference: Find an interval [L, U] s.t. P(θ∗

1 ∈ [L, U]) ≈ 95%

slide-6
SLIDE 6

Inference: Find an interval [L, U] s.t. P(θ∗

1 ∈ [L, U]) ≈ 95%

Simultaneous Inference: Find intervals [L1, U1], . . . , [Ld, Ud] s.t. P(θ∗

1 ∈ [L1, U1], . . . , θ∗ d ∈ [Ld, Ud]) ≈ 95%

slide-7
SLIDE 7

Inference: Find an interval [L, U] s.t. P(θ∗

1 ∈ [L, U]) ≈ 95%

Simultaneous Inference: Find intervals [L1, U1], . . . , [Ld, Ud] s.t. P(θ∗

1 ∈ [L1, U1], . . . , θ∗ d ∈ [Ld, Ud]) ≈ 95%

How to perform Simultaneous Inference:

slide-8
SLIDE 8

Inference: Find an interval [L, U] s.t. P(θ∗

1 ∈ [L, U]) ≈ 95%

Simultaneous Inference: Find intervals [L1, U1], . . . , [Ld, Ud] s.t. P(θ∗

1 ∈ [L1, U1], . . . , θ∗ d ∈ [Ld, Ud]) ≈ 95%

How to perform Simultaneous Inference: Step 1: Compute point estimator θ

slide-9
SLIDE 9

Inference: Find an interval [L, U] s.t. P(θ∗

1 ∈ [L, U]) ≈ 95%

Simultaneous Inference: Find intervals [L1, U1], . . . , [Ld, Ud] s.t. P(θ∗

1 ∈ [L1, U1], . . . , θ∗ d ∈ [Ld, Ud]) ≈ 95%

How to perform Simultaneous Inference: Step 1: Compute point estimator θ Step 2: Estimate the 0.95-quantile c(0.95) of

N

  • θ − θ∗
  • ∞ (by

bootstrap)

slide-10
SLIDE 10

Inference: Find an interval [L, U] s.t. P(θ∗

1 ∈ [L, U]) ≈ 95%

Simultaneous Inference: Find intervals [L1, U1], . . . , [Ld, Ud] s.t. P(θ∗

1 ∈ [L1, U1], . . . , θ∗ d ∈ [Ld, Ud]) ≈ 95%

How to perform Simultaneous Inference: Step 1: Compute point estimator θ Step 2: Estimate the 0.95-quantile c(0.95) of

N

  • θ − θ∗
  • ∞ (by

bootstrap) Step 3: For l = 1, . . . , d, Ll = θl − c(0.95) √ N , Ul = θl + c(0.95) √ N

slide-11
SLIDE 11

Distributed framework: Distribute N data points evenly across k machines s.t. each machine stores n = N/k data points

◮ 1 master node M1 ◮ k − 1 worker nodes M2, M3, . . . , Mk ◮ Zij: the i-th data point at machine Mj

slide-12
SLIDE 12

Distributed Simultaneous Inference

1Kleiner, et al. "A scalable bootstrap for massive data." JRSS-B (2014) 2Sengupta, et al. "A subsampled double bootstrap for massive data." JASA (2016)

slide-13
SLIDE 13

Distributed Simultaneous Inference Step 1: Compute θ

◮ Can be approximated by existing efficient distributed estimation methods

1Kleiner, et al. "A scalable bootstrap for massive data." JRSS-B (2014) 2Sengupta, et al. "A subsampled double bootstrap for massive data." JASA (2016)

slide-14
SLIDE 14

Distributed Simultaneous Inference Step 1: Compute θ

◮ Can be approximated by existing efficient distributed estimation methods

Step 2: Bootstrap c(0.95)

◮ Traditional bootstrap cannot be efficiently applied in the distributed

framework

◮ BLB1 and SDB2 are computationally expensive due to repeated

resampling and not suitable for large k

1Kleiner, et al. "A scalable bootstrap for massive data." JRSS-B (2014) 2Sengupta, et al. "A subsampled double bootstrap for massive data." JASA (2016)

slide-15
SLIDE 15

Question: How can we efficiently do Step 2 in a distributed manner?

slide-16
SLIDE 16

Question: How can we efficiently do Step 2 in a distributed manner? Our contributions:

◮ We propose communication-efficient and computation-efficient distributed

bootstrap methods: k-grad and n+k-1-grad

◮ We prove a sufficient number of communication rounds that guarantees

statistical accuracy and efficiency

slide-17
SLIDE 17

Approximate by sample average:

N( θ − θ∗)∞ ≈

  • E[∇2L(θ∗; Z)]−1 1

√ N

n

  • i=1

k

  • j=1

∇L(θ∗; Zij)

slide-18
SLIDE 18

Approximate by sample average:

N( θ − θ∗)∞ ≈

  • E[∇2L(θ∗; Z)]−1 1

√ N

n

  • i=1

k

  • j=1

∇L(θ∗; Zij)

Multiplier bootstrap: ǫij

iid

∼ N(0, 1) for i = 1, . . . , n and j = 1, . . . , k

N( θ−θ∗)∞

D

  • E[∇2L(θ∗; Z)]−1 1

√ N

n

  • i=1

k

  • j=1

ǫij∇L(θ∗; Zij)

  • {Zij}i,j
slide-19
SLIDE 19

Approximate by sample average:

N( θ − θ∗)∞ ≈

  • E[∇2L(θ∗; Z)]−1 1

√ N

n

  • i=1

k

  • j=1

∇L(θ∗; Zij)

Multiplier bootstrap: ǫij

iid

∼ N(0, 1) for i = 1, . . . , n and j = 1, . . . , k

N( θ−θ∗)∞

D

  • E[∇2L(θ∗; Z)]−1 1

√ N

n

  • i=1

k

  • j=1

ǫij∇L(θ∗; Zij)

  • {Zij}i,j

k-grad (computed at M1): ǫj

iid

∼ N(0, 1) for j = 1, . . . , k

N( θ − θ∗)∞

D

≈ W : =

  • Θ 1

√ k

k

  • j=1

ǫj √n(gj − ¯ g)

  • {Zij}i,j

where gj = 1 n

n

  • i=1

∇L(¯ θ; Zij) computed at Mj, transmitted to M1 ¯ g = 1 k

k

  • j=1

gj averaged at M1,

  • Θ =

1 n

n

  • i=1

∇2L(¯ θ; Zi1) −1 computed at M1

slide-20
SLIDE 20

k-grad fails for small k!

slide-21
SLIDE 21

k-grad fails for small k! Solution: n+k-1-grad (computed at M1): ǫi1, ǫj

iid

∼ N(0, 1) for i = 1, . . . , n and j = 2, . . . , k

  • W : =
  • Θ

1 √ n + k − 1

  • n
  • i=1

ǫi1(gi1 − ¯ g) +

k

  • j=2

ǫj √n(gj − ¯ g)

  • {Zij}i,j

where gi1 = ∇L(¯ θ; Zi1) computed at M1

slide-22
SLIDE 22

An example algorithm: apply k-grad/n+k-1-grad with CSL estimator3

1Jordan, et al. "Communication-efficient distributed statistical inference." JASA (2019)

slide-23
SLIDE 23

An example algorithm: apply k-grad/n+k-1-grad with CSL estimator3

Step 1: compute point estimator θ (τ rounds of communication)

1:

θ(0) ← arg minθ L1(θ) at M1

2: for t = 1, . . . , τ do 3:

Transmit θ(t−1) to {Mj}k

j=2

4:

Compute ∇L1( θ(t−1)) and ∇2L1( θ(t−1))−1 at M1

5:

for j = 2, . . . , k do

6:

Compute ∇Lj( θ(t−1)) at Mj

7:

Transmit ∇Lj( θ(t−1)) to M1

8:

∇LN( θ(t−1)) ← k−1 k

j=1 ∇Lj(

θ(t−1)) at M1

9:

  • θ(t) ←

θ(t−1) − ∇2L1( θ(t−1))−1∇LN( θ(t−1)) at M1

1Jordan, et al. "Communication-efficient distributed statistical inference." JASA (2019)

slide-24
SLIDE 24

An example algorithm: apply k-grad/n+k-1-grad with CSL estimator3

Step 1: compute point estimator θ (τ rounds of communication)

1:

θ(0) ← arg minθ L1(θ) at M1

2: for t = 1, . . . , τ do 3:

Transmit θ(t−1) to {Mj}k

j=2

4:

Compute ∇L1( θ(t−1)) and ∇2L1( θ(t−1))−1 at M1

5:

for j = 2, . . . , k do

6:

Compute ∇Lj( θ(t−1)) at Mj

7:

Transmit ∇Lj( θ(t−1)) to M1

8:

∇LN( θ(t−1)) ← k−1 k

j=1 ∇Lj(

θ(t−1)) at M1

9:

  • θ(t) ←

θ(t−1) − ∇2L1( θ(t−1))−1∇LN( θ(t−1)) at M1 Step 2: bootstrap quantile c(0.95) (0 round of communication)

10: Run k-grad/n+k-1-grad with ¯

θ = θ(t−1) at M1

1Jordan, et al. "Communication-efficient distributed statistical inference." JASA (2019)

slide-25
SLIDE 25

An example algorithm: apply k-grad/n+k-1-grad with CSL estimator3

Step 1: compute point estimator θ (τ rounds of communication)

1:

θ(0) ← arg minθ L1(θ) at M1

2: for t = 1, . . . , τ do 3:

Transmit θ(t−1) to {Mj}k

j=2

4:

Compute ∇L1( θ(t−1)) and ∇2L1( θ(t−1))−1 at M1

5:

for j = 2, . . . , k do

6:

Compute ∇Lj( θ(t−1)) at Mj

7:

Transmit ∇Lj( θ(t−1)) to M1

8:

∇LN( θ(t−1)) ← k−1 k

j=1 ∇Lj(

θ(t−1)) at M1

9:

  • θ(t) ←

θ(t−1) − ∇2L1( θ(t−1))−1∇LN( θ(t−1)) at M1 Step 2: bootstrap quantile c(0.95) (0 round of communication)

10: Run k-grad/n+k-1-grad with ¯

θ = θ(t−1) at M1 Step 3: 95% simultaneous confidence interval

11:

θ(t)

l

± N −1/2 c(0.95) for l = 1, . . . , d In total, τ rounds of communication

1Jordan, et al. "Communication-efficient distributed statistical inference." JASA (2019)

slide-26
SLIDE 26

An example algorithm: apply k-grad/n+k-1-grad with CSL estimator3

Step 1: compute point estimator θ (τ rounds of communication)

1:

θ(0) ← arg minθ L1(θ) at M1

2: for t = 1, . . . , τ do 3:

Transmit θ(t−1) to {Mj}k

j=2

4:

Compute ∇L1( θ(t−1)) and ∇2L1( θ(t−1))−1 at M1

5:

for j = 2, . . . , k do

6:

Compute ∇Lj( θ(t−1)) at Mj

7:

Transmit ∇Lj( θ(t−1)) to M1

8:

∇LN( θ(t−1)) ← k−1 k

j=1 ∇Lj(

θ(t−1)) at M1

9:

  • θ(t) ←

θ(t−1) − ∇2L1( θ(t−1))−1∇LN( θ(t−1)) at M1 Step 2: bootstrap quantile c(0.95) (0 round of communication)

10: Run k-grad/n+k-1-grad with ¯

θ = θ(t−1) at M1 Step 3: 95% simultaneous confidence interval

11:

θ(t)

l

± N −1/2 c(0.95) for l = 1, . . . , d In total, τ rounds of communication

Question: How many rounds of communication are sufficient?

1Jordan, et al. "Communication-efficient distributed statistical inference." JASA (2019)

slide-27
SLIDE 27

Assume n = dγn and k = dγk for constants γn, γk > 0

slide-28
SLIDE 28

Assume n = dγn and k = dγk for constants γn, γk > 0 Statistical accuracy: supα∈(0,1) |P( √ N( θ − θ∗)∞ ≤ cW (α)) − α| = o(1)

slide-29
SLIDE 29

Assume n = dγn and k = dγk for constants γn, γk > 0 Statistical accuracy: supα∈(0,1) |P( √ N( θ − θ∗)∞ ≤ cW (α)) − α| = o(1) Statistical efficiency: supα∈(0,1) |P( √ N( θ − θ∗)∞ ≤ cW (α)) − α| = o(1)

slide-30
SLIDE 30

Illustration of main results for linear models: Left: k-grad Right: n+k-1-grad Blue areas: accuracy and efficiency are guaranteed if τ ≥ τmin Gray areas: accuracy and efficiency are not guaranteed

2 4 6 8 10 γn = logd n 2 4 6 8 10 γk = logd k

τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

2 4 6 8 10 γn = logd n

τmin = 1 τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

slide-31
SLIDE 31

Illustration of main results for linear models: Left: k-grad Right: n+k-1-grad Blue areas: accuracy and efficiency are guaranteed if τ ≥ τmin Gray areas: accuracy and efficiency are not guaranteed

2 4 6 8 10 γn = logd n 2 4 6 8 10 γk = logd k

τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

2 4 6 8 10 γn = logd n

τmin = 1 τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

◮ τmin ր logarithmically as k ր, n ց and d ր

slide-32
SLIDE 32

Illustration of main results for linear models: Left: k-grad Right: n+k-1-grad Blue areas: accuracy and efficiency are guaranteed if τ ≥ τmin Gray areas: accuracy and efficiency are not guaranteed

2 4 6 8 10 γn = logd n 2 4 6 8 10 γk = logd k

τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

2 4 6 8 10 γn = logd n

τmin = 1 τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

◮ τmin ր logarithmically as k ր, n ց and d ր ◮ τmin,n+k-1-grad ≤ τmin,k-grad

slide-33
SLIDE 33

Illustration of main results for linear models: Left: k-grad Right: n+k-1-grad Blue areas: accuracy and efficiency are guaranteed if τ ≥ τmin Gray areas: accuracy and efficiency are not guaranteed

2 4 6 8 10 γn = logd n 2 4 6 8 10 γk = logd k

τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

2 4 6 8 10 γn = logd n

τmin = 1 τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

◮ τmin ր logarithmically as k ր, n ց and d ր ◮ τmin,n+k-1-grad ≤ τmin,k-grad ◮ τmin,n+k-1-grad ≥ 1,

τmin,k-grad ≥ 2

slide-34
SLIDE 34

Illustration of main results for linear models: Left: k-grad Right: n+k-1-grad Blue areas: accuracy and efficiency are guaranteed if τ ≥ τmin Gray areas: accuracy and efficiency are not guaranteed

2 4 6 8 10 γn = logd n 2 4 6 8 10 γk = logd k

τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

2 4 6 8 10 γn = logd n

τmin = 1 τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

◮ τmin ր logarithmically as k ր, n ց and d ր ◮ τmin,n+k-1-grad ≤ τmin,k-grad ◮ τmin,n+k-1-grad ≥ 1,

τmin,k-grad ≥ 2

◮ γk has to be large for k-grad, but not for n+k-1-grad

slide-35
SLIDE 35

Illustration of main results for generalized linear models Left: k-grad Right: n+k-1-grad

3 6 9 12 15 γn = logd n 3 6 9 12 15 γk = logd k

τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5

3 6 9 12 15 γn = logd n

τmin = 2 τmin = 3 τmin = 4 τmin ≥ 5 τmin = 1

slide-36
SLIDE 36

Simulations: logistic regression, N = 216 Top left: k-grad, d = 23 Top right: k-grad, d = 25 Bottom left: n+k-1-grad, d = 23 Bottom right: n+k-1-grad, d = 25

1 2 3 4 5 6 7 8 9 log2 k 0.0 0.5 1.0 Coverage 1 2 3 4 5 6 7 log2 k

τ = 1 τ = 2 τ = 3 τ = 4

1 2 3 4 5 6 7 8 9 log2 k 0.0 0.5 1.0 Coverage 1 2 3 4 5 6 7 log2 k

coverage width

0.00 0.64 1.27 1.91 2.54 Width ×10−1 0.00 0.64 1.27 1.91 2.54 Width ×10−1

slide-37
SLIDE 37

Comparisons to BLB and SDB:

◮ Width (logistic regression, left: d = 25, right: d = 27)

2 4 6 log2 k 0.0 0.2 0.4 Width 2 4 log2 k

k-grad, τ = 1 k-grad, τ = 4 n+k-1-grad, τ = 1 n+k-1-grad, τ = 4 BLB SDB

slide-38
SLIDE 38

Comparisons to BLB and SDB:

◮ Width (logistic regression, left: d = 25, right: d = 27)

2 4 6 log2 k 0.0 0.2 0.4 Width 2 4 log2 k

k-grad, τ = 1 k-grad, τ = 4 n+k-1-grad, τ = 1 n+k-1-grad, τ = 4 BLB SDB

◮ Run time in seconds (linear regression, d = 27)

Methods k = 22 k = 26 k = 29 k-grad 0.82 0.51 0.50 n+k-1-grad 1.49 0.67 0.64 SDB 3.44 3.83 12.66 BLB 981.17 842.50 1950.91

slide-39
SLIDE 39

Extensions:

◮ To other models, e.g., graphical models ◮ To high-dimensional sparse models (in progress)

slide-40
SLIDE 40

Thank you!