TAMING THE CURSE OF DIMENSIONALITY: DISCRETE INTEGRATION BY HASHING AND OPTIMIZATION Stefano Ermon*, Carla P. Gomes*, Ashish Sabharwal + , and Bart Selman* *Cornell University + IBM Watson Research Center ICML - 2013 1
Jan 31, 2016
TAMING THE CURSE OF DIMENSIONALITY: DISCRETE INTEGRATION BY HASHING AND OPTIMIZATION
Stefano Ermon*, Carla P. Gomes*, Ashish Sabharwal+, and Bart Selman*
*Cornell University+IBM Watson Research Center
ICML - 2013
1
High-dimensional integration• High-dimensional integrals in statistics, ML, physics
• Expectations / model averaging• Marginalization• Partition function / rank models / parameter learning
• Curse of dimensionality:
• Quadrature involves weighted sum over exponential number of items (e.g., units of volume)
L L2 L3 Ln
n dimensional hypercube
L4
2
Discrete Integration• We are given
• A set of 2n items• Non-negative weights w
• Goal: compute total weight • Compactly specified weight function:
• factored form (Bayes net, factor graph, CNF, …)• potentially Turing Machine
• Example 1: n=2 dimensions, sum over 4 items
• Example 2: n= 100 dimensions, sum over 2100 ≈1030 items (intractable)
1 4 0 5…
2n Items
Goal: compute 5 + 0 + 2 + 1 = 85 0 2 1
Size visually represents weight
5
0
1
2
3
Hardness • 0/1 weights case:
• Is there at least a “1”? SAT• How many “1” ? #SAT• NP-complete vs. #P-complete. Much harder
• General weights:• Find heaviest item (combinatorial optimization)• Sum weights (discrete integration)
• This Work: Approximate Discrete Integration via Optimization
• Combinatorial optimization (MIP, Max-SAT,CP) also often fast in practice:• Relaxations / bounds• Pruning
0
4
3
7
0
0
1
1
4
P
NP
P^#P
PSPACE
Easy
Hard
PH
EXP
Previous approaches: SamplingIdea:
Randomly select a region Count within this region Scale up appropriately
Advantage: Quite fast
Drawback: Robustness: can easily under- or
over-estimate Scalability in sparse spaces:
e.g. 1060 items with non-zero weight out of 10300 means need region much larger than 10240 to “hit” one
Can be partially mitigated using importance sampling
100
70
60
9 9
9
55
55
5
5
52
2
2
5
Previous approaches: Variational methods
Idea: For exponential families, use convexity Variational formulation (optimization) Solve approximately (using message-
passing techniques)Advantage:
Quite fastDrawback:
Objective function is defined indirectly Cannot represent the domain of
optimization compactly Need to be approximated (BP, MF) Typically no guarantees
6
# ite
ms
CDF-style plot
Suppose items are sorted by weight
A new approach : WISH
100706099955
b0=100b1=70b2=9b3=5
555552
b4=2
22
Geometrically divide y axis
Area under the curve equals the total weight we want to compute.
11248Geometrically increasing bin sizes
Given the endpoints bi, we have a 2-approximation
2i-largest weight (quantile) bi
b
How many items with weight at least b
How to estimate? Divide into slices and sum up
Can bound area in each slice within a factor of 2
How to estimate the bi?
7
1w
Also works if we have approximations Mi of bi
Hash 2n items into 2i buckets, then look at a single bucket.
Find heaviest weight wi in the bucket.
Estimating the endpoints (quantiles) bi
10070
60
99 9
5 55
5
55
5
2
2
2
For i=2, hashing 16 items into 22=4 buckets
INTUITION. Repeat several times. With High Probability: wi often found to be larger than w* there are at least 2i items with weight larger than w*.
Wi=9
8
Hashing and Optimization
• Hash into 2i buckets, then look at a single bucket• With probability >0.5:
• There is nothing from the small set (vanishes)• There is something from the larger set (survives)
2i-2=2i/4 heaviest items
2i+2=4.2i heaviest items
Something in here is likely to be in the bucket, so if we take a max , it will be in this range
bi-2bi+2
2
bi
9 5
5
2
Geometrically increasing bin sizes
b0
Remember items are sorted so max picks the “rightmost” item…
100
100
16 times larger
increasing weight
9
• Represent each item as an n-bit vector x• Randomly generate A in {0,1}i×n,b in {0,1}i
• Then A x + b (mod 2) is:• Uniform• Pairwise independent
Universal Hashing
Max w(x) subject to A x = b mod 2is in here “frequently”
n
Repeat several times.Median is in the desired range with high probability
Bucket content is implicitly defined by the solutions of A x = b mod 2(parity constraints)
Ai
x= b (mod 2)
bi+2 bi b0bi-2
x x xx x
10
WISH : Integration by Hashing and OptimizationWISH (WeightedIntegralsSumsByHashing)• T = log (n/δ)•For i = 0, … , n• For t = 1, … ,T
• Sample uniformly A in {0,1}i×n, b in {0,1}i
• wit = max w(x) subject to A x = b (mod 2)
• Mi = Median (wi1, … , wi
T)•Return M0 + Σi Mi+1 2i
The algorithm requires only O(n log n) optimizations for a sum over 2n items
Mi estimates the 2i-largest weight bi
Outer Loop over n+1 endpoints of the n slices (bi)
Sum up estimated area in each vertical slice
11
Hash into 2i bucketsFind heaviest itemRepeat log(n) times
# ite
ms
CDF-style plot
Visual working of the algorithm• How it works
….
median M1
1 random parity constraint
2 random parity constraints
…. ….
3 random parity constraints
median M2 median M3
….
Mode M0 + + +×1 ×2 ×4 + …
12
Function to be integrated
n times
Log(n) times
Accuracy Guarantees• Theorem 1: With probability at least 1- δ (e.g., 99.9%)
WISH computes a 16-approximation of a sum over 2n
items (discrete integral) by solving θ(n log n) optimization instances.• Example: partition function by solving θ(n log n) MAP queries
• Theorem 2: Can improve the approximation factor to (1+ε) by adding extra variables.• Example: factor 2 approximation with 4n variables
• Byproduct: we also obtain a 8-approximation of the tail distribution (CDF) with high probability
13
Key features• Strong accuracy guarantees• Can plug in any combinatorial optimization tool• Bounds on the optimization translate to bounds on the
sum• Stop early and get a lower bound (anytime)• (LP,SDP) relaxations give upper bounds• Extra constraints can make the optimization harder or easier
• Massively parallel (independent optimizations)• Remark: faster than enumeration force only when
combinatorial optimization is efficient (faster than brute force).
14
Experimental results• Approximate the partition function of undirected graphical
models by solving MAP queries (find most likely state)• Normalization constant to evaluate probability, rank models• MAP inference on graphical model augmented with random parity
constraints
• Toulbar2 (branch&bound) solver for MAP inference• Augmented with Gauss-Jordan filtering to efficiently handle the
parity constraints (linear equations over a field)
• Run in parallel using > 600 cores
15
Parity check nodes enforcing A x = b (mod 2)
Original graphical model
Sudoku• How many ways to fill a valid sudoku square?
• Sum over 981 ~ 1077 possible squares (items)• w(x)=1 if it is a valid square, w(x)=0 otherwise
• Accurate solution within seconds:• 1.634×1021 vs 6.671×1021
1 2
….
?
16
Random Cliques Ising ModelsVery small error band is the 16-approximationrange
Strength of the interactions
17
Other methods fall way out of the error bandPartition function
MAP query
Model ranking - MNSIT• Use the function estimate to rank models (data likelihood)
• WISH ranks them correctly. Mean-field and BP do not.
18
Visually, a better model for handwritten digits
Conclusions• Discrete integration reduced to small number of
optimization instances• Strong (probabilistic) accuracy guarantees by universal hashing• Can leverage fast combinatorial optimization packages• Works well in practice
• Future work:• Extension to continuous integrals• Further approximations in the optimization [UAI -13]
• Coding theory / Parity check codes / Max-Likelihood Decoding• LP relaxations
• Sampling from high-dimensional probability distributions?
19