Consider a dataset . Let and . Our goal is to learn s.t.
In fact, this mean-squared error comes from the fact that we assume
where and is the oracle weight. With this assumption, we can see that
Thus, we have
where is the normalizer of the Gaussian distribution. If we assume that 's are independent and identically distributed, we have
Taking the log and negate the term yields the objective we have just mentioned. The approach based on optimizing is called maximum likelihood estimation, which yields a point estimate .
However, in many applications, we just do not want only but rather the distribution of given the data to estimate (epstemic) uncertainty of the prediction. This distribution is called the posterior distribution, and it can be computed via
This is essentially the Bayes' rule. One can see that can be hard to compute; however, this can be done (analytically) with proper assumptions. In particular, we might choose that is structurally comparable with ; noting that the technical term is conjugate. For example, if our is a Gaussian, we might assume that where . With this choice of prior, we have
where is a constant when marginalizing out . From this, the posterior distribution is in a multivariate Gaussian whose covariance is
When expanding the density of multivariate Gaussian , we know that
Therefore, we can conclude that
Predictive Mean and Variance
Consider a new sample . We would like to know its prediction according to the posterior (i.e. averaging across all possible and its variance which is the uncertainty of the prediction. In particular, we know that
where is based on the assumptions that (1) the new prediction and the data given the model's parameters and (2) our model's parameters are independent of the new sample given the data . Because both terms in the integral are Gaussian, this distribution of the prediction is also a Gaussian. Let assume . Writing the two term together, we get
Expanding the term inside the exponent yields
Thus, we have
We can see that the exponent term inside the integrate is the form of Gaussian; in particular, this is
with an assumption that exists. Because of the Gaussian form, we thus have
Because is symmetric, i.e. , we can simplify , yielding
Combining all the results together we have
where we have
These results can be further simplified by using the Sherman–Morrison formula (1). Denote . We have
Similary, we can follow the same steps for ,
These are the predictive variance and predictive mean for logistic regression with Gaussian prior or ridge regression.
We can look closer at . Here the first term is a constrant we assume; more precisely it tells us about aleatoric uncertainty, which is the uncertainy due to noise in measurement. On the other hand, the second term is what we are interested in if we make prediction. It captures epistemic uncertainty, which indicates the level of knowledge one does not have in the problem or the model. Therefore, if one is interested in the uncertainty of her/his model 's prediction, one can determine the uncertainy by
Now, it is time to put things together. We take a dataset and train a linear regression model on four different subsets. We assume that all train samples are in the range , while test samples are .
Fig. 1: Ridge regression trained with data with different sizes; the more training data the more certain prediction it is, especially in the extrapolation regime.
From Fig. 1, we see the effect of extrapolation in the range that our training data does not cover. Without the posterior distribution, we would not know how much uncertainty we had if we relied the point-estimate of the solution (i.e. solutions of ML or MAP). Nevertheless, as we have more training samples, our model does not only do well in the interpolation regime (where training data is covered) but also the extrapolation regime.
Conclusion and References
Perhaps, in the next step, we shall look at classification tasks and how to estimate uncertainty of the prediction in such situations.
This post is my recap of Yarin Gal's tutorial at SMILES 2019. I also consulted mathematicalmok's Youtube channel.
The figure is generated from this Google Colab.