I have recently implemented the GPT-2 124M parameter model using PyTorch. I used this project to learn more about the training process of an LLM by implementing GPT-2 and training it from scratch, which was part of Andrej Karpathy sir's Zero to Hero series on YouTube. The video I referred to was titled "Let's reproduce GPT-2 (124M)".
For all who are curious about my learning journey, I have added incremental Git commits based on what I learned in each iteration in this GitHub repository (Github_repo). For those interested, I have also uploaded the trained model weights to this Hugging Face repository ( Model ).
If you feel this is too long and would like me to walk you through it, please let me know in the comments or send me a DM. We can plan a call or a YT live , or something similar.
Model settings, training process and other info
I have trained it for 1 epoch which is (10B tokens of fineweb edu dataset's sample10B ) , 19073 steps. I have followed the exact hyperparameters mentioned in the karpathy's nanogpt which inturn is from gpt-3 paper. I have used 3 A100s of 40GB of vram each and trained the model for almost 6 hours to complete 1 epoch .I have acheived val loss of 4.2 from 10.9 in 1 epoch. I have also used hellswag eval (which this model is performing terribly) which is giving me approx 25% . I have rented GPUs from jarvis labs ( JarvisLabs.ai ) which costed me around 2700 approx ( 4500 approx if i include the time i spent on learning distributed model training with multiple gpus :/ ). I also have saved the model's state_dict for every 50 steps ( which is too much of data (179GB) i feel). I did this to visualize the transformer's learning process especially the attention layer's process. It will take time for me to visualize those and understand them in depth.
Summary of Flow of things that I did from the beginning
- For someone who has watched karpathy's video this part will feel like a rewind.
- Started off by checking out huggingface's gpt2 model and loaded it with the pretrained weights that are released by OpenAI and performed some text generations and played with model a bit.
- Initial Idea was to replicate the same way as huggingface implemented it and train it so that we will be able to load it in the hf model and compare. (But this was not acheived though).
- Had written GPT module using pytorch's nn.module. Had written this similar to hf implementation of GPT2 by taking config which contains hyperparams and given the same name to model parameters.
- Implemented the simple training loop of 50 steps and were training on the tiny shakesphere dataset. And also added some code to see how much time it is taking for each step and how many tokens are being processed each step.
- Did some minute changes to the naive model that has been written by tying weights of token embeddings and the lmhead (final layer of transformer), adding scaling factor to residual pathways .
- Started off with GPU gymnastics to improve performance :
- Converted the tensors which will by default use float32/64 to use torch.tf32. TF32 is basically mantissa cutdown float number. We will loose some precision but this will improve speed of computation by a lot as there will be less bits to process.And we can take this as we are calculating some kind of scientific data where mantissa matters alot.(Think of mantissa as the part after . in a float number ex: .5234 in 12.5234 in simple words )
-
- To give a perspective this improved time taken per step from 1100ms -----> 400ms
- Converted from tf32 to bfloat16. This is similar but we cut down more mantissa. This will give us some not so precise but close results with a bit more faster compute speed.
- After this 1step jumped from 400ms -----> 340ms
- Added torch.compile. Up until now our python interpreter does all the computation sequentially line by line. This causes many roundtrips from memory to GPU which can be optimised as we know what will be the next processes to perform,. The same thing is acheived using torch.compile(model). This will compile the whole model to run GPU efficiently with minimal round trips
- 1 step jumped from 340ms ---> 150ms
- Replaced masked attention that we implemented with Flash Attention instead. Flash Attention was the algorithmic improved implementation of attention mechanism as It was proposed keeping parallel computing and GPUs in mind.