One thing I've always wanted to do was find the global minimum of 1,000 Rosenbrock functions using BFGS really quickly - and now I can do it easily!
import tensorflow as tf import tensorflow_probability as tfp import time as tm def rosenbrock(x, a, b): x0 = x[..., 0] x1 = x[..., 1] first = tf.math.squared_difference(a, x0) second = b * tf.math.squared_difference(x1, tf.square(x0)) return first + second def optimize(init, a, b): @tf.function def fn(x): return tfp.math.value_and_gradient(lambda x: rosenbrock(x, a, b), x) return tfp.optimizer.bfgs_minimize(fn, init, max_iterations=100) tf.random.set_seed(42) batch_size = 1_000 init = tf.random.normal([batch_size, 2], dtype=tf.float64) a = tf.random.normal([batch_size], mean=5.0, dtype=tf.float64) b = tf.random.normal([batch_size], mean=100.0, dtype=tf.float64) start = tm.time() opt = optimize(init, a, b) end = tm.time() solution = tf.stack([a, tf.square(a)], axis=-1) estimate = opt.position
print(f"{batch_size:,} optimizations took {end - start:.2f} seconds") print(f"iterations: {opt.num_iterations}") print(f"all converged: {tf.reduce_all(opt.converged)}") print(f"any line searches failed: {tf.reduce_all(opt.failed)}") print(f"max abs error: {tf.reduce_max(tf.abs(solution - estimate))}")
1,000 optimizations took 3.66 seconds iterations: 72 all converged: True any line searches failed: False max abs error: 3.966131600918743e-08
Now, obviously this isn't useful in itself - it's just an example of how to use the batching support recently introduced into tensorflow probability.