Short code analysis for a huge model
This is a brief code review about the Open Release of Grok-1 (link), whose code is found here. We’ll start at the entry point (the main
function of run.py
file) and only go through the essential steps – delving into the details would take too much time and effort for a short post.
Intro to Grok-1
Grok is a generative artificial intelligence chatbot developed by xAI, based on a large language model (LLM). The engine powering Grok is Grok-1, a 314 billion parameter Mixture-of-Experts model trained from scratch by xAI, which became open source under the Apache-2.0 license on March 17, 2024, when xAI released the base model weights and network architecture.
The repository readme cites: “This repository contains JAX example code for loading and running the Grok-1 open-weights model. Make sure to download the checkpoint and place the ckpt-0
directory in checkpoints
…”. So we know that it is relatively easy to place the weights, the difficult part is that the weights file, the result of a very expensive training, is about 318GB! In fact, the same page warns that due to the large size of the model, a machine with enough GPU memory is required to test the model with the example code… good, however we are only interested in the code for now!
Some model details: a) base model trained on a large amount of text data, not fine-tuned for any particular task; b) 314B parameter Mixture-of-Experts model with 25% of the weights active on a given token; c) trained from scratch by xAI using a custom training stack on top of JAX and Rust in October 2023.
- Parameters: 314B
- Architecture: Mixture of 8 Experts (MoE)
- Experts Utilization: 2 experts used per token
- Layers: 64
- Attention Heads: 48 for queries, 8 for keys/values
- Embedding Size: 6,144
- Tokenization: SentencePiece tokenizer with 131,072 tokens
- Additional Features:
- Rotary embeddings (RoPE)
- Supports activation sharding and 8-bit quantization
- Maximum Sequence Length (context): 8,192 tokens
Grok’s performance is not superior to other particular models. On the Grok blog, they justify this as “It is only surpassed by models that were trained with a significantly larger amount of training data and compute resources like GPT-4. This showcases the rapid progress we are making at xAI in training LLMs with exceptional efficiency”. I would never have thought of producing an article on a model with a gargantuan number of parameters and unconvincing results but, at least, the code seems very understandable.
Code
The repository contains code that only needs 4 libraries other than Python: JAX, Haiku, Sentencepiece and NumPy. The script run.py
simply 1) loads the checkpoint (weights file) and 2) samples from the model on a test input, i. e. after inserting some input text, the model returns a response. Clearly, we are only talking about inference, the training efforts are all wrapped up in the cumbersome checkpoint (weights) file.
Main function
The script run.py
contains the main
function, that is our entrypoint. A language model configuration (grok_1_model
) is initialized using specific parameters. Inside, a Transformer model is defined with its parameters, together with MoE and sharding parameters (sharding is a technique used in distributed computing to partition data across multiple devices or processors, allowing for parallel processing). Then, an inference runner (“runner” refers to an instance of a class or object responsible for executing the language model for inference; it encapsulates functionalities such as loading the model, tokenizing input text, performing inference, and generating output) is set up using this model configuration. The InferenceRunner
is initialized with certain parameters such as pad sizes, the actual runner (an object of the ModelRunner
class located in the runners.py
file), name, load, tokenizer path, local mesh configuration (this is the configuration of ) and between hosts configuration. Finally, the runner is initialized and executed (these two steps are dotted in the following picture and will be explored in the next sections) to generate text based on a given input prompt (inp
).
Above, the main
function essential view. There are two reduced parts (highlighted in grey) for Grok-1 model config and inference runner config; these parts are expanded below.
Inference Runner initialization
After setting all the necessary parameters, the initialize()
function from the inference runner object, is executed. This triggers a cascading initialization sequence, involving multiple regions of code, which is difficult to fully describe in a few lines – we’ll try!
inference_runner.initialize()
calls initialize
function (1) from InferenceRunner
class. In turn, this last initialization function calls another initialize
function (2) from ModelRunner
class.
Overall, the initialize
function (1) sets up the necessary components for the inference runner, including the model, tokenizer, and associated operations. It also handles distributed computation and compilation of model functions for efficient execution during inference. Here’s a breakdown of its components:
- Initialization of Runner and Data:
- The function starts by extracting the
runner
attribute and initializing a dummy data dictionary (dummy_data
) containing placeholders for inputs and targets (in essence, two arrays of zeros).
- The function starts by extracting the
- Initialization of Model and Tokenizer:
- The model is initialized using the
runner.initialize()
method, passing the dummy data and configuration parameters. - The SentencePiece tokenizer is initialized using the provided
tokenizer_path
.
- The model is initialized using the
- Extraction of Model Parameters:
- Parameters such as
max_len
(maximum sequence length) andvocab_size
are extracted from the model for further use.
- Parameters such as
- Padding Function:
- Defines a
pad_to_max_len
function, which pads sequences to the maximum length expected by the model.
- Defines a
- Functions for Model Operations:
- Defines several functions (
hk_forward
,hk_sample_step
,hk_new_memory
,hk_prefill_memory
) to perform various model operations such as forward pass, sampling, memory initialization, and prefilling.
- Defines several functions (
- Sharding and Compilation:
- Prepares model sharding for distributed computation using
jax.tree_util.tree_map_with_path
. - Compiles the functions for sampling and prefilling memory using
hk.without_apply_rng
andpjit.pjit
.
- Prepares model sharding for distributed computation using
- Final Setup:
- Sets up RNG (random number generator) key for model initialization.
- Initializes
dummy_tokens
for evaluating shapes. - Computes shapes using
jax.eval_shape
.
- Parameter Sharding:
- Computes parameter sharding using
jax.tree_util.tree_map_with_path
andapply_rules
from the model’s partition rules.
- Computes parameter sharding using
- Compilation of Sampling and Prefill Memory Functions:
- Uses
pjit.pjit
to compile the sampling and prefill memory functions with appropriate input and output sharding configurations.
- Uses
Inference Runner running
After initialization, the inference can begin. The run()
method efficiently generates text by sampling from the language model in response to prompts and yields the generated text to the caller when requested. It handles multiple requests concurrently and efficiently utilizes resources through asynchronous data copying. Let’s break down its functionality:
- Initialization: The method initializes various parameters and settings required for generating text, such as random number generators (
rngs
), memory buffers (memory
), and sample settings. - Preparation: It prepares a prompt array and settings for sampling. The prompt is padded to a suitable length, and settings like temperature and nucleus probability are set.
- Compiling: The method compiles the model for sampling. This might involve precompiling the model for different prompt sizes and compiling the model for actual sampling.
- Sampling Loop: The method enters a loop where it continually samples tokens from the model in response to prompts. It yields generated text when requested by the caller.
- Asynchronous Copying: During the sampling loop, it asynchronously copies data between devices and hosts to avoid blocking.
- Handling Requests: It processes requests for text generation, updating the state accordingly.
- Continuation: It continues the sampling loop until interrupted or until all requests are fulfilled.
Model
The model.py
file contains all the architectural features and various support functions. It is about 1400 lines of code, we’ll try to summarize. We reiterate that the code uses JAX and Haiku, a library that simplifies the process of building and training neural networks in JAX, providing a high-level interface while maintaining the performance benefits of JAX’s low-level primitives. It’s commonly used in machine learning research and development where JAX is the preferred framework for its flexibility and performance.
The first lines are about 8bit quantizing weights, registering them as a Pytree nodes (enabling efficient processing and manipulation using JAX’s powerful array-based operations and transformations), sharding constraints based on the presence of a distributed computing environment, rules defining a specific pattern to enable efficient distribution and parallelization of computations in a transformer model across multiple devices or processors.
Then come all the classes strictly related to the definition of MoE + Transformer architecture: Router, MoE layer, MultiHead Attention, Decoder, Transformer, Attention Mask, RMS Norm, Rotary Embedding.
Finally, language model and its configuration are defined. These classes integrate embedding, positional encoding, transformer layers, and decoding logic to generate logits for next-token prediction. They ensures proper handling of padding, masking, and distributed computation if configured.
Conclusions
We briefly discussed the Grok-1 code, which is easy to understand and well written. The model should be tested on huge machines but it is probably still immature for a definitive and high-performance version. At the time of writing, a 1.5 version of Grok already appears to be on the way – again, not state of the art… size isn’t everything in artificial intelligence.
Useful links
xAI – Open Release of Grok-1 (link)
Open Release of Grok-1 – GitHub repository (link)
xAI – Announcing Grok (link)
Hugging Face – Mixture of Experts explained (link)
xAI – Announcing Grok-1.5 (link)
Also available on Substack