-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Using fp16 uses more memory than using fp32 #1349
Comments
I tried looking at the internal code of loading the model and it seems that model is moved to GPU and then converted to fp16, would that not consume more memory when the model is being loaded. Probably has nothing to do with the used memory but still... Megatron-LM/megatron/training/training.py line 535 # GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model] |
I am still trying to look throught he code but the main difference is the fp16 optimizer has groups with both fp32 and fp16 parameters, probably somewhere duplicate memory is being used or something, will try to investigate a bit more but some feedback on this would be appreciated, especially if someone can confirm their memory usage also increases for fp16 |
Maybe the cause for the increased memory is the parameter being detached and cloned iin the initialization of the FP16Optimizer class.
|
The Distributed Optimizer section in the README explains that when not using zero-1, the model state for fp16 is 4 bytes (20-16=4) larger than that for fp32. It appears that your parameter script does not enable zero-1, and dp_size is set to 1 (with tp_size set to 8). |
I am using TP=8 because I was trying to reduce the memory usage for a model that barely fits in a single node so I could increase batch size. I am sorry I dont understand what is 20 and what is 16 that are being subtracted. I would be grateful if you can explain it. |
The example mentioned above calculates the memory usage of the optimizer state using the Adam optimizer. Your script uses the SGD optimizer, so the calculation method may be different. SGD optimizer: |
Describe the bug
Using fp16 or bf16 uses more memory than using fp32
To Reproduce
Here are the training parameters I am using to train the model. When I comment out the
--fp16
, the memory usage increases.My setup 8xH100.
Expected behavior
FP16 should use less memory than that of FP32
Stack trace/logs
FP16 MEMORY USAGE
FP32 MEMORY USAGE
Environment (please complete the following information):
The text was updated successfully, but these errors were encountered: