One thing I've always wanted to do was find the global minimum of 1,000 Rosenbrock functions using BFGS really quickly - and now I can do it easily!

import tensorflow as tf
import tensorflow_probability as tfp
import time as tm

def rosenbrock(x, a, b):
    x0 = x[..., 0]
    x1 = x[..., 1]
    first = tf.math.squared_difference(a, x0)
    second = b * tf.math.squared_difference(x1, tf.square(x0))
    return first + second

def optimize(init, a, b):
    def fn(x):
        return tfp.math.value_and_gradient(lambda x: rosenbrock(x, a, b), x)
    return tfp.optimizer.bfgs_minimize(fn, init, max_iterations=100)


batch_size = 1_000
init = tf.random.normal([batch_size, 2], dtype=tf.float64)
a = tf.random.normal([batch_size], mean=5.0, dtype=tf.float64)
b = tf.random.normal([batch_size], mean=100.0, dtype=tf.float64)

start = tm.time()
opt = optimize(init, a, b)
end = tm.time()

solution = tf.stack([a, tf.square(a)], axis=-1)
estimate = opt.position
print(f"{batch_size:,} optimizations took {end - start:.2f} seconds")
print(f"iterations: {opt.num_iterations}")
print(f"all converged: {tf.reduce_all(opt.converged)}")
print(f"any line searches failed: {tf.reduce_all(opt.failed)}")
print(f"max abs error: {tf.reduce_max(tf.abs(solution - estimate))}")
1,000 optimizations took 3.66 seconds
iterations: 72
all converged: True
any line searches failed: False
max abs error: 3.966131600918743e-08

Now, obviously this isn't useful in itself - it's just an example of how to use the batching support recently introduced into tensorflow probability.