In this post I have a look at the JointDistributionCoroutineAutoBatched class in TensorFlow Probability. The title is based on this C++ article.

I'm going to attempt to show the difference between JointDistributionCoroutineAutoBatched and JointDistributionCoroutine by displaying the code differences and runtime speed.

## 1 Data

As per usual, I use some football data from football-data.co.uk, you can see some code to download it and process it here.

num_teams = len(team_names)
home_team = soccer_data["home_team"].to_numpy()
away_team = soccer_data["away_team"].to_numpy()
observed_home_goals = soccer_data["home_goals"].to_numpy()
observed_away_goals = soccer_data["away_goals"].to_numpy()


  home_team_name away_team_name  home_goals  away_goals  home_team  away_team
0    Aston Villa       West Ham           3           0          1         32
1      Blackburn        Everton           1           0          3         12
2         Bolton         Fulham           0           0          5         13
3        Chelsea      West Brom           6           0         10         31
4     Sunderland     Birmingham           2           2         27          2


## 2 Model

Nothing new here, just a simple Poisson model for the observed home and away goals using an an intercept and home advantage terms, along with a measure of attacking and defensive capabilities for each team in the data.

In writing the model is:

\begin{align*} \sigma &\sim HalfNormal(1) \\ \alpha &\sim Normal(0, \sigma^2) \\ \beta &\sim Normal(0, \sigma^2) \\ \psi &\sim Normal(0, 1^2) \\ \gamma &\sim Normal(0, 1^2) \\ H &\sim Poisson(\exp(\psi + \gamma + \alpha_i + \beta_j)) \\ A &\sim Poisson(\exp(\psi + \alpha_j + \beta_i)) \end{align*}

where team $$i$$ plays at home to team $$j$$. This is a heirarchical model since it has "parameters which depend on parameters" - the attack ($$\alpha$$) and defence ($$\beta$$) depend on the scale ($$\sigma$$). For completeness, $$H$$ and $$A$$ are the home and away goals respectively, $$\psi$$ is an intercept term, and $$\gamma$$ is a home advantage term.

I use greek letters in this sort of stuff as it is common practice and takes up less space, but I prefer to look at the code to see what is going on now, where I try to write descriptive names for the parameters.

## 3 Two ways to write the model in code

Previously, I'd write the above model in code using something like this:

import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
tfl = tf.linalg

Root = tfd.JointDistributionCoroutine.Root

@tfd.JointDistributionCoroutine
def joint_dist():
team_ability_scale = yield Root(tfd.HalfNormal(scale=1.0))
unit_attack = yield Root(
tfd.MultivariateNormalLinearOperator(
loc=0.0,
scale=tfl.LinearOperatorIdentity(num_rows=num_teams),
)
)
unit_defence = yield Root(
tfd.MultivariateNormalLinearOperator(
loc=0.0,
scale=tfl.LinearOperatorIdentity(num_rows=num_teams),
)
)
intercept = yield tfd.Normal(loc=0.0, scale=1.0)

attack = team_ability_scale[..., tf.newaxis] * unit_attack
defence = team_ability_scale[..., tf.newaxis] * unit_defence

home_log_rate = (
intercept
+ tf.gather(attack, home_team, axis=-1)
- tf.gather(defence, away_team, axis=-1)
)

away_log_rate = (
intercept
+ tf.gather(attack, away_team, axis=-1)
- tf.gather(defence, home_team, axis=-1)
)

yield tfd.Independent(
tfd.Poisson(log_rate=home_log_rate), reinterpreted_batch_ndims=1
)

yield tfd.Independent(
tfd.Poisson(log_rate=away_log_rate), reinterpreted_batch_ndims=1
)


which has a bunch of stuff in it I'd rather not need to deal with:

1. Root
2. MultivariateNormalLinearOperator
3. [..., tf.newaxis]
4. tfd.Independent

step forward JointDistributionCoroutineAutoBatched:

@tfd.JointDistributionCoroutineAutoBatched
def auto_joint_dist():
team_ability_scale = yield tfd.HalfNormal(scale=1.0)
unit_attack = yield tfd.Normal(loc=tf.zeros(num_teams), scale=1.0)
unit_defence = yield tfd.Normal(loc=tf.zeros(num_teams), scale=1.0)
intercept = yield tfd.Normal(loc=0.0, scale=1.0)

attack = team_ability_scale * unit_attack
defence = team_ability_scale * unit_defence

home_log_rate = (
intercept
+ tf.gather(attack, home_team, axis=-1)
- tf.gather(defence, away_team, axis=-1)
)

away_log_rate = (
intercept
+ tf.gather(attack, away_team, axis=-1)
- tf.gather(defence, home_team, axis=-1)
)

yield tfd.Poisson(log_rate=home_log_rate)
yield tfd.Poisson(log_rate=away_log_rate)


Now isn't that a lot nicer to read? But I have a question - what's the downside?

## 4 Do we get the same log-probabilities?

When running MCMC or similar, we spend most of the time evaluating the log-probability (and it's gradient) of some parameters conditioned on the observed data - so this is the function I test in the post. Really I should test the gradients as well but I don't have loads of time at the moment.

observed_goals = [observed_home_goals, observed_away_goals]

@tf.function(autograph=False)
def target_log_prob_fn(*state):
return joint_dist.log_prob(list(state) + observed_goals)

@tf.function(autograph=False)
def auto_target_log_prob_fn(*state):
return auto_joint_dist.log_prob(list(state) + observed_goals)

state = joint_dist.sample()[:-2]

print(f"old: {target_log_prob_fn(*state):.6f}")
print(f"new: {auto_target_log_prob_fn(*state):0.6f}")

old: -16304.628906
new: -16304.628906


So no evidence of anything going wrong.

## 5 Do we get the same speeds?

import time as tm

warmup = 5
iterations = 10_000

def benchmark(fn):
for _ in range(warmup):
_ = fn()

start = tm.time()

for _ in range(iterations):
result = fn()

end = tm.time()

return end - start, result

old_time, old_result = benchmark(lambda: target_log_prob_fn(*state))
new_time, new_result = benchmark(lambda: auto_target_log_prob_fn(*state))

print(f"old: {old_time:0.2f} seconds, result: {old_result:.6f}")
print(f"new: {new_time:0.2f} seconds, result: {new_result:.6f}")

old: 5.32 seconds, result: -16304.628906
new: 5.16 seconds, result: -16304.628906


Hmmm! I ran this a few times and it looks like the old way could be a little faster, but this is not really conclusive, and the speed of the gradient calculation is probably the bottleneck anyway.

## 6 Conclusion

I really like this new way of writing models and I would use it wherever possible. The downsides are that it might be slower and it relies on tf.vectorized_map - so some operations might not be supported. There may also be memory issues but I didn't look at that here.

I'd even happily give up some runtime performance if it reduced the number of shape errors I get in my code.