TL;DR: Why does GPU memory usage spike during gradient update step (can’t account for 10gbs) but then drop down?

I’ve been working on fine-tuning some of the larger LMs available on HuggingFace (e.g. Falcon40B and Llama-2-70B) and so far all my estimates for memory requirements don’t add up. I have access to 4 A100-80gb GPUs and was fairly confident that I should have enough RAM to fine-tune Falcon40B with LoRA but I keep getting CUDA OOMs errors. I have figured out ways to get things running, but this made me realize I don’t really understand how memory is allocated during training.

Here’s my understanding of where memory goes when you want to train a model:

Setting

-> Defining a TOTAL_MEMORY = 0 (MB) and I will update it as I move through each step that adds memory.

-> Checking memory usage by “watching” nvidia-smi with a refresh every 2 seconds.

-> Model is loaded in fp16

-> Using Falcon7B with ~7B parameters (it’s like 6.9 but close enough)

-> Running on single A100-80gb GPU in a jupyter notebook

Loading The Model:

  • CUDA Kernels for torch and so on (on my machine I’m seeing about 900mb per GPU). TOTAL_MEMORY + 900 -> TOTAL_MEMORY=900
  • Model weights (duh). Say you have a 7B parameter model loaded in using float16, then you are looking at 2 bytes * 7B parameters = 14B bytes. ~= 14gb of GPU VRAM. TOTAL_MEMORY + 14_000 -> TOTAL_MEMORY=15_000 (rounding)

with that the model should load on a single GPU.

Training (I am emulating a single forward and backward step by running each part separately)

  • The data. I am passing in a single small batch of a dummy input (random ints) so I will assume this does not add a substantial contribution to the memory usage.
  • Forward pass. For some reason memory jumps by about 1000mb. Perhaps this is due to cached intermediate activations? Though I feel like that should be way larger. TOTAL_MEMORY + 1_000 -> TOTAL_MEMORY = 16_000.
  • Compute the cross-entropy loss. The loss tensor will utilize some memory, but that doesn’t seem to be a very high number, so I assume it does not contribute.
  • Computing gradients with respect to parameters by calling `loss.backwards()`. This results in a substantial memory spike (goes up by 15_000 MB). I imagine this is a result of storing a gradient values for every parameter in the model? TOTAL_MEMORY + 15_000 -> TOTAL_MEMORY = 30_000
  • Updating model parameters by calling `optimizer.step()`. This results in yet another memory spike, where GPU memory usage goes up more than 38_000MB. Not really sure why. My best guess is that this is where AdamW starts storing 2 x momentum value for each parameter. If we do the math (assuming optimizer state values are in fp16) ----> 2 bytes * 2 states * 7B = 28B bytes ~= 28gb. TOTAL_MEMORY + 38_000 -> TOTAL_MEMORY = 68_000

LoRA would reduce this number, by dropping the amount needed during the optimizer step, but I have not yet done any tests on that so don’t have any numbers.

I believe that’s all the major components.

So where do the extra 10gb come from? Maybe it’s one of those “torch reserved that memory but isn’t actually using it”. So I check by inspecting the output of `torch.cuda.memory_allocated` and `torch.cuda.max_memory_allocated` and perhaps there’s something there.

memory allocated (after backward step): 53gb

max memory allocated: 66gb

Meaning at some point, an extra 13 gb were needed, but then were freed up.

My question for you folks, does anybody know where those extra 10GBs I am not finding in my math are coming from? What happens that 13GBs are freed up after the backward pass? Are there any additional steps that require memory that I missed?

This has been bothering me for a while and I’d love to get a better sense so any expert input, resources or other suggestions you may have will be greatly appreciated!

Edit: I also know that when you train with the `Trainer` class you can enable gradient checkpointing, to reduce memory usage by recomputing some of the intermediate activations during the backward pass. So which part of the whole process would this reduce memory usage at?