Working Notes: a commonplace notebook for recording & exploring ideas.
Home. Site Map. Subscribe. More at expLog.
— Kunal
2025-12-28
train a 70b parameter model on 15t tokens on 1024 h100s
total flops needed: 6 * 70b * 15t (6 is the magic)
h100 is 1979e12 (/2 because the listed number is with sparsity)
mfu is .5
memory use
digging into transformers [I need to do this explicitly]
memory usage
tensors
float32 / fp32 / single precision
exponent 8 bit, fraction 23 bit, sign 1 bit
gpt3 one matrix in feedforward: 2.3gb
float16 / half precision
exponent 5 bit / 10 bit fraction / 1 bit sign
dynamic range is pretty bad
large models can have instability with under/overflow with this
bfloat16 (2018, Google)
brain float, by google brain
8 bit exponent, 7 bit fraction, sign
same dynamic range as float32
torch.finfo
typically used for computations as it's good enough
optimizer states and params still need float32
fp8 (2022, nvidia)
very crude: e4m3 and e5m2 options
-448, 448 or -57344,57344
supported by h100s
float32 is too expensive
generally use mixed precision
want higher precision fgor something that's accumulated over time
test with torch.cuda.memory_allocated
tensors are pointers to memory with a way to index into the matrix
be aware of tensor views
untyped_tensor().data_ptr()
is_contiguous()
transpose: not contiguous anymore, cannot take more views
making it contiguous() will force a copy
elementwise operations create new tensors
triu is good for causal attention mask
matmul
when multiplying with not matching matrices, it'll just iteratate over the missing dimensions
einops
jaxtyping specifies dimensions in typesx: Float[torch.Tensor, "batch seq head hidden"] -- just documentationreduce, rearrangeflops
intuition
n points, each with d dims
map to a k dim vector
matmul: every i,j,k triple: one multiplication and one addition
2 times product of all dimensions
crude estimate for order of magnitude
hardware is designed for large matrix multiplication
generally only consider regimes where models are dominant
wall clock time
MFU: model flops utilization -- actual flop per sec / promised flop per second
= .5 is quite good
dominated by matmul
ignores all communication / overhead
gradients: 4 * total parameters
that's why roughly in an NN total is 6 * total params
this is the bulk of the computation for many models
works for most standard models
initialization
randomness
torch.manual_seed, np.random.seed, random.seeddata loading
np.memmap mapped to file loading on demandoptimizer
optimizer memory
checkpointing