Skip to content

Boo! Matmuls and Why They are Scary. Chapter One.

“On Halloween, I’ll put on emotional intimacy costume - it scares everyone”

Niiiice, but! The scariest thing in my life has been linear algebra, and, particularly, matrix multiplication - I almost cried over it in the university. Math, in general, has never been my thing, but solving matrices equations was just torture and my worst nightmare. And here I am: trying to understand how it works again through the lenses of AI infrastructure. Teaser: such ephemeral things as matrix multiplications influence even the hardware you use quite directly.

It’s not that I need to use this daily: write code or do research - however, decomposing how large models work at all infrastructure building blocks levels appears to be crucial for me to do my job as well. Simply because I need to speak more or less the same language with the people we at @AIFoundry.org deal with.

For those who are puzzled by why I am talking about matrices multiplication, and what the heck is it in general - let’s make a small dive into details. Hold your breath ;)

 

Matrices are an essential part of the large models both in training and inference as a way to represent data efficiently — such as input data (e.g., pixels in an image), or the internal processing between the model layers. Multiplying these matrices takes the enormous part of the overall compute in deep learning models. In fact, matmuls make up roughly 45-60% of the total runtime of many popular transformer models - both during training and inference. Therefore, you have to write matmul algorithms fast enough to achieve state-of-the-art performance across models, frameworks, hardware, datasets and inference engines in the entire AI industry. Now imagine you need to push all this to the production, a distributed network of nodes which are expected to work together in a fast and effective and compliant way.

 

Data you use can vary quite a bit. Many people know about data diversity in terms of structured and unstructured data, but here we're talking about data types ("dtype"). AI models usually use a data type called FP32, but the industry is increasingly using lower precision types like Bfloat16, Int8, and even more unusual ones like FP4 and Int4 to make models smaller and faster. This means matrix multiplication algorithms need to handle many different data types, depending on the specific use case.

 

Each device that runs AI models has different types of memory and different units for multiplying and adding data. For example, CPUs have a memory hierarchy ranging from slower RAM to faster caches’ levels and CPU registers. Generally, the smaller the memory, the faster it is — for instance, accessing a level 1 cache takes just a few nanoseconds, while accessing RAM can take around 100 nanoseconds. So, to get the best performance for matrix multiplication, the algorithm must be designed to work efficiently with these different memory levels and sizes. Since raw matrices are too large to fit in the fastest memory all at once, the challenge is to break them down into smaller tiles - to use the fastest memory at its best.

 

The hardware units used for matrix operations vary depending on the device. In the past, CPUs were mostly scalar, which meant they processed instructions one by one. Over the last 20+ years, though, CPU manufacturers have added vector units to handle multiple instructions at once. GPUs take this further by running operations across multiple threads (Single Instruction Multiple Threads), which makes them great for parallel, repetitive tasks like matrix multiplication. Specialized hardware like Google's TPUs can even work directly with 2D matrices. While these advanced units boost performance, they also demand more adaptable algorithms to handle all types of processors - scalars, vectors, or matrices.

 

AI models also come in many forms, and the sizes of matrices inside them can vary a lot. For instance, models may have different input shapes (like sequences of different lengths), different shapes for internal layers (the matrices used in hidden layers), and even different batch sizes (which affects training and inference efficiency). Because of this, there are hundreds of different matmul shapes used in practice, making it challenging to break them down into smaller blocks to make the best use of memory.

 

Now you get me about why this is a nightmare?:) As a result, it is also a nightmare for the AI industry nowadays, as matrix multiplication contributes highly to the compute expenses - and they are huge for a model training from scratch. But this component makes the inference highly complicated and sometimes expensive, as it requires lots of advanced hardware optimization. And now it is a good point to dive a bit deeper into how you can make it...

 

Stay tuned for the second chapter!