LLM training can be much cheaper than people generally thought
JetMoE is a recent Large Language Model (LLM) that supposedly outperforms LLaMA2-7B from Meta AI and was trained for 2 weeks using 96×H100 GPU cluster, spending only ~$80,000…
But how much does it cost to train a LLM?
Training costs
A first oddity is that the JetMoE article does not explicitly mention any training costs (except its own) for comparison with other models. Also, according to this page, Llama2-7B model requires less than $85,000 to train – so, if that were the case, what would be the big economic benefit of JetMoE? For example, where did the JetMoE staff get the amount of training costs for Llama2-7B and why didn’t they publish this data for direct comparison?
Anyway, the JetMoE article reports training costs as GPU hours (exactly, Nvidia H100 GPU hours). JetMoE training costs 30,000 H100 GPU Hours. A Microsoft “optimized version of the Llama 2 model” shows the table below expressed in A100 GPU Hours (the overall performance of the H100 is better than the previous generation A100)…
Meta’s largest LLaMA model, as of march 2023, used 2,048 Nvidia A100 GPUs to train on 1.4 trillion tokens (750 words is about 1,000 tokens), taking about 21 days: the cost was over $2.4 million. Analysts and technologists estimate that the critical process of training a large language model such as OpenAI’s GPT-3 could cost more than $4 million. You can find these numbers here.
GPT-4 training approximately costs over $100 million (here).
JetMoE-8B is trained with less than $ 0.1 million cost but outperforms LLaMA2-7B from Meta AI, who has multi-billion-dollar training resources. LLM training can be much cheaper than people generally thought.
JetMoE-8B is very open and academia-friendly because:
It only uses public datasets for training, and the code is open-sourced. No proprietary resource is needed.
It can be finetuned with very limited compute budget (e.g., consumer-grade GPU) that most labs can afford.
JetMoE-8B only has 2.2B active parameters during inference, which drastically lowers the computational cost. Compared to a model with similar inference computation, like Gemma-2B, JetMoE-8B achieves constantly better performance.
How JetMoE works
The JetMoE architecture is illustrated in the following figure.
JetMoE architecture takes advantage of sparse activation on both the attention and feed-forward layers, significantly reducing training and inference costs.
Let x be the input vector, consider a learnable matrix Wr that controls the routing. Let s be the routing output:
s= Wrx .
The Sparse Mixture of Experts (SMoE) output y is represented by a relation of the type
y = g1 · f1(x) + g2 · f2(x) + · · · + gn · fn(x) .
It’s just a weighted combination of nexperts (these are normally 2-layer MLPs or, in case of Mixture of Attention, constructs of the type illustrated below) represented by the functions fi with i = 1, 2, . . . , n with the various “weights” gi as functions that select the top k logits (taking their softmax) from s, setting the rest to 0.
In essence, s is a vector whose larger components have a greater influence on the above combination defining output y. The usefulness of this approach lies in the fact that if gi = 0 for several indices i, then all the corresponding fi(x) will not be evaluated, thus reducing computation cost during training and inference. The mechanism of a single attention expert is illustrated in the following figure.
Matrices Wk and Wv are shared across experts to improve the training and inference efficiency, instead matrices Wq and Wo in orange vary from one expert to the other. ae is obtained applying standard multi-head attention with RoPE to k, v and qe .
A little coding
A very concise and quick PyTorch test Jupyter notebook for JetMoE can be found here (warning: you’ll need a lot of GPU memory). Alternatively, you can test the model directly using the Online Demo on Lepton AI (link).
This is a brief code review about the Open Release of Grok-1 (link), whose code is found here. We’ll start at the entry point (the main function of run.py file) and only go through the essential steps – delving into the details would take too much time and effort for a short post.
Intro to Grok-1
Grok is a generative artificial intelligence chatbot developed by xAI, based on a large language model (LLM). The engine powering Grok is Grok-1, a 314 billion parameter Mixture-of-Experts model trained from scratch by xAI, which became open source under the Apache-2.0 license on March 17, 2024, when xAI released the base model weights and network architecture.
The repository readme cites: “This repository contains JAX example code for loading and running the Grok-1 open-weights model. Make sure to download the checkpoint and place the ckpt-0 directory in checkpoints…”. So we know that it is relatively easy to place the weights, the difficult part is that the weights file, the result of a very expensive training, is about 318GB! In fact, the same page warns that due to the large size of the model, a machine with enough GPU memory is required to test the model with the example code… good, however we are only interested in the code for now!
Some model details: a) base model trained on a large amount of text data, not fine-tuned for any particular task; b) 314Bparameter Mixture-of-Experts model with 25% of the weights active on a given token; c) trained from scratch by xAI using a custom training stack on top of JAX and Rust in October 2023.
Parameters: 314B
Architecture: Mixture of 8 Experts (MoE)
Experts Utilization: 2 experts used per token
Layers: 64
Attention Heads: 48 for queries, 8 for keys/values
Embedding Size: 6,144
Tokenization: SentencePiece tokenizer with 131,072 tokens
Additional Features:
Rotary embeddings (RoPE)
Supports activation sharding and 8-bit quantization
Maximum Sequence Length (context): 8,192 tokens
Grok’s performance is not superior to other particular models. On the Grok blog, they justify this as “It is only surpassed by models that were trained with a significantly larger amount of training data and compute resources like GPT-4. This showcases the rapid progress we are making at xAI in training LLMs with exceptional efficiency”. I would never have thought of producing an article on a model with a gargantuan number of parameters and unconvincing results but, at least, the code seems very understandable.
Code
The repository contains code that only needs 4 libraries other than Python: JAX, Haiku, Sentencepiece and NumPy. The script run.py simply 1) loads the checkpoint (weights file) and 2) samples from the model on a test input, i. e. after inserting some input text, the model returns a response. Clearly, we are only talking about inference, the training efforts are all wrapped up in the cumbersome checkpoint (weights) file.
Main function
The script run.py contains the main function, that is our entrypoint. A language model configuration (grok_1_model) is initialized using specific parameters. Inside, a Transformer model is defined with its parameters, together with MoE and sharding parameters (sharding is a technique used in distributed computing to partition data across multiple devices or processors, allowing for parallel processing). Then, an inference runner (“runner” refers to an instance of a class or object responsible for executing the language model for inference; it encapsulates functionalities such as loading the model, tokenizing input text, performing inference, and generating output) is set up using this model configuration. The InferenceRunner is initialized with certain parameters such as pad sizes, the actual runner (an object of the ModelRunner class located in the runners.py file), name, load, tokenizer path, local mesh configuration (this is the configuration of ) and between hosts configuration. Finally, the runner is initialized and executed (these two steps are dotted in the following picture and will be explored in the next sections) to generate text based on a given input prompt (inp).
Above, the main function essential view. There are two reduced parts (highlighted in grey) for Grok-1 model config and inference runner config; these parts are expanded below.
Inference Runner initialization
After setting all the necessary parameters, the initialize() function from the inference runner object, is executed. This triggers a cascading initialization sequence, involving multiple regions of code, which is difficult to fully describe in a few lines – we’ll try!
inference_runner.initialize() calls initialize function (1) from InferenceRunner class. In turn, this last initialization function calls another initialize function (2) from ModelRunner class.
Overall, the initialize function (1) sets up the necessary components for the inference runner, including the model, tokenizer, and associated operations. It also handles distributed computation and compilation of model functions for efficient execution during inference. Here’s a breakdown of its components:
Initialization of Runner and Data:
The function starts by extracting the runner attribute and initializing a dummy data dictionary (dummy_data) containing placeholders for inputs and targets (in essence, two arrays of zeros).
Initialization of Model and Tokenizer:
The model is initialized using the runner.initialize() method, passing the dummy data and configuration parameters.
The SentencePiece tokenizer is initialized using the provided tokenizer_path.
Extraction of Model Parameters:
Parameters such as max_len (maximum sequence length) and vocab_size are extracted from the model for further use.
Padding Function:
Defines a pad_to_max_len function, which pads sequences to the maximum length expected by the model.
Functions for Model Operations:
Defines several functions (hk_forward, hk_sample_step, hk_new_memory, hk_prefill_memory) to perform various model operations such as forward pass, sampling, memory initialization, and prefilling.
Sharding and Compilation:
Prepares model sharding for distributed computation using jax.tree_util.tree_map_with_path.
Compiles the functions for sampling and prefilling memory using hk.without_apply_rng and pjit.pjit.
Final Setup:
Sets up RNG (random number generator) key for model initialization.
Initializes dummy_tokens for evaluating shapes.
Computes shapes using jax.eval_shape.
Parameter Sharding:
Computes parameter sharding using jax.tree_util.tree_map_with_path and apply_rules from the model’s partition rules.
Compilation of Sampling and Prefill Memory Functions:
Uses pjit.pjit to compile the sampling and prefill memory functions with appropriate input and output sharding configurations.
Inference Runner running
After initialization, the inference can begin. The run() method efficiently generates text by sampling from the language model in response to prompts and yields the generated text to the caller when requested. It handles multiple requests concurrently and efficiently utilizes resources through asynchronous data copying. Let’s break down its functionality:
Initialization: The method initializes various parameters and settings required for generating text, such as random number generators (rngs), memory buffers (memory), and sample settings.
Preparation: It prepares a prompt array and settings for sampling. The prompt is padded to a suitable length, and settings like temperature and nucleus probability are set.
Compiling: The method compiles the model for sampling. This might involve precompiling the model for different prompt sizes and compiling the model for actual sampling.
Sampling Loop: The method enters a loop where it continually samples tokens from the model in response to prompts. It yields generated text when requested by the caller.
Asynchronous Copying: During the sampling loop, it asynchronously copies data between devices and hosts to avoid blocking.
Handling Requests: It processes requests for text generation, updating the state accordingly.
Continuation: It continues the sampling loop until interrupted or until all requests are fulfilled.
Model
The model.py file contains all the architectural features and various support functions. It is about 1400 lines of code, we’ll try to summarize. We reiterate that the code uses JAX and Haiku, a library that simplifies the process of building and training neural networks in JAX, providing a high-level interface while maintaining the performance benefits of JAX’s low-level primitives. It’s commonly used in machine learning research and development where JAX is the preferred framework for its flexibility and performance.
The first lines are about 8bit quantizing weights, registering them as a Pytree nodes (enabling efficient processing and manipulation using JAX’s powerful array-based operations and transformations), sharding constraints based on the presence of a distributed computing environment, rules defining a specific pattern to enable efficient distribution and parallelization of computations in a transformer model across multiple devices or processors.
Then come all the classes strictly related to the definition of MoE + Transformer architecture: Router, MoE layer, MultiHead Attention, Decoder, Transformer, Attention Mask, RMS Norm, Rotary Embedding.
Finally, language model and its configuration are defined. These classes integrate embedding, positional encoding, transformer layers, and decoding logic to generate logits for next-token prediction. They ensures proper handling of padding, masking, and distributed computation if configured.
Conclusions
We briefly discussed the Grok-1 code, which is easy to understand and well written. The model should be tested on huge machines but it is probably still immature for a definitive and high-performance version. At the time of writing, a 1.5 version of Grok already appears to be on the way – again, not state of the art… size isn’t everything in artificial intelligence.