I want to fine tune some LLM models with my own dataset which contains very long examples (a little > 2048 tokens). vRAM usage jumps up several GBs by just increasing the Cutoff Length from 512 to 1024.
Is there a way to feed those long examples into the models without increasing vRAM significantly?
VRAM scales quadratically as sequence length increases. I’m not aware of any solutions. Even efficient implementations of long context fine tuning such as LongLoRA only improve speed and quality, but leave memory usage the same as LoRA.
I recommend ensuring you’re reducing memory in other ways:
Ensure you’re using 4-bit QLoRA
Ensure batch size is 1
Ensure you’re using FlashAttention-2
Ensure your optimizer state in in CPU memory by utilizing a paged optimizer.
Use gradient checkpointing.
You also could do something more experimental like employ Mistral with a sliding window of 1024 tokens to capture 2048 tokens of context while only using the memory of 1024 tokens.
Or you could just summarize or prune your long examples.