Updating Statistics in Batches

probability
machine learning
streaming
How to calculate statistics on data streams or out-of-core datasets in a parallelizable way.
Published

April 13, 2025

# Simple Statistics

Simple statistics are common in a number of software and data domains because they are simple to calculate and simple to understand. Most common are the arithmetic mean and variance. For a given set of values, \(X\), the mean, \(\mu\), is a measure of central tendency and locates a central value of the distribution. The variance, \(\sigma^2\), provides a measure of the dispersion of the observed values of \(x\) about the central value \(\mu\) (Schott 2016, 20–25). \[ \begin{alignat*}{2} \text{E}(X) &= \mu &&= \frac{1}{N} \sum^N_{i=1} x_i \tag{mean} \\ \text{Var}(X) &= \sigma^2 &&= \frac{1}{N} \sum^N_{i=1} (x_i - \mu)^2 \tag{population variance} \end{alignat*} \]

Suppose we have a set of \(N=6\) values, \(X = [1, 2, 1, 2, 4, 5]\); then: \[ \begin{aligned} \mu &= \frac{1 + 2 + 1 + 2 + 4 + 5}{6} \\[10pt] \sigma^2 &= \frac{(1-2.5)^2 + (2-2.5)^2 + (1-2.5)^2 + (2-2.5)^2 + (4-2.5)^2 + (5-2.5)^2}{6} \end{aligned} \]

Code
import torch

X = torch.tensor([1.0, 2.0, 1.0, 2.0, 4.0, 5.0])

print(f"mean = {X.mean()}")
print(f"var  = {X.var(correction=0)}")
print(f"sample_var = {X.var()}")
mean = 2.5
var  = 2.25
sample_var = 2.700000047683716

This should be no surprise if you’ve worked with data before and in theory, the mean and variance are straightforward to calculate. But in practice, datasets nowadays have become very large in the age of “Big Data”. Not only are datasets now too big to fit into memory, datasets are ever growing and constantly being updated. How do we calculate these measures if the data is too large to process in memory? How can we update these measures when new data is added to the dataset? Do we have to reprocess every value? What if we’re working with stream data?

Storing large datasets can become costly but reprocessing large datasets just to recalculate a simple statistic can be prohibitive. Many datasets have a temporal component and simple statistics can be calculated on a moving or sliding window because the most important data is the most recent data. Many other datasets can’t just calculate over a sliding window. For example, deep learning models are now being trained on very large datasets, e.g. the whole of the internet. Since the entire internet can’t fit into memory of even the largest compute instance, these models are trained in batches. Typically, the mean is computed over a batch (BatchNorm) but when evaluating the model on a hold-out dataset, we want to compute statistics over the whole of the hold-out dataset.

We could just keep a running sum of the numerator and running count of the denominator but there are some drawbacks. Calculating the numerator of the mean could result in an arithmetic overflow when dealing with large numbers. Calculating the numerator of the variance involves recalculating the mean for each value in the batch and involves calculating the sums of squares, which can lead to numerical instability.

Suppose our dataset arrives in batches of 2: \(X = [[1, 2], [1, 2], [4, 5]]\). We can update the mean, population variance, and sample variance as:

\[ \boxed{ \begin{align} \mu &= \frac{a}{a+b} \mu_a + \frac{b}{a+b} \mu_b \\[10pt] \sigma^2 &= \underbrace{\frac{a}{a+b} \sigma_a^2 + \frac{b}{a+b} \sigma_b^2}_{\text{within group variance}} + \underbrace{\frac{ab}{(a+b)^2} \big( \mu_a - \mu_b \big)^2}_{\text{between group variance}} \\ s^2 &= \underbrace{\frac{1}{a+b-1}}_{\text{Bessel's correction}} \Bigg[ \underbrace{\frac{a-1}{a+b} s_a^2 + \frac{b-1}{a+b} s_b^2}_{\text{within sample variance}} + \underbrace{\frac{ab}{a+b}(\mu_a - \mu_b)^2}_{\text{between sample variance}} \Bigg] \end{align} } \]

where:

  • \(a\) is the running count of all values seen so far
  • \(b\) is the count of the values in the new batch of data
  • \(\mu_a\) is the mean of all values seen so far
  • \(\mu_b\) is the mean of the new batch of data
  • \(\sigma_a^2\) / \(s_a^2\) is the variance of all the values seen so far
  • \(\sigma_b^2\) / \(s_b^2\) is the variance of the new batch of data

The updated mean is a linear combination of the mean over the seen data and the mean over new data weighted by the relative sample size of each mean.

The updated variance is a linear combination of the seen (within) data variance and new (between) data variance plus a correction by the means.

Code
import torch

def update_mean(a: int, b: int, mean_a: float, mean_b: float) -> float:
    return (a * mean_a + b * mean_b) / (a + b)

def update_variance(a: int, b: int, mean_a: float, mean_b: float, var_a: float, var_b: float) -> float:

    within_var = a/(a+b) * var_a + b/(a+b) * var_b
    between_sample_var = (a*b * (mean_a - mean_b)**2) / (a+b)**2

    return (within_var + between_sample_var)

def update_sample_variance(a: int, b: int, mean_a: float, mean_b: float, var_a: float, var_b: float) -> float:

    within_sample_var = (a - 1) * var_a + (b - 1) * var_b
    between_sample_var = (a*b) / (a+b) * (mean_a - mean_b)**2

    return (within_sample_var + between_sample_var) / (a+b-1)

X = torch.tensor([[1.0, 2.0], [1.0, 2.0], [4.0, 5.0]])

# while not strictly correct, we can initialize the mean and variances to 0 rather than NaN
count = 0
mean = 0.0
var = 0.0
sample_var = 0.0

for batch in X:
    mean_updated = update_mean(count, batch.numel(), mean, batch.mean())
    var_updated = update_variance(count, batch.numel(), mean, batch.mean(), var, batch.var(correction=0))
    sample_var_updated = update_sample_variance(count, batch.numel(), mean, batch.mean(), sample_var, batch.var())

    count += batch.numel()
    mean = mean_updated
    var = var_updated
    sample_var = sample_var_updated

print(f"mean = {mean}")
print(f"var = {var}")
print(f"sample_var = {sample_var}")
mean = 2.5
var = 2.25
sample_var = 2.700000047683716

# Derivation

Batch Updating the Mean

\[ \begin{align} \mu = \frac{1}{n} \sum_{i=1}^n x_i \tag{1} \\ n \mu = \sum_{i=1}^n x_i \tag{2} \end{align} \]

We can think of the two batches we are combining as two ranges that are being combined. So if batch A is between \([1, a]\) and batch B is between \([1, b]\) and we are starting with batch A and adding batch B, the combined range is then \([1, a+b]\), where the total number of samples observed is \(n=a+b\).

\[ \sum_{i=1}^{a+b} = \sum_{i=1}^{a} + \sum_{i=a+1}^{a+b} \tag{3} \]

Knowing this we can express the combined mean as:

\[ \begin{align} \mu &= \frac{1}{a+b} \sum_{i=1}^{a+b} x_i \\ &= \frac{1}{a+b} \left( \sum_{i=1}^{a} x_i + \sum_{i=a+1}^{a+b} x_i \right) \tag*{split the sum using eq 3} \\ &= \frac{1}{a+b} \Big( a\mu_a + b\mu_b \Big) \tag*{substitute with eq 2} \\ \Aboxed{ \mu &= \frac{a}{a+b} \mu_a + \frac{b}{a+b} \mu_b } \tag{4} \end{align} \]


Batch Updating the Variance

For the variance, we’ll first derive how to combine two population variances and from there work out how to combine the sample variance.

Population Variance

Let’s start by simplifying the population variance equation.

\[ \begin{align} \sigma^2 &= \frac{1}{n} \sum_{i=1}^n (x_i - \mu)^2 \\ &= \frac{1}{n} \sum_{i=1}^n \left( x_i^2 - 2\mu x_i + \mu^2 \right) \tag*{complete the square} \\ &= \frac{1}{n} \left[ \sum_{i=1}^n x_i^2 - 2\mu \sum_{i=1}^n x_i \mu^2 \right] \tag*{distribute the sum} \\ &= \frac{1}{n} \sum_{i=1}^n x_i^2 - 2\mu \frac{1}{n} \sum_{i=1}^n x_i + \mu^2 \tag*{substitute eq 1} \\ &= \frac{1}{n} \sum_{i=1}^n x_i^2 - 2\mu^2 + \mu^2 \tag*{simplify} \\ \sigma^2 &= \frac{1}{n} \sum_{i=1}^n x_i^2 - \mu^2 \tag{5} \\ \sigma^2 + \mu^2 &= \frac{1}{n} \sum_{i=1}^n x_i^2 \tag{6} \end{align} \]

We can follow a similar process as we did when working out the combined mean, starting from equation 5.

\[ \begin{align} \sigma^2 &= \frac{1}{a+b} \sum_{i=1}^{a+b} x_i^2 - \mu^2 \tag{split the sum using eq 3} \\ &= \frac{1}{a+b} \left( \sum_{i=1}^{a} x_i^2 + \sum_{i=a+1}^{a+b} x_i^2 \right) - \mu^2 \tag{distribute} \\ &= \frac{1}{a+b} \sum_{i=1}^{a} x_i^2 + \frac{1}{a+b} \sum_{i=a+1}^{a+b} x_i^2 - \mu^2 \tag{multiply by one} \\ &= \frac{1}{a+b} \frac{a}{1} \underbrace{\frac{1}{a} \sum_{i=1}^{a} x_i^2}_{\sigma_a^2 + \mu_a^2} + \frac{1}{a+b} \frac{b}{1} \underbrace{\frac{1}{b} \sum_{i=a+1}^{a+b} x_i^2}_{\sigma_b^2 + \mu_b^2} - \mu^2 \tag{substitute using eq 6} \\ &= \frac{a}{a+b} \left( \sigma_a^2 + \mu_a^2 \right) + \frac{b}{a+b} \left( \sigma_b^2 + \mu_b^2 \right) - \mu^2 \tag{substitute with eq 4} \\ &= \frac{a}{a+b} \left( \sigma_a^2 + \mu_a^2 \right) + \frac{b}{a+b} \left( \sigma_b^2 + \mu_b^2 \right) - \left( \frac{a}{a+b} \mu_a + \frac{b}{a+b} \mu_b \right) \tag{simplify} \\ \Aboxed{ \sigma^2 &= \underbrace{\frac{a}{a+b} \sigma_a^2 + \frac{b}{a+b} \sigma_b^2}_{\text{within group variance}} + \underbrace{\frac{ab}{(a+b)^2} \big( \mu_a - \mu_b \big)^2}_{\text{between group variance}} } \end{align} \]

Sample Variance

To find the sample variance, \(s^2\), we can use Bessel’s correction, \(\frac{n}{n-1}\):

\[ \begin{align} s^2 &= \frac{n}{n-1}\sigma^2 \\ \frac{n-1}{n} s^2 &= \sigma^2 \tag{7} \end{align} \]

then we have:

\[ \begin{align} \sigma^2 = \frac{a+b-1}{a+b} s^2 \\ \sigma_a^2 = \frac{a-1}{a} s_a^2 \\ \sigma_b^2 = \frac{b-1}{b} s_b^2 \end{align} \]

we can then substitute these in to find the sample variance,

\[ \begin{align} \sigma^2 &= \frac{a}{a+b} \sigma_a^2 + \frac{b}{a+b} \sigma_b^2 + \frac{ab}{(a+b)^2} \big( \mu_a - \mu_b \big)^2 \tag{factor out} \\ &= \frac{1}{a+b} \left[ a\sigma_a^2 + b \sigma_b^2 + \frac{ab}{(a+b)} \big( \mu_a - \mu_b \big)^2 \right] \tag{substitute} \\ \frac{a+b-1}{a+b}\sigma^2 &= \frac{1}{a+b} \left[ a \left(\frac{a-1}{a}\sigma_a^2 \right) + b \left( \frac{b-1}{b} \sigma_b^2 \right) + \frac{ab}{(a+b)} \big( \mu_a - \mu_b \big)^2 \right] \tag{simplify} \\ \Aboxed{ s^2 &= \frac{1}{a+b-1} \left[ (a-1) s_a^2 + (b-1) s_b^2 + \frac{ab}{(a+b)} \big( \mu_a - \mu_b \big)^2 \right] } \end{align} \]


These formulas for updating the mean and variance are particularly for streaming or large datasets as only the necessary statistics (count, mean, variance) are maintained rather than storing all values, making it memory-efficient and parallelizable and also numerically stable.

# Excersises Left Up to the Reader

Now that you know how to derive batch updates for the mean and variance, you should be well-equiped to derive the covariance, \(\text{Cov}\), and pearson correlation, \(\rho\):

\[ \begin{align} \text{Cov}(X, Y) &= \frac{1}{n} \sum_{i=1}^n (x_i - \bar{x})(y_i - \bar{y}) \\[10pt] \rho(X) &= \frac{\sum_{i=1}^n (x_i - \bar{x})(y_i - \bar{y})} {\sqrt{ \sum_{i=1}^n (x_i - \bar{x})^2 \sum_{i=1}^n (y_i - \bar{y})^2 }} \\[10pt] \rho(X, Y) &= \frac{\text{Cov}(X, Y)}{\sqrt{\text{Var}(X) \text{Var}(Y)}} \end{align} \]

# Further Reading

  • See this leetcode problem about calculating the median from a data stream.
  • torcheval is a library for calculating metrics are out-of-core datasets.

References

Schott, James R. 2016. Matrix Analysis for Statistics. 3rd ed. Wiley.