Building Machine Learning Systems for a Trillion Trillion Floating Point Operations

Over the last 10 years we’ve seen Machine Learning consume everything, from the tech industry to the Nobel Prize, and yes, even the ML acronym. This rise in ML has also come along with an unprecedented buildout of infra, with Llama 3 now reaching 4e25 floating point operations, or 40 yottaflops, or 40 trillion trillion floating point operations.

Learn more about machine learning at Jane Street.

Transcript

So it’s my great pleasure to introduce our speaker tonight, Horace He. Horace He graduated from Cornell in 2020, and he works at Meta on the PyTorch team, specifically at the intersection of compilers and machine learning. If you’ve used things like torch.compile, which is a thing in PyTorch that makes your model go like 2-4x faster in just one line of code. Or if you’ve used FlexAttention, which is something that lets researchers design fast kernels for attention without leaving the world of Python. He’s the person responsible for both of those things. You should also check his blog, which is pretty awesome, and in particular, the blog post “Making Deep Learning Go Brrrr from First Principles.” And without further ado, I’ll give it to Horace.

Thanks for the intro. Today I’m gonna give a talk about building machine learning systems for a trillion, trillion floating point operations. My name is Horace He, and I’m on the PyTorch compilers team at Meta. So, you know, I think we live in pretty unprecedented times in terms of an infrastructure build-out, which is, I think, you know, nicely reflected in NVIDIA’s stock price. I feel like, you know, basically every month, we see like a different headline about like a new nuclear power plant from like Microsoft, or, you know, like a massive 300K GPU data cluster from xAI. You know, people like, I feel like when Amazon first like built or like bought a nuclear power center, it was like big news. But now practically like, you know, every cool AI company has their own nuclear data center. And you know, of course, this like kind of really massive build-out has kind of also resulted in pretty ludicrous fundraisers from startups, I think. Like, I remember back in like 2016, like, you know, if you, like, made it as a startup, it would be like if you were worth a billion dollars, right? It was like a unicorn. It was like the mark of really making it as a startup beyond your wildest dreams. But you know, in 2024, you know, you gotta raise a billion dollars just to get started. You know, it’s like just to play the game, you need like a billion dollars, you know, of which most of it goes to NVIDIA.

And so I think like, it’s kind of crazy to think that, you know, all of this, like, you know, billions of dollars is really just to do like absolutely insane amounts of floating point operations. Here’s like a nice chart from Epoch AI showing like the growth of compute over time. And like, these floating point operations are really just like big matmuls done like over and over, like, you know, for millions of iterations over like months. And nowadays, like the leading-edge models are currently trained with about like one E26 floating point operations, which is approximately 100 trillion trillion floating point operations. So like a trillion trillion, like a trillion TeraFLOPS worth of floating point operations.

And so, you know, back, you know, another kind of effect that you might’ve noticed is like back prior to 2016, if you search on Hacker News like ML, you’d often get like a lot of people asking about, you know, the ML family of languages. But nowadays, you know, you search ML on Hacker News and you get like a very different type of article. And so I think one of the things that kind of is missed here is like, you know, with all these billions of dollars and like, you know, yottaflop of operations, it’s kind of easy to forget that, you know, like these operations needed to actually run somehow on the like machines. And so, you know, a modern stack might involve, you know, like, you know, call like an LLM API and then it’s called like PyTorch, like NCCL, Triton, CUDA like NVCC, like all these different layers in the stack that like somebody was involved in writing.

And I’m often reminded of this XKCD comic, you know, showing like the state of modern digital infrastructure and so, you know, you kind of, you know, often forget about all the infrastructure that was involved for you to get where you are, but like really we’re just kind of building like layers on top of layers. And so, you know, if you work in systems, you know, like I do and I suspect many of you do, you kind of, I think, oftentimes think about your work a little bit like, like a benevolent dictator of sorts. Where kind of you imagine like what you’re doing here is that like a lot of people build on top of your work. And so if you can kind of, you know, give like a small amount improvement to, you know, like millions of people, you know, your small improvements can lead to like, you know, significant impacts on the world.

And so for example, I kind of imagine like Guido, with Python, you just, like, you know, kind of sits on top of this cloud and you know, eventually gives us like NoGil or like faster CPython or like the Walrus Operator and I feel like this is kind of oftentimes like how I imagine infrastructure work did. And it’s kind of a lot of why I got into infra work in the first place. And so on the other hand, I feel like if you work in ML, things can sometimes feel a little bit different. I originally came across this post on threads where this person said, you know, my main gripe with like working on top of LLM APIs is that you’re not really like, no one is like engineering anything. You’re just like chanting a prayer to this manmade demon deity to like do your bidding. And this like very similar to this kind of like Shoggoth metaphor that has become very popular in deep learning circles. Where the idea here is that we really put all these matmuls and computation into producing this very weird alien intelligence. and then we like kind of RLHF it, you know, and like provide it in this nice convenient interface to people with like, you know, ChatGPT or something like that. But you know, I think if you think about it, I think we’re kind of working with Shoggoths like all the way down. Like even if you’re working like in systems and you’re not calling ML models, you still have this kind of massive amount of infrastructure that you presumably don’t really understand written by people you’ve probably never talked to. And like the end result they try to expose is like some kind of a simple interface where you can just like import torch and then, you know, hopefully run your code across like, you know, a 100K GPUs.

And so I think as a result, there are some, I think, interesting ways in which I think ML models that feel kind of different from regular systems. One of those ways is that ML models are extremely simple and as a result, we have very high expectations from the performance of the code. I think it kind of traces this all the way back to this very nice article called the “Bitter Lesson,” which I’d really recommend reading if you guys haven’t come across it before. Where their main observation here was that clever ideas in machine learning have basically throughout its 50-year history always lost out to just simple ideas that scaled really well with Moore’s law. And I’m not really joking here when I say that machine learning logic is exceedingly simple. There’s like this cool project from someone called Andrej Karpathy called llama2.c. And basically here he’s like implemented llama2 in about 973 lines of C. With like no other dependencies. Like, you know, every single loop, every single matmul is just implemented from scratch. So with 973 lines, you can’t like, you know, do it very fast, but it does run, and I think it does kind of indicate just how fundamentally simple the models that like we’re spending all this compute end up being.

And so the end result, like although the problems themselves are extremely simple and like very easy to optimize in some sense, the expectations are very high for how well we can optimize these matmuls. So one example here is that the predominant metric for measuring your model’s performance in deep learning is called model FLOP utilization. And so this is basically the percentage of the theoretical max flops that your GPU is able to do. And so if you kind of think about this, this is actually a very absurd metric to hit. Like if on a CPU, like you measured any of your code by this metric, like you can only hit a 100% if like at every single time every single core of your CPU is always issuing max-width SIMD instructions. That’s the only way you can hit 100% utilization. And so, you know, if you take a look at any of the code that you guys presumably write, like almost no CPU code is like anywhere near like, you know, 100% flop. It’s probably like way under 1% most of the time. On the other hand, in machine learning for large scale training, we’re typically often hitting around 50% of the peak flops. And this is I think like this is kind of indicative that even though the overall problem is like very simple, the kind of corresponding difficulty just goes into making your models hit this very high perf barrier.

Another kind of interesting observation in machine learning is that, that has kind of exacerbated this a bit. Is that the field has consolidated significantly over the last five to 10 years. So one of them is that like, you know, like maybe 10 years ago you had a lot more architecture is a lot more variance and like different things that people were trying. But nowadays people really just like transformers are the dominant architecture for everything. Like you have transformers for vision, you have transformers for language, you have transformers for audio, it’s kind of like all transformers.

And the other way things have kind of changed is that instead of like, you know, many different people training SOTA models. You oftentimes just have a few companies training SOTA models. And so we’ve kind of gone from a bit of a monopoly where like, you know, previously there was this one person providing the infra and many people using the infra, to in some ways it feels a little more like monopsony where you have many people trying to provide infra and then, you know, only one person’s actually training the job at the end of the day. And so as a kind of result, I think that like there’s kind of two general ways to think about like getting performance from your systems.

And so one of the ways generally is basically like optimizations is like you have a compiler, you make the compiler faster, you know, you improve the performance of everybody using your compiler. And so the other way though is kind of like programming models. And so like this is kind of analogous as opposed to like a system whose responsibility is to chop down a tree and then, you know, you’re just optimizing how fast you can chop down the tree. And the other alternative is like you’re providing tools for people to cut down the trees themselves.

And so to kind of talk a little bit about programming models, I think it’s illustrative to talk about how ML frameworks have kind of evolved in terms of what program model they’ve exposed to users. So originally like I think like 2010, like 2011, 2012, the first ML framework that kind of got a lot of popularity was this framework called Caffe. And the way you express neural networks in Caffe is like a very declarative nature. And by that I mean you edited a protobuf file I think. Where like the protobuf specified all the things you needed to care about your neural network.

And so as you can imagine, you know, programming protobufs is not very fun, you know, there’s a lot of things you might want to do that you can’t do in protobufs. And so a natural net thing that people did was kind of these graph builder type APIs. And so this is kind of, you know, how TensorFlow 1 kind of looked like where the idea is that oh, you know, programming protobufs or no human should need to write protobufs by hand. And so ideally, you know, you should just like write like a DSL of sorts that allows you to generate the protobufs from this DSL.

However, like even this DSL still kind of has like a lot of confusion in it. Like it is not super clear that given this DSL like how code actually executes on your GPU. And so kind of finally around 2016 or 2017, PyTorch started to become really successful. And kind of the feature that was most emblematic of PyTorch was basically what was called imperative or eager execution. And so what this means is that PyTorch was very successful, and I think it’s worth talking about why eager execution was so successful. And I think that the main reason it was successful, it just comes down to what the programming model of the execution looked like. Where in imperative slash eager execution is basically like, you call a function, the GPU runs a function and then the function finishes and you know, that’s it. Like, you know, like that’s basically all you do. you call like torch.matmul and this is basically the sequence of operations that happens but you know, with kind of like a graph mode type approach or you know, where kind of this pilot interjects. You first define the function, the function gets converted into some like intermediate IR, you know, a bunch of who knows what happens to your function and then eventually like the function eventually executes on the GPU. And like the top one was like a very simple execution model for people to understand.

And I think another thing that’s kind of interesting to notice about this is that this kind of also just, like the top half also kind of describes how Python executes. And I think this is kind of pretty illustrative to me of like why Python has been so successful in machine learning. It’s basically that like, like a funny statement that you can make about Python is that like if you tried to train a model today, like you know, I gave you like a day to train a model, it would run faster if you ran it in Python compared to doing in C++. And you might argue that this is like an unfair comparison. Because you know, like you know all the infra, like, you know, all the frameworks that people built are in Python. But I think the reason why so many of these frameworks have been built in Python is that Python is an exceedingly simple language. And so it’s also like a very global language. And what I mean by like global is that like it’s very easy for people to build their own infrastructure on top of Python without really needing to fight with like anything the Python language does because the Python language itself does basically nothing. Like it, you know, it doesn’t do any optimizations for you. It just takes your function and runs it.

And so I think that PyTorch basically historically is at like a very similar point in the design space. Where PyTorch is like execution model is like so very simple and although this like doesn’t really give you a lot in terms of like, it doesn’t like automatically do a lot of things for you, it does mean that it’s very easy for people to like build their own infrastructure and frameworks on top of PyTorch.

I think another important detail to realize especially about PyTorch when it first came out. Is that this kind of unoptimized execution didn’t actually even sacrifice any performance at all. Like, you know, when people benchmarked PyTorch versus like, you know, TensorFlow at Caffe, PyTorch oftentimes wasn’t even slower than those frameworks. And I think there’s like two main reasons for this. So the first reason is that back in the day, like you know about like 90% plus of your time was spent in matmuls. And so there’s basically nothing else you need to optimize. And so matmuls here are like matrix implications and so they’re often provided by these vendor libraries like cuBLAS or QDNN. And like, you know, they’re provided by NVIDIA. and they’re like very hand optimized and so, you know, if 90% of your time is spent in matmuls, then like what else can you even do to like optimize the performance of your neural network?

And I think another kind of interesting like important piece here for like why PyTorch’s performance was quite good is that it had this Async execution model. Where basically the idea here is that you kind of have a parallel work queue on your GPU. And so what the CPU does is that it’s only responsible for scheduling work on your work queue and then, you know, the GPU executes work from the work queue and I think of it generally as like this GIF is what usually comes to mind and basically you can imagine that like the dog is like, or Gromit is like Python, you know, they’re like, you know, trying to put down the train track in front of the train, which is the GPU. And so, you know, as long as like Gromit is able to put down the train tracks faster than the train actually rolls along the train tracks, you can actually kind of view Python as having zero overhead. Like it doesn’t provide any extra cost compared to, you know, if Python was like in a more efficient language like C++.

And so in this way, like you know, eager execution not only had like a much easier to understand program model for users, it was also like basically just as fast as non eager execution.

Unfortunately, you know, good things never last. And in 2017, like NVIDIA introduced what are called tensor cores. And if you guys are unfamiliar with tensor cores, they’re basically hardware units on the GPUs that only do matmul operations. Like, and I don’t mean this like figuratively in the sense that people often say that GPUs are like well suited for matmuls. I mean this very literally in that there’s actually an assembly instruction that just does like a mini matmul. And this is how you like interact with the tensor cores. And so if you look at this plot of like the amount of matmul flops versus non-matmul flops, you can really see like when NVIDIA realized that deep learning was a big deal. Because like all, all of a sudden, you know, you kind of had this massive 10x gap and so there’s like a log scale. And you had this kind of massive 10x gap between how fast matmuls were on the GPU and how fast like literally anything else you wanted to run on the GPU was.

And so the end result is, you know, previously we said that matmuls took like 90% of the time and so if the, if like, you know, NVIDIA has sped up matmuls by 10x but then everything else stayed the same amount of speed, then all of a sudden, you know, like you’re spending a lot more of your time doing non-matmul operations. And so as a result, we’ve kind of gotten ML compilers I think largely due to this change. And so one of the important details about ML compilers like, you know, in terms of how they differ from the frameworks that came before is that ML frameworks still keep the eager programming model and that, like the code that you write like logically, the program model exposed to users is that you’re still just writing Python code and executes like line by line. The only difference now is that instead of actually executing line by line, we kind of captured into a graph in some manner. And so Torch compile I think actually kind of does this in a pretty interesting way. And that torch compile actually intercepts at the Python by code interpreter level. Where Python kind of exposes these APIs where you can kind of insert your own frame interpreter. And so this looks very much just like a traditional Git for like any kind of other VM. except this Git is kind of, you know, only meant for like PyTorch programs.

And so if you kind of look at how things have evolved over time. Originally you kind of had frameworks like TensorFlow or Caffe before that. Where both the user like program model that users wrote was like a graph builder type traction. But then the execution programming model was also a graph execution type extraction. And then like after that, you kind of had PyTorch, you know, one type stuff. Where the user program model now switched to an eager style execution, but the execution program and the execution program model was also eager where like, you know, each operator’s executed one at a time.

But kind of now finally, you know, modern ML frameworks like, pretty much all ML frameworks nowadays use an imperative eager programming model. But almost all ML frameworks now also have some way to capture this program model into a graph of some kind so they can perform optimizations. And so this is like, you know, JAX.git, Tonnegrad, this is like MLX. They kind of all have their different approaches for capturing the graph to optimize.

And so I think next I kind of want to talk about, you know, we’ve kind of discussed how we’ve gotten to ML compilers in the first place. And so I think next I want to talk about what ML compilers are actually doing for you and what kind of optimizations they’re performing.

And so generally speaking, the way I think about deep learning performance or performance on GPUs in general is that there’s basically three things you can be spending your time on. The first one is compute, so this is time on our GPU competing in actual floating point operations. The next one is memory, which is, you know, time spent transferring your tensors within a GPU. So this is like, you know, across various memory subsystems in your GPU. And so finally overhead, which is like everything else like, you know, it’s like time your GPU spending idle and so on.

And so first we’re gonna talk about compute. And so I think to a sum of approximation, you can say that all runtime on your GPU is either compute or it’s a shuffling data. And that data movement is not like a real operation, right? it’s like a no-op from the theoretical point of view. All it’s doing is it’s moving data from one place where it’s convenient to another place where it’s convenient. And so basically a floating point operation is like the only real thing a GPU can do. But you can actually, I think simplify this even more and say that in reality actually nowadays like all runtime is either matmuls or essentially shuffling data. And so this is because if you look at the actual flop chart on an Edge 100 GPU, you can see here that the FP32 flops is like you only have 67 teraFLOPS of FB32 compute, but you actually have like 1,000 teraFLOPS of a TF32 compute. Which is basically matrix multiplication compute. And so what this essentially… And sometimes what this means is that if you’re not doing measurable on your GPU, you’re really only getting like 7% of your peak FLOP utilization. And so like, you know, by the metric that I mentioned before like model FLOP utilization, even if your GPU was fully occupied doing stuff that wasn’t a matmul you could only ever get like 7% of FLOP utilization, which is like much lower than our theoretical peak.

I did have a brief interlude about like, I think an interesting case where these kind of abstractions do break down even more. And so I do have like a kind of a fun question. Which is like, do the matrix contents affect your matrix multiplication performance? And so I think if you kind of like, you know, are familiar with like general performance, there are like a lot of things that… A lot of ways where data can impact your performance, but in this case, matmuls actually avoid a lot of them.

So for example, they have identical memory access patterns regardless of what data is in your tensor, there’s no control flow in the matmul. And the GPU also doesn’t have like the denormals. So like that’s not a possibility as well. So if you like, you know, and in this case where like taking three tensors, we’re initializing them with like all zeros. Like a tensor initialized from the Gaussian distribution and then a tensor initialized from a uniform distribution, which is from zero to one. And so funnily enough, if you benchmark this, you actually find that there is a performance difference depending on the actual data that is within your tensors. And so there’s this tweet from some of you guys might know, a long time ago that really I think like when I first saw this. I actually was very much reminded of this tweet, where, you know, I thought I knew how a GPU worked, but then I was very confused about what could possibly be causing these performance differences like between the different data. And so the actual cause here is something called leakage power or dynamic power where I think, you know, most of you’re probably familiar that you know, like when a CPU or GPU is under load, it uses more power and at some point it can throttle, you know, it’s like using the maximum amount of power it can use or the max amount of heat it’s allowed. But the actual thing is that like, you know, this power doesn’t just come from nowhere. It actually largely comes from what’s called dynamic or switching power. And what this means is like every time a transistor like on the GPU switches from zero to one or like, you know, high to low or low to high, it loses a little bit of this power. And so the actual power usage on your GPU is kind of like a sum across the total amount of switching that goes on in your GPU.

And so this is why like if you’re multiplying with all zeros, you can imagine that like your GPU ends up not, like a lot of transistors don’t end up switching at all. And so like it doesn’t actually consume that much power and it’s much less a throttle. And so if you actually look at this, you can actually get very different performance for all these different kind of fund distributions. Like whether it’s the normal distribution or whether it’s like a checkerboard type pattern or you know, it’s sparse or ternary. And the reason why is just like, it’s like this kind of abstract thing where these different patterns lead to more or less transistor flips and which leads to more or less power throttling, which leads to more or less performance.

And so I remember actually one time somebody had told me a funny story where like they were training the machine learning model and benchmarking their performance. And then they like at some point their model would nad, and then they’d be like wow my performance just got way better. And so I wrote an article about this and they messaged me and they were like, oh you know, that was very illustrative because you know, I was really confused why my performance would be getting better. But that’s why, you know, like if all your tensors are NaN, your transistors also don’t need to do a lot of flipping and so you know, you’ll measure like a better perf.

So that’s kind of compute. And so the next thing that your GPU can be spending a lot of your time on is memory, which is essentially time spent transferring your tensors within a GPU. And so I think one thing to observe from this kind of empirical plot from a paper on data movement is that although like so this paper kind of breaks down the operations on like I think a Bert type model. Into what it calls tensor contractions, i.e memory locations. And then normalization operations and element-wise operations. And so you can see that although major are responsible for 99.8% of your FLOPS, they’re only responsible for 61% of your runtime. And so where are we spending, like, you know, 40% of our runtime doing like operations that cumulatively only take 0.2% of our FLOPS? The kind of key thing here is what’s called a memory bandwidth cost. Where the way I typically think about this, is that like even like, and so here I’m talking like all of your data already lives on the GPU, like, you know, it’s like, you know, occupying your GPU’s VRAM. But the thing is that like your GPU’s VRAM is not where the compute units are located. And in order to actually do operations on a GPU, you need to move your data from the VRAM to where the compute units are located, which is like your SRAM or compute units. And so usually, I kind of think this of like as like a factory where you have like a factory with not that much space. And then you have a warehouse located like much further away and so you have like a lot more space in your warehouse, but now in order to do any operations like on those supplies, you need to move them from your warehouse to your factory and then back. And so this cost of moving data around is called the memory bandwidth cost. And so this is actually responsible for a lot of what your GPU is spending its time doing.

Where if you imagine like let’s say we do three operations on a GPU, so maybe you’re doing an add, and then like a, you know, RELU, and then like a sign operation. Like you can imagine that what happens when you do these operations. Like first the GPU sends the data from the memory to the compute units and then you know, turns it from a square to a triangle and then it sends it all the way back. And then you know, it sends the triangle from the memory units to the compute units again where it does another operation and then sends it all the way back. And then finally, you know, you get the idea, it’s like sending the circle from the memory units to the compute units again and then it’s sending it all the way back. And so by default, whenever you run any operations like in PyTorch like let’s say you ran add and then multiply and then cosine, this is exactly what would be happening on our GPU.

And so you might think that this is a very dumb thing to do and you would be correct. And that, you know, why are we like sending our triangle back from the factory to the warehouse just to send the data from the warehouse back to the factory again. And so a very common operation for GPU compilers to do, and I’d say what I’d actually call the most important optimization in a deep learning compiler by far, is called operator fusion. And so what operator fusion does is that instead of, like, you know, sending the data back and forth so much. we like do a single GPU kernel where you send the data once to the factory units, you do all of the operations and then you send the data back. Also, notably, this is also like an issue. This optimization is not really something you can do in eager mode, right? Because in eager mode I was, you know, mentioning like the executed model is very simple where you like run our operation and then it executes the operation. And now if we want to do this optimization, that program model is no longer sufficient.

And so there’s actually a lot of different ways to minimize memory movement. Although at the end of the day, operator fusion is the most important thing you can do for an ML compiler. There’s a lot of decisions that go into operator fusion that kind of enable it to be more or less effective. One of the examples here is kind of these re computation versus reuse trade-offs. If you guys are familiar with maybe like register allocation type settings, you often have a similar issue where like if you have a register you can choose to either store it in global memory and then load from it later, or you can just choose to recompute that value from values that are ready in the registers. And so you kind of have a similar idea here where you often have cases where by doing some amount of re computation you can significantly reduce your memory. You can significantly reduce your number of memory accesses. And so in this way, the recomputation can not only reduce your peak memory usage, it can also often improve your actual runtime performance.

And so I think one of the things to mention about why this observation actually ends up being quite important for deep learning performance like re computation versus reuse. And I think the reason why is that the shape of machine learning programs actually looks quite unusual compared to like your typical program that you might have. Where it’s generally like kind of a bit of an axiom in programs that usually most intermediates that you have are very short-lived. So like, you know, your program generally consists of a lot of very short-lived intermediates that are created and they’re very shortly destroyed. But in machine learning this is actually not the case. Because in machine learning, like the typical model that you’ll execute will first run the model forward like, you know, layer zero to layer one, to layer two, to layer three, to layer four, and then initially run what’s called the backwards pass, which will run the layers in reverse. So then it’ll go from layer four to layer three, to layer two, to layer one and to layer zero. And in between the forward pass and the backwards pass, you need to save what are called intermediates or activations. And so these are a lot of, like, you have a lot of them. And they’re oftentimes largely responsible for running into out of memory type errors. And so I think this is like kind of a pretty unusual program structure in machine learning that’s caused like back propagation and gradient descent.

And so finally, you know, the last thing you can be spending your time on is overhead. Where you can imagine that, you know, if a poor Gromit is not able to put down the train tracks faster than the train can go on the train tracks, then sometimes a train is gonna be just stuck waiting for him to put down the next train track. And so here I have like a profile trace where you can see that like the bottom line, which is a GPU trace is largely idle and it’s mostly just idle waiting for the CPU to schedule the next operation. And so there are a lot of ways to actually address this nowadays. One of the most powerful is called CUDAgraphs, which is an NVIDIA provided API, but you also have like other approaches like in a compiler for example, like codegenning like a lower overhead wrapper or something like this.

So you know, I’ve talked about ML compilers like, and you know, what they can do to your program and how they can be useful. But I think kind of like an interesting question that you often see like, you know, I talked a lot about how like, you know, we have like, you know, super massive infra build-out and the programs are super simple. And we’ve seen like a lot of consolidation in terms of what the architecture looks like. And so I think like a reasonable question is like, if you only have one architecture and you’re spending billions of dollars to train it, why do you even need a compiler? You know, like why can’t you just like assign some group of people to optimize it by hand instead of leveraging a compiler? And I’m gonna say some kind of or sorry, and the other thing I’ll say about this actually is that like in practice a lot of times people do not use compilers like for kind of this reason. And so this section is gonna be talking a little bit about why I think that’s a case. And what are kind of some of the challenges when it comes to using compilers in this setting.

And yeah, so disclaimer, you know, I do really like compilers. I’m gonna say some kind of mean things about compilers in a bit, but you know, as like to establish my credibility, you know, I work on a team called Fighter’s Compilers. And so there are like, to be clear, a lot of reasons why compilers can be very useful. In particular this kind of notion of leverage, like being able to do the optimization once in the compiler and then having everybody be able to take advantage of it. And also, you know, compilers are also just very fun to work on.

That being said, yeah, I’m gonna introduce you. You know, my new exciting library, Horace’s exciting library abbreviated at HEL. And so it has a couple of cool features that you might be interested in. So the first feature it has is that it doesn’t always work. And you know, to address that, it also has no documentation about why or when it will work except by reading my library as an implementation. And in exchange for that, when you update the library, it may totally change what code works and what code doesn’t work. i.e. no backwards compatibility or guarantees on whether your code works. And so are you interested in using my library? I guess, you know, most people would probably say no.

And so I think one thing to note here is if work means like has desired performance and is applying the desired optimizations, this is kind of largely describing how compiler optimizations work. And that compiled optimizations don’t always work. Often like there’s no real documentation on when a compiler optimization will trigger and you know, when you update your compiler it may completely change when it does or does not apply these optimizations. And so there’s a very influential article about this compiler called ISPC. And they have this note here called auto vectorization is not a programming model. Where they note here is that the problem with an auto vectorizer, which is kind of a compiler… Or so the overall framing of the article is that he wrote this compiler called ISPC, which you can think of as CUDA for Intel SIMD instructions. And he kind of, you know, is constantly like, you know, as part of the article, he is constantly trying to fight against the Intel compiler team, which is where he works. With the Intel compiler team, wanted to leverage auto vectorization to get vectorization done instead of introducing a new program model in ISPC. And so he kind of, I think elucidates what the problem with the auto vectoriser. Which is that the problem with the auto vectoriser is that as long as vectors can fail and it will then if you’re a programmer that actually cares about what code the compiler generates from your program, you need to deeply understand the auto vectoriser. Then when it fails to vectorize code, you wanna be vectorized, you need to either poke it in the right way or change your program in the right way so that it works for you again. And so this is a very horrible way to program.

And then, you know, if most of you are, or if any of you here are very like into using SIMD instructions, you probably also do not trust the auto vectorizer at all and you’re mostly just, you know, writing intrinsics. And so the, you know, with the proper program model, ideally, the user is able to learn what does and does not work without, you know, needing to be tied to the implementation and then, you know, one compiler implements it and then the user can learn to reliably rely on this optimization without needing to understand the compiler and only needing to understand the programming model. And so one I guess way you can phrase this is that a compiler optimization that always works is just part of the programming model. Like for example, when you’re writing SIMD instructions, you know that a SIMD intrinsic will always get mapped to SIMD instruction. And so that’s part of your program model and not really an optimization that the compiler is doing.

So I think, you know, to… Like when I see the people kind of complaining about Shoggoths and working with them, I kind of am often reminded times of a compiler. We can imagine that a compiler is oftentimes a large piece of code that you know has been worked on by a lot of very smart people and oftentimes has a lot of tricky details and implementation details in it. And so ideally when you’re doing a compiler, only the program model ends up being exposed to the user. So like the actual compiler implementation ends up being completely hidden. And so the user only needs to deal with like the nice program model, but it’s usually just the language that the compiler is compiling from.

Unfortunately, you know, when compilers fail and like, you know, they don’t apply the optimizations that you want them to apply. The kind of entire thing becomes exposed. And so, you know, this kind of applies anytime that you’re wrestling with a compiler and you’re trying to understand why the compiler didn’t like inline my code or things like this. And so to give some examples of cases where very kind of nuanced details here can lead to compilers like having, like it can lead to compilers struggling a lot. One of them here is like numerics and machine learning. Where numerics can be kind of a, or like in general floating point arithmetic is a very cursed thing to deal with generally speaking. And it’s just gotten even worse with the fact that everybody in the industry keeps on pushing our data types lower to lower and lower bits where, like on the V100, they kind of introduce 16-bit operations as kind of the default. And on H100, they introduced eight-bit operations and on the B100, you know, they’re now pushing for four-bit operations, like four-bit floats. And operations on four-bit floating point numbers. I think it’s reasonable to question how is this even a floating point number at this point? But this is kind of what we’re typically dealing with. And so as a result of this kind of like low precision and the fact that numerics end up being so subtle, the algorithms have very annoying numerical problems to deal with.

And I think a good example for me that was very frustrating was this kind of NaN in a FlashAttention implementation. Where the underlying cause here is something called an FMA or a fuse multiply accumulate. It can actually be disastrously bad when it comes to your numerics. So for folks who are unfamiliar with FMA, it’s basically like it takes A, B, and C and it does like a signal operation that does A times B plus C. And so one of the ways that FMA differs from normal operations is that in addition to being faster, it’s also usually computed in what people call infinite precision internally. And what that means is that the result of your FMA is like the closest possible representation to the true value of your FMA. And so this is different from like if you just did this operation separately. Where typically after your multiply operation, you would have like a rounding term where it becomes a bit different. And so in this particular case, we’re computing exp of a_i times B minus a maximum scale. Where max scale is like a singular scaler value. That’s a result of taking the maximum across a_i times b. And so the main idea here is that exp is a very numerically unstable operation. And so we want to make sure that we’re not taking any very large exp and so we’re subtracting out the maximum value so that all the exponent, like the maximum size exponent, that you can get here is a zero. And you know, the compiler in its wisdom helpfully rewrites this as FMA of ai b - max scale. Where you know it’s smart here, it actually realizes that it can rewrite the plus into a plus and minus. But then, you know, if you actually look at the numerics that are going on here, max scale is like a rounding of a_i times b. and we actually end up computing this quantity instead of expon of a_i times b minus the rounded value of a_i times b, which can be disastrously off. And so the end result here is that we NaNed with FMAs turned on and you don’t NaN with FMAs turned off. Even though FMAs are theoretically a precision-improving optimization.

The underlying cause here to summarize is basically that our numerical properties relied on two separate computations of the same value to be exactly identical. And in this case with FMAs, like you can apply them to only one branch but not the other branch. And this leads to NaNs in this case.

Another thing that compilers oftentimes struggle with are what are called algebraic rewrites. So if you guys are familiar with this optimization called FlashAttention. Which kind of fuses attention operators together, it actually relies on this kind of subtle rewrite, which people oftentimes call online softmax. Which is that typically softmax is implemented with a couple of global synchronizes I suppose. But you can actually rewrite softmax in a way that removes one of those global synchronizations. But this requires some amount of math. It’s not very hard math, but compilers are generally very bad at math and so it’s kind of difficult for them to actually come up with this rewrite by themselves.

And so I think FlashAttention is a very good example of a case where you need to think about the program model you expose for FlashAttention. Where you know, before we had this fused attention kernel came about, people just typically wrote attention like this, you know. You did like one matmul and then you did a softmax and then you did another matmul. And you know, this was not super efficient, but it was as efficient as you could do. So people were happy with it.

But unfortunately with FlashAttention we now want to fuse these three operations into a single kernel. And so the difficulty is what API do we expose for FlashAttention? And so one way you might tackle this is a pattern matching. So you can use a compiler, try to find the sequence of matmul sort by matmul. And then pattern match it into a single FlashAttention operator. And so this is a reasonable option. But the issue here is that it becomes very frustrating to debug, like for example, the user might change how they write softmax, like they might reimplement softmax with their own softmax implementation. Instead of using your vendor provided softmax implementation. And then all of a sudden your pattern no longer applies and they’re sad because the memory suddenly blows up and their code is 3x slower and so this is very frustrating for users to deal with.

And so instead, you know, what PyTorch did in this case, is that we just said, okay, like these three operators now need to be fused into one operator. We’re just going to directly provide you that single operator as a single API that you can call. And so this is one way typically that you can deal with program models is you can just kind of introduce a slightly high-level thing that does one, a very specific thing for the users, but this is also still kind of frustrating. And you might have some questions about whether this is actually good enough to do. Because you know, when you consolidate multiple more primitive APIs in a single more monolithic API, you oftentimes run into issues where the monolithic API is now no longer able to represent all the things that users want to do. And so with attention, we do see that people kept on coming out with new attention variants like Sliding Window Attention, Alibi, PagedAttention, Neighborhood Attention, all this stuff. Like, you know, you look on Twitter and like new four attention papers come out every single week and as a result, like they kind of fuse the tender kernels that people have, I keep accumulating new cords like, you know, FlashAttention, like this one’s from the FlashAttention repo and now has dropout, a softmax scale, a causal one, a window size, a soft cap, alibi slopes and so on.

And you know, some of these were added in the past couple of months and you know, they just keep on adding them. And once you’ve added a core to an API or to a program model, you can’t remove the core anymore because now that’s breaking what users rely upon. And even worse, like even though users like we’ve… People have been kind of aggressive about adding new cores, it still doesn’t end up being enough. And you have all sorts of users who are constantly complaining that, you know, nobody has implemented FlashAttention for their, you know, pet attention variant. Like there’s no flash attention for prefixLM. This bottom one was from like a blog post from somebody who used to be at Google who was like complaining about the ecosystem outside of Google.

And so basically the point here is that a single monolithic operator is not actually always sufficient. And so we’re in a situation where compilers, like we don’t really want to rely on a compiler to generate from scratch, but it’s also painful to do modifications by hand. And so this might kind of seem like a like no-win situation. Where you kind of must choose one of them, you must choose either unpredictability and doing it as a compile optimization or you must choose a single monolithic API that’s difficult for users to modify. But you know, this is kind of where you can be clever. And come up with a new program model that wasn’t either of the program models that users had before.

So one thing to notice here is that this kind of custom kernel that’s difficult for a compiler to generate from scratch can actually be decomposed into a handwritten/complicated FlashAttention kernel. And a bunch of trivial modifications from users they can be very mechanically generated.

And so here we have an API that we’ve introduced recently called FlexAttention and I really like FlexAttention. And I think we’ve seen very positive reception from the community, including from users who traditionally don’t actually like use compilers that much. And so one of the reasons that FlexAttention is liked. Even though it relies on Torch compile to work is that it’s guaranteed to always result in a single fused detention kernel and it’s always guaranteed to have the same memory properties as a fused attention kernel. And so this means that the user now has a programming model that they can rely upon where they can program against this programming model and try a bunch of different variants that all fit within this programming model. And the user does not need to understand how the actual API is implemented under the hood.

So to give you guys a bit of a peek of how this API looks like. If you guys are familiar with sliding window attention or causal attention, here you can implement causal attention by checking whether your query position is greater than equal to your KV position. And then also checking whether the distance between them is less than your sliding window size. And then you can just and these masks together to get your sliding window causal attention. And so the way we kind of think about and so this is I think a good example of where compilers can provide a lot of value for users even in these super handcrafted scenarios where previously prior to this attention API, you kind of have this situation where you had a bunch of these big cube masks or positional biases or whether it supported training or whether it was supported in inference. And you know, a bunch of these dots were filled in with users who had manually implemented attention kernels. But now with FlexAttention every single one of these dots are now filled in. And so it’s now like consistently users can rely upon fuse attention kernels regardless of the tension variant that they’re using. And it gives them an example of some stuff that users have been doing with FlexAttention. You have on the left, like a bit of a fun attention mask from some guy on Twitter where basically in this case he had a bunch of like molecular graphs of different sizes. And he was able to convert this into like an intention mask that FlexAttention supported. And so this is a very weird mask that, you know, I definitely did not think about when developing FlexAttention. But when you’ve developed a good program model for users, users are oftentimes able to do things that you never considered doing when you developed the abstraction.

Another kind of analogy that I sometimes think about when it comes to optimizations versus program models. Is there’s kind of a famous quote from Grothendieck, about tackling math problems where he said that when you tackle a math problem, you imagine a math problem as a walnut. And it is like, you know, when you’re trying to open the walnut, you can either tackle it by just hitting the walnut a bunch and like opening it up or you can like, you know, like soak the walnut in water and, you know, like establish new mental models for how to think about the problem until eventually, like the walnut just opens by itself after being soaked in water. And so I think about program models like very similarly. Where like, you know, if you think about auto vectorization from the Intel compiler folks’ perspective, you know, he had this fun anecdote where he said that the Intel people kept asking like, what happens when the CUDA compiler fails to vectorize? And he was like, absolutely baffled about what happened. Like he felt like this was just a thing he needed to understand to like, you know, understand, you know, how the Intel people needed to change their compiler to be more competitive. But the kind of misunderstanding here is that it is almost a nonsense question to even ask when does a CUDA compiler fail to make your code parallel? Because if you’ve written your code in CUDA and in the CUDA programming model, it must be parallel. Like it is kind of axiom part of the program model. Like there’s no way you can write a CUDA program that does not execute across multiple cores. Like you can write CUDA programs that are incorrect or deadlock, but they’re still guaranteed to always run across multiple cores. And this is because parallelism is inherent to the program modeling in CUDA while the parallelism is not inherent to the program model of auto vectorization.

I think, yeah, I think kind of ML compilers and kind of, you know, how they apply to like, you know, just generally optimizing GPUs I think like is already quite difficult, but you have a lot more issues when it comes to making GPUs run at scale. And so when you come to like, you know, getting GPUs to run a scale, I do think there are a couple of kind of interesting differences between distributed ML programs versus traditional distributed systems. And so one of the differences is that traditional distributed systems oftentimes are just trying to scale a QPS. And so when you’re trying to scale QPS, you have a lot of small interchangeable queries and the performance per hardware is not that critical. And then oftentimes you’re willing to double your amount of hardware use in order to get fault tolerance or like, you know, like higher uptime or things like this. On the other hand, in ML systems, we just basically have a single query. Like, you know, you’re just training your ML job. We have frequent global synchronizations across all of our hardware and performance is extremely important, you know, so much so that there’s no way we could tolerate like a 2x loss in performance for basically any reason.

And so one way to think about how you can paralyze programs. And this is actually not specific to ML, this is kind of like a general way to think about paralyzing programs. Is that you can think that there’s basically a, like you have this cube of computation to do. Where one of the dimensions is like the batch dimension. So it’s like different tasks that you can perform. You also have the within task dimension, so like the different operations within a single task. And then finally you have the time dimension, which is what your GPU is doing at any given point in time. And so you can think of data parallelism as basically just splitting along the batch dimension. You can think about task parallelism as splitting across the task dimension. And then you think about pipeline parallelism as kind of like splitting across the time dimension.

And so for data parallelism, the basic idea I think is pretty simple, which is just that we have a bunch of different tasks. We have a bunch of parallelism, and so each GPU just handles a different task. And so because of that we just have like… We just put one task on each GPU and then run it. And this seems very nice and very trivial, but the issue is that like we have a massive synchronization of parameters after every step. Because after every step that we do like the gradients need to be synchronized across all of your hardware. And you can’t just avoid synchronizing your parameters because you also have this kind of like, mathy constraint from the ML training side, which is that you simply can’t train with like too large of a batch size or your model will not converge properly. There’s also another detail here where you can’t just like naively replicate your parameters. People oftentimes use what’s called a Fully-Sharded Data Parallel.

The second kind of parallelism that you have is task parallelism, which is commonly known as tensor parallelism. And in this case, you basically just have two GPUs that split the same task. And so the main problem that you run into here that’s kind of specific to ML is that task parallelism because you’re running the same task does not oftentimes have like obvious ways to overlap your communication with your computation. But nevertheless, because the performance is so important, we still really want to overlap the communication. And so there ends up being a lot of involved things that people do to try to improve the overlap. One of them here is called Async Tensor Parallelism. Where the general idea here is they have like a communication op followed by a computation operation. And so usually, you know, while you’re doing the communication, the CPU can’t do any computation and so this is like, you know, wasted idle time on our GPU. But the observation about Async Tensor Parallelism is you can actually kind of like mini pipeline your communication and your matmul. Where even within a single tensor operation, there’s still oftentimes a batch dimension that you can paralyze along. And so by doing this kind of micro pipelining of your communication and computation, you can actually still enable overlap even though we’re doing tensor parallelism.

Finally, the last kind of parallelism that we have is pipeline parallelism. Where the general idea here is that you assign the first part of the task to the first GPU. And the second part of the task to the second GPU, this kind of differs from tensor parallelism or tensor task parallelism in that in this case, you do not run on the same task with both GPUs at the same time. And so in this way, you can think of the task as actually being sharded across the time dimension.

And so there are a couple of additional wrinkles here about pipeline parallelism that are kind of unique to ML I think. So one of them is that, you know, once again the frequent massive synchronizations prevent us from filling up the pipeline. But the second issue is that back propagation, like the forwards and backward pattern that I mentioned before. Also adds a lot of very fun wrinkles to your pipeline parallelism schedule. And in this case, this is kind of like a pipeline that people have designed where the blue boxes are like the forwards pass, the cyan boxes are like the backwards pass and then the green boxes are like the backwards pass split up even more. And in this case, the main nuance here is that you can see like there’s various points here where your pipeline can actually choose to either run a forwards micro-batch or like a backwards micro-batch. And so this kind of choice is not something you need to deal with, in a traditional pipeline parallelism setting. And it ends up leading to a lot of different optimizations that people do with pipeline parallelism.

And so putting this all together, this is a diagram from the llama3 paper. Where they showed what’s being done to train llama3. And in this case, you can see that we’re combining tensor parallelism, task parallelism, with data parallelism with pipeline parallelism. And there’s also a fourth one thrown in here called context parallelism If you can think of that as just another form of a task or tensor parallelism.

And so I think again here, compilers have historically struggled quite a bit when it comes to keeping up with the distributed optimizations that people do. And the main reason is that compilers are dumb and humans are smart. And what I mean by this is that for any set of given parallelism schemes or like any set of given parallelism configs, it often is pretty feasible to create an automatic model for the compiler to determine how to automatically paralyze your program. But the issue here is that actually most of the innovation of parallelism doesn’t come from searching within your existing search base. It usually comes from expanding your search base along a new dimension. And so this is something that compilers are not super great at. And it’s been one of the struggles when people try to automate parallelism with compilers or general systems.

For example, one of the digital wrinkles that have kind of shown up, you know, at certain scales is fault tolerance where you know, when you’re running it on like, you know, 10 or 20,000 GPUs, where we’re doing these globally full synchronizations across all of our GPUs. we have this issue where our GPUs can fail for all sorts of reasons. Some of them might be user-related, some of them might be hardware related, some of them might be networking related, but basically, a single failure takes down your entire run and this is, you know, quite problematic. And so the table from a paper that Meta published about training llama3 and fault tolerance. and like, you know, I think like one kind of example of how this ends up being an interesting issue is that when you’re training on 16,000 GPU hours, you’re only getting an error once every 1.8 hours. But now if you scale it up to 131,000 GPUs, you now get a failure every 15 minutes or so. And so this turns something that might not be that problematic at a 16,000 scale or smaller into something very problematic at a 131,000 GPU scale or even higher where you can imagine that, you know, now you have a situation where you might not be able to make even a single step before a single GPU in your entire fleet fails.

So to kind of conclude, I think ML has basically become the single most important computational workload in the world over the last decade. Maybe a bit biased, but I think by the amount of FLOPS and investment, I think you can make a strong case for it. But the characteristics of the workload both from a social POV, like, you know, the massive infrastructure required as well as the technical POV. I think kind of often means that a lot of the traditional approaches to building systems or compilers don’t directly apply. Like you can’t just like, build a compiler and then, you know, have people who just optimize a compiler for five years without thinking about, you know, how the workloads are evolving on top of them affect their compiler. And so I think to me like the kind of most interesting question about building a system here. Isn’t really about building the right optimizations for systems. It’s instead about coming up with the right programming models for expressing your systems and kind of coming up with the right programming models to enable people to do their own optimizations and kind of build the next, you know, 100,000 GPU model.

Thanks for coming to my talk. Hope you guys liked it.

The next great idea will come from you