Top Banner
Tradeoffs in Large Scale Learning: Statistical Accuracy vs. Numerical Precision Sham M. Kakade Microsoft Research, New England S. M. Kakade (MSR) One pass learning 1 / 15
21

Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Oct 23, 2021

Download

Documents

dariahiddleston
Welcome message from author
This document is posted to help you gain knowledge. Please leave a comment to let me know what you think about it! Share it to your friends and learn new things together.
Transcript
Page 1: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Tradeoffs in Large Scale Learning:Statistical Accuracy vs. Numerical Precision

Sham M. Kakade

Microsoft Research, New England

S. M. Kakade (MSR) One pass learning 1 / 15

Page 2: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Statistical estimation involves optimization

Problem:

Find the minimizer w∗ of

L(~w) = E[loss(~w , point)]

You only get n samples.

Example: Estimate a linear relationship with n points in d dimensions?

Costly on large problems: O(nd2 + d3) runtime, O(d2) memoryHow should we approximate our solution?

S. M. Kakade (MSR) One pass learning 1 / 15

Page 3: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Statistics vs. Computation

Stochastic approximation

e.g. stochastic gradient descentobtain poor accuracy, quickly?simple to implement

Numerical analysis

e.g. (batch) gradient descentobtain high accuracy, slowly?more complicated

What would you do?

S. M. Kakade (MSR) One pass learning 2 / 15

Page 4: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Machine learning at scale!

Can we provide libraries to precisely do statistical estimation atscale?

Analogous to what was done with our linear algebra libraries?(LAPACK/BLAS)?

S. M. Kakade (MSR) One pass learning 3 / 15

Page 5: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

The Stochastic Optimization Problem

minw

L(w) where L(w) = Epoint∼D[loss(w , point)]

With N sampled points from D,

p1,p2, . . .pN

how do you estimate w∗, the minima of P?Your expected error/excess risk/regret is:

E[L(wN)− L(w∗)]

Goal: Do well statistically. Do it quickly.

S. M. Kakade (MSR) One pass learning 4 / 15

Page 6: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Without Computational Constraints...

What would you like to do?

Compute the empirical risk minimizer /M-estimator:

wERMN ∈ argmin

w

1N

N∑i=1

loss(w ,pi).

Consider the ratio:

E[L(wERMN )− L(w∗)]

E[L(wN)− L(w∗)].

Can you compete with the ERM on every problem efficiently?

S. M. Kakade (MSR) One pass learning 5 / 15

Page 7: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

This Talk

TheoremFor linear/logistic regression, generalized linear models, M-estimation,(i.e. assume “strong convexity” + “smoothness”),we provide a streaming algorithm which:

Computationally:single pass; memory is O(one sample)

trivially parallelizable

Statistically:achieves the statistical rate of the best fit on every problem(even considering constant factors)(super)-polynomially decreases the initial error

Related work: Juditsky & Polyak (1992); Dieuleveut & Bach (2014);

S. M. Kakade (MSR) One pass learning 6 / 15

Page 8: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Outline

1 Statistics:the statistical rate of the ERM

2 Computation:optimizing sums of convex functions

3 Computation + Statistics:combine ideas

S. M. Kakade (MSR) One pass learning 6 / 15

Page 9: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

(just) Statistics

Precisely, what is the error of wERMN ?

σ2 :=12E[‖∇loss(w∗,p)‖2(∇2L(w∗))−1

]Thm: (e.g. van der Vaart (2000)), Under regularity conditions, e.g.

loss is convex (almost surely)loss is smooth (almost surely)∇2L(w∗) exists and is positive definite.

we have,

limN→∞

E[L(wERMN )− L(w∗)]σ2/N

= 1

S. M. Kakade (MSR) One pass learning 7 / 15

Page 10: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

(just) Computation

optimizing sums of convex functions

minw

L(w) where L(w) =1N

N∑i=1

loss(w ,pi)

Assume:L(w) is µ strongly convexloss is L-smoothκ = L/µ is the effective condition number

Stochastic Gradient Descent: (Robbins & Monro, ’51)Linear convergence:Strohmer & Vershynin (2009), Yu & Nesterov (2010),Le Roux, Schmidt, Bach (2012), Shalev-Shwartz & Zhang, (2013),(SVRG) Johnson & Zhang (2013)

S. M. Kakade (MSR) One pass learning 8 / 15

Page 11: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

(just) Computation

optimizing sums of convex functions

minw

L(w) where L(w) =1N

N∑i=1

loss(w ,pi)

Assume:L(w) is µ strongly convexloss is L-smoothκ = L/µ is the effective condition number

Stochastic Gradient Descent: (Robbins & Monro, ’51)Linear convergence:Strohmer & Vershynin (2009), Yu & Nesterov (2010),Le Roux, Schmidt, Bach (2012), Shalev-Shwartz & Zhang, (2013),(SVRG) Johnson & Zhang (2013)

S. M. Kakade (MSR) One pass learning 8 / 15

Page 12: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Stochastic Gradient Descent (SGD)

SGD update rule: at each time t ,

sample a point pw ← w − η∇loss(w ,p)

Problem: even if w = w∗, the update changes w .

How do you fix this?

S. M. Kakade (MSR) One pass learning 9 / 15

Page 13: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Stochastic Gradient Descent (SGD)

SGD update rule: at each time t ,

sample a point pw ← w − η∇loss(w ,p)

Problem: even if w = w∗, the update changes w .

How do you fix this?

S. M. Kakade (MSR) One pass learning 9 / 15

Page 14: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Stochastic Gradient Descent (SGD)

SGD update rule: at each time t ,

sample a point pw ← w − η∇loss(w ,p)

Problem: even if w = w∗, the update changes w .

How do you fix this?

S. M. Kakade (MSR) One pass learning 9 / 15

Page 15: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Stochastic Variance Reduced Gradient (SVRG)

1 exact gradient computation: at stage s, using ws, compute:

∇L(ws) =1N

N∑i=1

∇loss(ws,pi)

2 corrected SGD: initialize w ← ws. for m steps,

sample a point pw ← w − η

(∇loss(w ,p)−∇loss(ws,p) +∇L(ws)

)3 update and repeat: ws+1 ← w .

Two ideas:If w = w∗, then no update.unbiased updates: blue term is mean 0.

S. M. Kakade (MSR) One pass learning 10 / 15

Page 16: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Stochastic Variance Reduced Gradient (SVRG)

1 exact gradient computation: at stage s, using ws, compute:

∇L(ws) =1N

N∑i=1

∇loss(ws,pi)

2 corrected SGD: initialize w ← ws. for m steps,

sample a point pw ← w − η

(∇loss(w ,p)−∇loss(ws,p) +∇L(ws)

)3 update and repeat: ws+1 ← w .

Two ideas:If w = w∗, then no update.unbiased updates: blue term is mean 0.

S. M. Kakade (MSR) One pass learning 10 / 15

Page 17: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

SVRG has linear convergence

Thm: (Johnson & Zhang, ’13) SVRG has linear convergence,for fixed η.

E[L(ws)− L(w∗)] ≤ e−s · (L(w0)− L(w∗))

many recent algorithms with similar guaranteesYu & Nesterov ’10; Shalev-Shwartz & Zhang ’13

Issues: must store dataset, requires many passes

What about the statistical rate?

S. M. Kakade (MSR) One pass learning 11 / 15

Page 18: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Computation + Statistics

Our problem:

minw

L(w) where L(w) = E[loss(w , point)]

(Streaming model) We obtain one sample at a time.

S. M. Kakade (MSR) One pass learning 12 / 15

Page 19: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Our algo: streaming SVRG

1 estimate the gradient: at stage s, using ws,

with ks fresh samples, estimate ∇L(ws)

2 corrected SGD: initialize w ← ws. for m steps:

sample a point

w ← w − η(∇loss(w ,p)−∇loss(ws,p) + ∇L(ws)

)3 update and repeat: ws+1 ← w

single pass; memory of O(one parameter); parallelizable

S. M. Kakade (MSR) One pass learning 13 / 15

Page 20: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Single Pass Least Squares Estimation

Theorem (Frostig, Ge, Kakade, & Sidford ’14)κ effective condition numberchoose p > 2schedule: increasing batch size ks = 2ks−1. fixed m and η = 1

2p .

If total sample size N is larger than multiple of κ (depends on p), then

E[L(wN)− L(w∗)] ≤ 1.5σ2

N+

L(w0)− L(w∗)(Nκ

)p

σ2/N is the ERM rate.

general case: use self-concordance

S. M. Kakade (MSR) One pass learning 14 / 15

Page 21: Tradeoffs in Large Scale Learning: Statistical Accuracy vs ...

Thanks!

We can obtain (nearly) the same rate as the ERM in a single pass.

Collaborators:

R. Frostig R. Ge A. Sidford

S. M. Kakade (MSR) One pass learning 15 / 15