In this post I have a look at the JointDistributionCoroutineAutoBatched class in TensorFlow Probability. The title is based on this C++ article.

I'm going to attempt to show the difference between JointDistributionCoroutineAutoBatched and JointDistributionCoroutine by displaying the code differences and runtime speed.

For the official release notes about this feature, see here.

1 Data

As per usual, I use some football data from football-data.co.uk, you can see some code to download it and process it here.

num_teams = len(team_names)
home_team = soccer_data["home_team"].to_numpy()
away_team = soccer_data["away_team"].to_numpy()
observed_home_goals = soccer_data["home_goals"].to_numpy()
observed_away_goals = soccer_data["away_goals"].to_numpy()

print(soccer_data.head())
  home_team_name away_team_name  home_goals  away_goals  home_team  away_team
0    Aston Villa       West Ham           3           0          1         32
1      Blackburn        Everton           1           0          3         12
2         Bolton         Fulham           0           0          5         13
3        Chelsea      West Brom           6           0         10         31
4     Sunderland     Birmingham           2           2         27          2

2 Model

Nothing new here, just a simple Poisson model for the observed home and away goals using an an intercept and home advantage terms, along with a measure of attacking and defensive capabilities for each team in the data.

In writing the model is:

\begin{align*} \sigma &\sim HalfNormal(1) \\ \alpha &\sim Normal(0, \sigma^2) \\ \beta &\sim Normal(0, \sigma^2) \\ \psi &\sim Normal(0, 1^2) \\ \gamma &\sim Normal(0, 1^2) \\ H &\sim Poisson(\exp(\psi + \gamma + \alpha_i + \beta_j)) \\ A &\sim Poisson(\exp(\psi + \alpha_j + \beta_i)) \end{align*}

where team \(i\) plays at home to team \(j\). This is a heirarchical model since it has "parameters which depend on parameters" - the attack (\(\alpha\)) and defence (\(\beta\)) depend on the scale (\(\sigma\)). For completeness, \(H\) and \(A\) are the home and away goals respectively, \(\psi\) is an intercept term, and \(\gamma\) is a home advantage term.

I use greek letters in this sort of stuff as it is common practice and takes up less space, but I prefer to look at the code to see what is going on now, where I try to write descriptive names for the parameters.

3 Two ways to write the model in code

Previously, I'd write the above model in code using something like this:

import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
tfl = tf.linalg

Root = tfd.JointDistributionCoroutine.Root


@tfd.JointDistributionCoroutine
def joint_dist():
    team_ability_scale = yield Root(tfd.HalfNormal(scale=1.0))
    unit_attack = yield Root(
        tfd.MultivariateNormalLinearOperator(
            loc=0.0,
            scale=tfl.LinearOperatorIdentity(num_rows=num_teams),
        )
    )
    unit_defence = yield Root(
        tfd.MultivariateNormalLinearOperator(
            loc=0.0,
            scale=tfl.LinearOperatorIdentity(num_rows=num_teams),
        )
    )
    intercept = yield tfd.Normal(loc=0.0, scale=1.0)
    home_advantage = yield tfd.Normal(loc=0.0, scale=1.0)

    attack = team_ability_scale[..., tf.newaxis] * unit_attack
    defence = team_ability_scale[..., tf.newaxis] * unit_defence

    home_log_rate = (
        intercept
        + home_advantage
        + tf.gather(attack, home_team, axis=-1)
        - tf.gather(defence, away_team, axis=-1)
    )

    away_log_rate = (
        intercept
        - home_advantage
        + tf.gather(attack, away_team, axis=-1)
        - tf.gather(defence, home_team, axis=-1)
    )

    yield tfd.Independent(
        tfd.Poisson(log_rate=home_log_rate), reinterpreted_batch_ndims=1
    )

    yield tfd.Independent(
        tfd.Poisson(log_rate=away_log_rate), reinterpreted_batch_ndims=1
    )

which has a bunch of stuff in it I'd rather not need to deal with:

  1. Root
  2. MultivariateNormalLinearOperator
  3. [..., tf.newaxis]
  4. tfd.Independent

step forward JointDistributionCoroutineAutoBatched:

@tfd.JointDistributionCoroutineAutoBatched
def auto_joint_dist():
    team_ability_scale = yield tfd.HalfNormal(scale=1.0)
    unit_attack = yield tfd.Normal(loc=tf.zeros(num_teams), scale=1.0)
    unit_defence = yield tfd.Normal(loc=tf.zeros(num_teams), scale=1.0)
    intercept = yield tfd.Normal(loc=0.0, scale=1.0)
    home_advantage = yield tfd.Normal(loc=0.0, scale=1.0)

    attack = team_ability_scale * unit_attack
    defence = team_ability_scale * unit_defence

    home_log_rate = (
        intercept
        + home_advantage
        + tf.gather(attack, home_team, axis=-1)
        - tf.gather(defence, away_team, axis=-1)
    )

    away_log_rate = (
        intercept
        - home_advantage
        + tf.gather(attack, away_team, axis=-1)
        - tf.gather(defence, home_team, axis=-1)
    )

    yield tfd.Poisson(log_rate=home_log_rate)
    yield tfd.Poisson(log_rate=away_log_rate)

Now isn't that a lot nicer to read? But I have a question - what's the downside?

4 Do we get the same log-probabilities?

When running MCMC or similar, we spend most of the time evaluating the log-probability (and it's gradient) of some parameters conditioned on the observed data - so this is the function I test in the post. Really I should test the gradients as well but I don't have loads of time at the moment.

observed_goals = [observed_home_goals, observed_away_goals]

@tf.function(autograph=False)
def target_log_prob_fn(*state):
    return joint_dist.log_prob(list(state) + observed_goals)


@tf.function(autograph=False)
def auto_target_log_prob_fn(*state):
    return auto_joint_dist.log_prob(list(state) + observed_goals)


state = joint_dist.sample()[:-2]

print(f"old: {target_log_prob_fn(*state):.6f}")
print(f"new: {auto_target_log_prob_fn(*state):0.6f}")
old: -16304.628906
new: -16304.628906

So no evidence of anything going wrong.

5 Do we get the same speeds?

import time as tm

warmup = 5
iterations = 10_000

def benchmark(fn):
    for _ in range(warmup):
        _ = fn()

    start = tm.time()

    for _ in range(iterations):
        result = fn()

    end = tm.time()

    return end - start, result


old_time, old_result = benchmark(lambda: target_log_prob_fn(*state))
new_time, new_result = benchmark(lambda: auto_target_log_prob_fn(*state))

print(f"old: {old_time:0.2f} seconds, result: {old_result:.6f}")
print(f"new: {new_time:0.2f} seconds, result: {new_result:.6f}")
old: 5.32 seconds, result: -16304.628906
new: 5.16 seconds, result: -16304.628906

Hmmm! I ran this a few times and it looks like the old way could be a little faster, but this is not really conclusive, and the speed of the gradient calculation is probably the bottleneck anyway.

6 Conclusion

I really like this new way of writing models and I would use it wherever possible. The downsides are that it might be slower and it relies on tf.vectorized_map - so some operations might not be supported. There may also be memory issues but I didn't look at that here.

I'd even happily give up some runtime performance if it reduced the number of shape errors I get in my code.