A contemplation of $\text{logsumexp}$

Published:

$\text{logsumexp}$ is an interesting little function that shows up surprisingly often in machine learning. Join me in this post to shed some light on $\text{logsumexp}$: where it lives, how it behaves, and how to interpret it.

What is $\text{logsumexp}$?

Let $\mathbf{x} \in \mathbb{R}^n$. $\text{logsumexp}(\mathbf{x})$ is defined as:

$$\begin{eqnarray*} \text{logsumexp}(\mathbf{x}) = \text{log} \left( \sum_{i} \text{exp}(x_i) \right). \end{eqnarray*}$$

Numerically, $\text{logsumexp}$ is similar to $\text{max}$: in fact, it’s sometimes called the “smooth maximum” function. For example:

$$\begin{eqnarray*} \text{max}([1,2,3]) = 3 \end{eqnarray*}$$
$$\begin{eqnarray*} \text{logsumexp}([1,2,3]) = 3.4076 \end{eqnarray*}$$

Examples of $\text{logsumexp}$

Here are some places in machine learning where $\text{logsumexp}$ is used.

Softmax classifiers

In a softmax classifier, the likelihood of label $i$ is defined as:

$$\begin{eqnarray*} p_{\theta}(i | \mathbf{l}) = \text{softmax}(\mathbf{l})_i, \end{eqnarray*}$$

where $\mathbf{l}$ is the vector of logits (unnormalized scores for each label).

Softmax classifiers are trained by minimizing the negative log-likelihood loss:

$$\begin{align*} -\text{log } p_{\theta}(i | \mathbf{l}) &= -\text{log} \left( \text{softmax}(\mathbf{l})_i \right) \\ &= -\text{log} \left( \text{exp}(l_i) / \sum_j \text{exp}(l_j) \right) \\ &= -\text{log} \left( \text{exp}(l_i)\right) +\text{log} \left(\sum_j \text{exp}(l_j) \right) \\ &= -l_i + \text{logsumexp}(\mathbf{l}). \end{align*}$$

Our friend $\text{logsumexp}$ appears in the last line.

Global pooling

For sequence classification tasks, it is usually necessary to map a variable length sequence of feature vectors to a single feature vector to be able to use something like a softmax classifier. To obtain a single vector, global pooling operations like max pooling or mean pooling can be used. Another aggregation method, which is less commonly used but has some of the advantages of both mean and max pooling, is $\text{logsumexp}$ pooling. (See this paper for an example.)

Latent alignment models

In latent alignment models, like connectionist temporal classification (CTC), dynamic programming is used to add up the probability of all possible alignments of an input sequence $\mathbf{x}$ to an output sequence $\mathbf{y}$ to train the model.

The dynamic programming algorithm in CTC uses the following recursion:

$$\begin{eqnarray*} \alpha_{s,t} = (\alpha_{s,t-1} + \alpha_{s-1,t-1} + \alpha_{s-2,t-1}) \cdot p_{\theta}(y_s | x_t)\end{eqnarray*}$$

This algorithm multiplies a long chain of probabilities, and so will underflow when aligning long sequences (like speech). Instead, we can run the algorithm in the log domain, in which case the recursion becomes:

$$\begin{align*} \text{log}(\alpha_{s,t}) &= \text{log}(\alpha_{s,t-1} + \alpha_{s-1,t-1} + \alpha_{s-2,t-1}) + \text{log } p_{\theta}(y_s | x_t)\\ &= \text{log}(\text{exp}(\text{log }\alpha_{s,t-1}) + \text{exp}(\text{log }\alpha_{s-1,t-1}) + \text{exp}(\text{log }\alpha_{s-2,t-1})) + \text{log } p_{\theta}(y_s | x_t)\\ &= \text{logsumexp}([\text{log }\alpha_{s,t-1},\text{log }\alpha_{s-1,t-1}, \text{log }\alpha_{s-2,t-1}]) + \text{log } p_{\theta}(y_s | x_t)\\ \end{align*}$$

Similar recursions using $\text{logsumexp}$ can be derived for the forward-backward algorithm used in Hidden Markov Models and Transducer models.

Fun fact: if we replace the $\text{logsumexp}$ with $\text{max}$, we get the Viterbi algorithm, which gives us the score of the single most likely alignment (and if we backpropagate, the alignment itself).

Some properties of $\text{logsumexp}$

$\text{logsumexp}$ has some useful properties. It is:

  • Convex. If you can pose your machine learning problem as a convex optimization problem, you can solve it quickly and reliably.
  • Differentiable everywhere. This is nice to have if your optimization algorithm is picky and doesn’t like functions with non-differentiable points, like $\text{max}$. For those picky algorithms, we can approximate $\text{max}$ using $\text{logsumexp}$.
  • Associative. So:
$$\begin{align*} \text{logsumexp}([a,b,c,d]) &= \text{logsumexp}([ \\ & \text{logsumexp}([a,b]), \\ & \text{logsumexp}([c,d]) \\ ]). \\ \end{align*}$$
  • (Hence, it can be computed in just $\text{log}_2(n)$ timesteps using a parallel reduction, where $n$ is the length of the vector we’re $\text{logsumexp}$ing.)
  • Close to $\text{max}$. In what sense “close”? It is bounded as follows:
$$\begin{eqnarray*} \text{max}(\mathbf{x}) \leq \text{logsumexp}(\mathbf{x}) \leq \text{max}(\mathbf{x}) + \text{log}(n) \end{eqnarray*}$$

The $\text{logsumexp}$ trick

Computing $\text{log} \left( \sum_{i} \text{exp}(x_i) \right)$ directly is numerically unstable because of the $\text{exp}$. Try the following in numpy:

x = np.array([7000,8000,9000])
np.log(np.sum(np.exp(x)))

From the approximation $\text{max}(\mathbf{x}) \approx \text{logsumexp}(\mathbf{x})$, we know that the result should be a little over 9000—but if you run this code, the result will be infinity because of overflow.

Instead, to compute $\text{logsumexp}$, use the following trick:

$$\begin{eqnarray*} \text{logsumexp}(\mathbf{x}) = \text{log} \left( \sum_{i} \text{exp}(x_i - \text{max}(\mathbf{x})) \right) + \text{max}(\mathbf{x}). \end{eqnarray*}$$

(See this post for the proof that the trick works.)

Now we won’t get an overflow because we’re taking the $\text{exp}$ of $[-2000,-1000,0]$ instead of $[7000,8000,9000]$. If we now run this instead:

x = np.array([7000,8000,9000])
np.log(np.sum(np.exp(x - x.max()))) + x.max()

we’ll get what we expect.

Takeaways

  • $\text{logsumexp}$ is everywhere!
  • You can make certain dense equations easier to digest by identifying instances of $\text{logsumexp}$ and mentally replacing them with $\text{max}$.

    For example, from the discussion of softmax classifiers above, you now know that the loss for a classifier is just the difference between the maximum of the logits (roughly) and the logit for the right answer!

  • When computing $\text{logsumexp}$, use the “logsumexp trick”. (Or just check if your library already has a numerically stable $\text{logsumexp}$ function, as PyTorch does.)