Torch Tensor: broadcasting
Intuitions for how to prepare broadcastable tensors.
I’ve found sites like Philippe Adjiman’s blog unsatisfying on understanding torch tensor broadcasting, so I decided to write up this post.
I think most broadcasting problems fall into two high-level buckets:
- Matching dimension: Output tensors preserve axes. Broadcasting here is for missing exes.
- Pairwise expansion: Output tensors has a new axis (e.g., outer product, pairwise difference). Broadcasting here is for pairwise interactions.
Broadcasting 101
Broadcasting is a fundamental concept in PyTorch that allows element-wise operations between tensors with diverse shapes under one condition: there is a size-1 dimensions that allow PyTorch to stretch and make the dimensions match.
import torch
a = torch.tensor([[1], [2], [3]]) # shape: (3, 1)
b = torch.tensor([10, 20, 30]) # shape: (3,) → broadcast to (1, 3) → (3, 3)
result = a + b
# tensor([[11, 21, 31],
# [12, 22, 32],
# [13, 23, 33]])
New Dimension
Oftentimes, we need to prepare the “broadcastable” tensors for operations. We can use None to introduce new dimension, which leads to either wrap-around or wrap-within outcome.
X = [x, y]
# operation 1: wrap within (each element is now an array)
X[:, None] -> [[x], [y]] # shape (2, 1)
# oepration 2: wrap around (the whole array is wrapped by another array)
X[None, :] -> [[x, y]] # shape: (1, 2)
Broadcasting Usage A: Matching dimensions
This is the common case for LLMs. You have a “main” tensor, and then a smaller tensor that should repeat across one or more axes. Examples here include adding bias to logits, applying key padding mask, etc.
Here is an example of adding bias to every token in the corresponding batch item.
# input
x.shape = [B, T, Z]
bias.shape = [B, Z]
# expected output.shape
[B, T, Z] # answer is x + bias[:, None, :]
So within each batch (dimension B), we want to add Z from bias to every token’s Z.
The problem is that bias has no token dimension T — PyTorch can’t align it against x as-is. We need to wrap-within the B dimension to introduce a size-1 T axis:
bias[:, None, :] # shape: [B, 1, Z]
Now the shapes are broadcastable and the answer is therefore x + bias[:, None, :].
x: B × T × Z
bias[:, None,:] B × 1 × Z
result: B × T × Z ✓
and this is what happens to bias
# x has shape [2, 2, 3]
x = torch.tensor([
[[tok0x0, tok0x1, tok0x2], [tok1x0, tok1x1, tok1x2]], # batch 0
[[tok0y0, tok0y1, tok0y2], [tok1y0, tok1y1, tok1y2]], # batch 1
])
# before wrapping, bias has shape [2, 3]
bias = torch.tensor([
[0, 1, 2], # batch 0
[4, 5, 6,], # batch 1
])
# after wrapping (wrap-around for `Z` dimension)
# now bias has shape [2, 1, 3]
new_bias = bias[:, None, :]
new_bias = torch.tensor([
[[0, 1, 2]], # batch 0
[[4, 5, 6]], # batch 1
])
# now we can use under-the-hood broadcasting to add to x
# during `x + new_bias`
new_bias = torch.tensor([
[[0, 1, 2], [0, 1, 2]], # batch 0
[[4, 5, 6], [4, 5, 6],], # batch 1
])
Broadcasting Usage B: Pairwise Expansion
This is the case where the output has a new axis that didn’t exist in either input. Instead of repeating one tensor across an existing axis, we’re computing every interaction between two sets of elements — think pairwise distance, attention scores, or outer products.
Here is an example of computing pairwise differences between two sets of vectors.
# input
A.shape = [N, Z] # N vectors
B.shape = [M, Z] # M vectors
# expected output.shape
[N, M, Z] # answer is A[:, None, :] - B[None, :, :]
For each of the N vectors in A, we want to subtract every one of the M vectors in B. The output has a brand new M axis that came from the interaction — neither input had both N and M at the same time.
The trick is to use both None operations together: wrap-within A to give it a size-1 M slot, and wrap-around B to give it a size-1 N slot:
A[:, None, :]: N × 1 × Z
B[None, :, :]: 1 × M × Z
result: N × M × Z ✓
and this is what happens to each tensor:
# A has shape [2, 3]: 2 vectors of dim 3
A = torch.tensor([
[a0, a1, a2], # vector 0
[b0, b1, b2], # vector 1
])
# B has shape [3, 3]: 3 vectors of dim 3
B = torch.tensor([
[x0, x1, x2], # vector 0
[y0, y1, y2], # vector 1
[z0, z1, z2], # vector 2
])
# wrap-within A: introduces a size-1 M slot
# A[:, None, :] has shape [2, 1, 3]
A_expanded = torch.tensor([
[[a0, a1, a2]], # vector 0, one slot
[[b0, b1, b2]], # vector 1, one slot
])
# wrap-around B: introduces a size-1 N slot
# B[None, :, :] has shape [1, 3, 3]
B_expanded = torch.tensor([
[[x0, x1, x2],
[y0, y1, y2],
[z0, z1, z2]],
])
# under-the-hood broadcasting stretches both before subtracting
# A_expanded becomes shape [2, 3, 3]
A_expanded = torch.tensor([
[[a0, a1, a2], [a0, a1, a2], [a0, a1, a2]], # vector 0 repeated for each B
[[b0, b1, b2], [b0, b1, b2], [b0, b1, b2]], # vector 1 repeated for each B
])
# B_expanded becomes shape [2, 3, 3]
B_expanded = torch.tensor([
[[x0, x1, x2], [y0, y1, y2], [z0, z1, z2]], # all B vectors, for A vector 0
[[x0, x1, x2], [y0, y1, y2], [z0, z1, z2]], # all B vectors, for A vector 1
])
# result[i, j] = A[i] - B[j], shape [2, 3, 3]
result = A_expanded - B_expanded
Key contrast with Usage A: In matching dimensions, only one tensor needs a
None— the smaller one gets stretched into the existing axes of the larger one. In pairwise expansion, both tensors get aNone, each surrendering a size-1 slot to the other, producing a new axis that neither originally had.
Complex use of pairwise expansion: Attention
In attention, the score matrix is computed as $QK^T$, where:
Q.shape = [B, H, T, Z] # T query vectors
K.shape = [B, H, T, Z] # T key vectors
scores.shape = [B, H, T, T] # every query paired with every key
The output [B, H, T, T] has a new T×T block that didn’t exist in either input — for each of the T query positions, you’re computing a score against every one of the T key positions. That query-key cross term is exactly the pairwise expansion pattern.
# Q: [B, H, T, Z] → [B, H, T, 1, Z]
# K: [B, H, T, Z] → [B, H, 1, T, Z]
# scores = (Q[:,:,:,None,:] * K[:,:,None,:,:]).sum(-1)
# shape: [B, H, T, T]