HomeVideos

How GPT-5, Claude, and Gemini are actually trained and served – Reiner Pope

Now Playing

How GPT-5, Claude, and Gemini are actually trained and served – Reiner Pope

Transcript

2191 segments

0:00

Today, I'm interviewing Reiner Pope, who is the CEO of MatX,

0:03

which is a new chip startup.

0:04

Previously, he was doing TPU architecture and many other things at Google.

0:09

This is a very different format from my usual interviews.

0:10

This is going to be a blackboard lecture.

0:12

We're going to get up in a second.

0:13

We in fact built this whole new studio with specifically this

0:16

format in mind, so it's a pleasure to get to inaugurate it with you.

0:21

We're going to be talking about model architecture, ML

0:23

infra, and many other things.

0:25

The reason I think it's an important topic is because once you understand how

0:29

training and inference work in a cluster, a lot of things—about why AI is the way

0:35

it is, why AI architectures are the way they are, why API prices are the way they

0:40

are, and fundamentally why AI progress is the way it is—start making sense.

0:45

You need to understand the details to get there, and you need a

0:47

blackboard to understand the details.

0:48

Reiner, thank you so much for doing this.

0:50

Very happy to be here.

0:52

Full disclosure, I am an angel investor in MatX, but that's

0:56

unrelated to this podcast.

0:56

Reiner, to kick us off I'll ask this question.

1:01

We have a couple of companies like Claude and Codex and Cursor

1:05

offering something like Fast Mode, where for 6x the price, they'll

1:09

stream you tokens at 2.5x the speed.

1:12

Mechanically, I'm curious what's going on here.

1:14

Why is it the case that you can pay more to get faster latency?

1:18

Two, could you keep going?

1:19

Could you pay 100x more and somehow get much faster speeds?

1:25

Three, could you go the other way?

1:27

Could you have something like Claude Code "Slow Mode", where if you are

1:31

willing to wait for minutes on end, you could get even cheaper prices?

1:36

Maybe this will help motivate the analysis that you'll be doing through the lecture.

1:39

Great.

1:40

To jump to the conclusion a little bit, the big effect is batch size.

1:44

What we're going to do now is quantify exactly what that looks like and what

1:47

its implications are on latency and cost.

1:50

There's another effect, which you can call speculative decoding

1:53

or multi-token prediction.

1:55

We can maybe come back to that later, but the first thing that

1:58

we'll talk through is batch size.

2:00

What I'd like to introduce is the two principles of analysis.

2:04

First, we're going to look at a roofline analysis of how we run a

2:07

transformer model on a cluster of chips.

2:10

We'll take a

2:13

Blackwell NVL72 cluster, so a rack of 72 GPUs.

2:19

The roofline analysis means we look at memory bandwidth and compute performance.

2:25

The other side of that is that we're going to look at just two simple

2:27

factors of the model: the time to operate on the weights, and the time to

2:32

operate on the context, the KV cache.

2:36

Let's jump in.

2:37

We're going to try and estimate the time that it takes to run

2:42

an inference of a certain shape.

2:45

We're not perfect here.

2:46

We can't exactly predict the time, so instead we're going to approximate.

2:50

We're going to say that the time must be greater than or

2:52

equal to a certain quantity.

2:55

We're going to consider two different aspects: the time it takes to

3:01

do the memory fetches, and the time it takes to do the compute.

3:07

It will turn out that this gives us very strong predictive

3:09

power, even with a simple model.

3:12

One by one, what is the time that it takes to do the compute?

3:19

There are really two things I need to do in the compute.

3:21

I need to multiply by all of the active parameters, and then I need

3:25

to do some work on the attention.

3:28

Multiplying by all the active parameters, I have a certain batch size that

3:30

I'm running, and I've got a number of active parameters in my model.

3:38

Then I'm just going to divide this by the compute throughput,

3:41

which is the FLOPs of the chip.

3:45

This is a hardware concern.

3:48

This accounts for all of the compute time for all of the weight matrix multiplies.

3:54

There's a little caveat here.

3:56

We've ignored the time to do any of the attention computation, but that in general

4:00

will be quite small in comparison to this.

4:02

So we'll ignore this.

4:03

I'll just interrupt from time to time to ask some very naive questions

4:05

or to clarify some basic points.

4:09

For the audience, you're not serving one user at a time.

4:11

The batch refers to the fact that you're serving many different users at the

4:14

same time, and that's a whole batch.

4:17

I can motivate the batch at least a little bit.

4:20

We will see exactly why batch is such a favorable optimization.

4:23

What will turn out to be the case is that if you do not batch together many

4:28

users, the cost and the economics you get can be a thousand times worse than

4:34

if you do batch many users together.

4:36

We'll be able to see that quite explicitly.

4:38

Then, number of active parameters.

4:40

If I look at, for example, a DeepSeek model, the DeepSeek V3 model has

4:44

about 37 billion active parameters, and 700 billion total parameters.

4:52

We're focusing on just the ones that are active for a single AI token.

4:56

We're modeling compute performance.

4:58

I'm going to keep writing equals, but in all of these cases, you can think of this

5:00

time as being at least this much, and maybe there will be some terms we ignored.

5:05

On the memory side, what do we need to do with memory?

5:09

We need to fetch all of the weights, so there is some time to fetch

5:16

the total number of parameters, not just the active parameters.

5:21

There's weight fetch time, and then in addition, there's a KV cache fetch time.

5:27

This actually depends on batch size.

5:30

For every element of the batch, we have to fetch an entire context

5:36

length worth of tokens, and there's a size per token, bytes

5:44

for one token.

5:46

This is a model parameter.

5:47

Maybe just backing up, let's explain what the KV cache is real quick.

5:52

When I do a forward pass… Let me draw how the autoregressive inference works.

5:58

This is during decode.

6:01

If I have a bunch of text tokens… I'm drawing a tensor because ultimately

6:06

the tokens are represented as

6:09

a tensor in some embedding dimension.

6:12

In this direction, I have the sequence length.

6:18

The work of running a decode is that I have to run each token through

6:22

a whole bunch of matrix multiplies over a bunch of different layers.

6:29

In general, I'm going to have to do that work over all of these tokens.

6:35

But one step of decode is to produce just this one additional token up here.

6:42

What I'm going to do there is run a full forward pass of multiplying by all of

6:47

the weight matrices in the entire model.

6:50

But then I've got this attention mechanism where this token is looking

6:55

at all of the past tokens, and what is it looking at specifically?

7:01

It is looking at some internal representation that the model

7:03

has produced of the tokens, and we call that the KV cache.

7:08

This process of this single token attending to all of the

7:11

history of tokens is attention.

7:13

It is mostly dominated by memory fetches rather than matrix multiplies.

7:19

So we've got the amount of memory that we're fetching shown over here,

7:23

and then this is of course just divided by the memory bandwidth,

7:27

so the memory bytes per second.

7:35

In fact, these equations here are enough for us to now draw some fit lines.

7:41

The things that we'd like to look at are sensitivity to batch, and then also,

7:46

which we'll draw separately, to context

7:51

length.

7:52

We said that the big effect you can get is some trade-off in

7:55

latency versus cost in batch size.

7:58

Let's draw them out.

8:00

I think there are just really two graphs that we want to draw.

8:02

We'll first draw batch size versus time here.

8:11

When we look at the shape of this, we've got a maximum of

8:15

the sum and then another term.

8:19

Let's look at these terms one by one and how they scale: the time for compute

8:24

and memory, and how they show up.

8:27

Let's first look at this compute time.

8:31

This is just purely linear in batch size with no offset, so

8:35

it is some curve like this.

8:37

This is t compute.

8:44

On the memory side, we've got some portion here that is just this constant

8:51

in some base offset here, which is the weight fetch.

9:00

Finally, we have this term here, which is the KV fetch, which is pretty

9:14

linear in batch size, and so it looks like that.

9:17

The sum of this plus this maxed with this… Let's at least first draw the sum.

9:29

The two memory times in conjunction end up looking on this curved slope like this.

9:35

Then the overall maximum is—I'll draw a little thicker here—the

9:41

maximum of these two curves.

9:44

What does this mean?

9:47

This is a latency plot.

9:54

If I grow my batch size, initially I get some not very strong dependence

9:59

on batch size, so there is some lower bound on latency here.

10:11

This already partially answers the question.

10:13

For a given hardware configuration—and we can talk about varying the hardware

10:18

configuration—there is a lower bound on latency.

10:20

It is simply that I need to read all of my total parameters

10:27

from memory into the chips, and that takes a certain amount of time.

10:31

If I use all of my memory bandwidth, I can't do any better than that.

10:34

It seems like the way you've drawn the slopes for compute time and how

10:40

the KV grows—and what implication the KV has on memory time—

10:42

What if

10:48

this were above or below?

10:49

Yeah, is that necessarily the case?

10:52

If this is always true, then as batch size grows compute always

10:56

dominates KV, which suggests that if you have a big enough batch size,

11:00

maybe memory is never an issue.

11:02

This is really sensitive to the context length, so I think we

11:05

should come back and explore this.

11:08

As you vary the context length, the KV fetch time will go up and up, and

11:11

that will cause a transition from compute-limited to memory-limited.

11:15

Is there something especially significant about the slope being

11:19

exactly the slope of the compute time?

11:25

Whenever we have balance points, it says that you're getting it exactly right.

11:29

For the particular context length where the slopes match, that says I am

11:34

equally memory-bound and compute-bound, which is a really desirable place to

11:39

be.

11:40

This is a very simple algebra problem, but suppose the optimal is 100K context

11:48

length, and you go to 200K context length.

11:52

Does your MFU go down to 50%?

11:54

Does it have a humongous impact on MFU to be slightly outside of the optimal

11:58

context length range, the Goldilocks zone?

12:01

That's right.

12:02

That is true as modeled here.

12:04

There is a key point here that

12:08

I'm modeling the memory fetch as linear in context length.

12:11

That depends on model architecture.

12:13

It is true for all of the model architectures with dense attention.

12:20

Sparse attention actually scales much better than that.

12:22

Got it.

12:23

Is sparse attention what everybody uses in practice?

12:25

I'm pretty excited about sparse attention.

12:27

It's hard to know what the labs are using.

12:29

DeepSeek has published a sparse attention mechanism.

12:31

I'll just put a plug in that some of the DeepSeek papers that have

12:36

published sparse attention end up putting a square root in this term.

12:40

So far, we've looked at the latency.

12:42

It's hard to read off cost from this.

12:45

If I think about what cost means…

12:49

To run this inference, I'm going to use the GPU for a certain number of seconds,

12:52

like one millisecond or 20 milliseconds.

12:56

I have to pay the rental time for that time.

12:59

So it's $2/hour per GPU or something like that.

13:05

That's the cost of this inference, but how many tokens have I

13:08

processed during that inference?

13:10

That is the batch size.

13:12

What we actually want to plot is the cost versus batch size, which

13:18

is t over B versus batch size.

13:23

This is the cost per token.

13:30

We have to imagine dividing each of these three curves by B, so

13:34

multiplying by this reciprocal.

13:38

What we end up with there is… The compute curve

13:44

was linear.

13:44

We divide by B, and that makes it a constant here.

13:48

This is t compute.

13:52

The KV fetch was linear, and now it becomes a constant as well.

14:00

Then the weight fetch

14:10

was constant, and now we've divided by B, so it becomes this

14:22

parabola.

14:22

Again, we're going to compute the max of the sum.

14:28

The sum of these two terms shifts the parabola up.

14:33

The sum of the KV fetch and the weight fetch gives us

14:39

a higher parabola that's like this.

14:41

Then we're going to take the max with the compute

14:45

here.

14:45

We end up with this being the overall shape that we care about.

14:52

Again, we see some limiting behavior.

14:54

The cost initially starts very high at a batch size of one.

14:59

It almost goes to infinity because we've got so many weight fetches that are

15:04

not amortized over a large batch size.

15:07

But as we increase the batch size, the weight fetches become amortized over so

15:11

many different batch elements that their cost grows very small, and eventually the

15:15

compute time ends up driving the cost.

15:18

So there is a limiting

15:23

lower bound on cost,

15:29

which is this line here.

15:31

So Claude Code Slow or Codex Slow or whatever would just live on this line.

15:35

It wouldn't help much because you're not able to amortize the KV

15:40

values over a much bigger batch.

15:44

They're unique per batch.

15:45

The compute is also unique per batch.

15:46

So what is the minimum work you can do per batch after

15:49

amortizing everything else away?

15:52

This point where you are no longer memory bandwidth bound, practically

16:00

how big a batch do you need?

16:04

How big are the batches practically for frontier models?

16:07

You can just solve for that.

16:09

It's not even particularly sensitive to model architecture.

16:13

Let's go ahead and do that.

16:15

What we're talking about is when the memory time is equal to the compute time.

16:19

That’s what that question is.

16:26

Because we're focused on what the batch size is—and really there's a question

16:30

of when the weights are amortized over the multiplies—I'm going to

16:34

focus on comparing the weight fetch time to the weight multiply time.

16:38

I'm going to disregard the KV fetch term just to simplify the analysis

16:43

so we can get a clean answer out.

16:46

We're going to equate

16:50

this portion with these two times.

16:58

Writing that out, we get N, number of total parameters,

17:04

over memory bandwidth,

17:09

is equal to

17:12

batch size times number of active parameters

17:17

divided by the compute performance.

17:22

Looking over here, everything on the top are model parameters.

17:26

Everything on the bottom are hardware parameters.

17:28

It turns out to be nice to rearrange them such that we have

17:31

the hardware parameters on one side.

17:33

This is equivalent to

17:40

FLOPs over memory bandwidth

17:44

being equal to batch size times number of active parameters,

17:52

divided by the number of total parameters.

17:56

This hardware

17:59

parameter ends up being a dimensionless constant.

18:01

If you look in terms of FLOPs… What are the dimensions of this?

18:04

This is multiplies per second.

18:06

This is bytes per second.

18:07

So that's not quite dimensionless.

18:09

But what you do is you say, how many FP4 multiplies per second times

18:20

the fact that each FP4 is half a byte.

18:24

I can actually make this end up being dimensionless.

18:25

On most GPUs, this ends up being

18:32

somewhere around 300.

18:37

Has that ratio changed over time as we've gone from model

18:39

generation to model generation, where the FLOPs keep increasing?

18:41

This is a hardware parameter.

18:43

To what extent has the hardware changed?

18:46

From A100 to H100 to B100, the FLOPs have increased substantially,

18:51

the memory bandwidth has also increased substantially, and it

18:53

has remained reasonably stable.

18:56

We can express this one as well.

18:57

This is a sparsity parameter.

19:00

I might even phrase this slightly differently.

19:01

Let's solve for batch size in total.

19:05

Moving this back over to the other side, we end up with batch size needs

19:09

to be bigger than approximately 300

19:13

times sparsity.

19:16

For example, in DeepSeek I activate 32

19:21

out of 256 experts, so this would be 8 for DeepSeek.

19:27

This actually gives you a ballpark which is remarkably accurate to practice.

19:31

Generally, people will go a little bit larger than this.

19:33

They don't really want to be exactly at the balance point because

19:37

real-world efficiencies aren't as good as a roofline analysis would say.

19:41

But take this and maybe double or triple it.

19:44

Okay, so it's two to three thousand tokens per batch.

19:49

But then if you included the KV cache, the implication would be

19:55

that the optimal batch size...

19:57

Should grow larger.

20:00

We solved for the equivalence between when compute time is equal to memory time.

20:06

If I add in something that consumes more memory bandwidth, then I have

20:10

less available for the weight loads.

20:13

I need to grow the memory bandwidth more, and therefore the batch size more.

20:17

This seems incredibly small.

20:19

This would be less than one sequence, right?

20:24

Keep in mind that I'm talking about the number of tokens that

20:27

I'm generating one more token for.

20:30

It's actually 2,000 unique sequences.

20:33

Got it.

20:33

We're just talking about a single forward pass on these sequences.

20:39

You think of the batch as the number of sequences.

20:41

That’s right.

20:43

When I'm prepping for interviews, I often talk to experts in the field.

20:45

So for Reiner, I chatted with two of Jane Street's engineers, Clark and Axel.

20:50

Clark, who works on low latency trading systems, walked me through why Jane

20:53

Street uses FPGAs to make sure that they have predictable nanosecond latencies.

20:57

“You can just build these like giant grids of compute very easily that do

21:01

exactly what you need that touch a hundred megabytes of SRAM and then

21:04

get your response back in tens of nanoseconds very easily. And that's

21:08

basically impossible on a CPU.”

21:09

He then went on to explain why CPUs just wouldn't work for this kind of thing.

21:13

“And so if you have a clock that's going every three nanoseconds, you actually

21:16

have several bytes of information at a time to make your decision.

21:21

That's as opposed to a CPU where you'll just collect up a whole packet, you

21:24

know, let's say a 1500-byte packet, and then you say, okay, this packet's ready.

21:26

Here you go, CPU.

21:27

You can start thinking about it now.”

21:29

FPGAs allow you to react to the earliest part of the packet as it arrives, rather

21:33

than having to wait for the full thing.

21:34

We also talked about liquid cooling, network design, and many other things.

21:37

If you're interested in this stuff, Jane Street is hiring.

21:40

You can check out their open roles at JaneStreet.com/Dwarkesh.

21:46

And if you want to watch the full prep conversation, we posted it there too.

21:49

If you've got a frontier model and you are actually doing inference, surely they must

21:56

have more than 2,000 concurrent users.

21:58

Is there any added latency from the fact that you need to

22:01

have the whole batch fill up?

22:02

Or if you have a reasonable amount of users, is it so unlikely that

22:08

it would take you 100 milliseconds to fill up the next 2,000 slots?

22:13

The way to think about this is: when does the train depart, as a model?

22:18

Let's say I've picked a batch size that I'm going to run at.

22:25

By the way, this intersection point is the same intersection point here.

22:30

I pick this batch size, and I know that it's going to take, for

22:32

example, 20 milliseconds, which is a common place this ends up landing.

22:36

This is a timeline of what

22:42

is running on the GPU.

22:43

It's going to start a new batch every 20 milliseconds regardless.

22:56

You can think of this as a schedule for the train.

22:58

A new train departs every 20 milliseconds.

23:00

Any passengers who are ready board the train.

23:02

If the train is full, they wait until the next train.

23:05

If the train is not full, the train is going to go anyway.

23:07

In terms of what that means for queuing latency, the worst case is that a request

23:15

arrives just after the train departed.

23:17

It has to wait for the next train, so that's up to 20 milliseconds, and then it

23:21

has to wait for that train to complete.

23:25

So the worst-case latency is 40 milliseconds.

23:27

How is the 20 milliseconds derived?

23:28

It's a rule of thumb, but where it comes from is not fully explained yet.

23:36

So far we've focused on memory bandwidth and compute time.

23:40

When we look at memory, the other consideration is that we want to use

23:43

all of the memory capacity we have.

23:47

Generally, we're going to use all of that memory capacity to

23:50

store the weights or the KVs.

23:55

In the time of doing a forward pass, we want to read all of the

23:58

memory capacity into the chip.

24:01

That is capacity divided by bandwidth.

24:03

That tends to be 20 milliseconds on many different generations of HBM.

24:07

The units make sense.

24:08

You would have

24:11

a byte divided by bytes per second.

24:13

For example, on the Rubin generation, it is something like 288 gigabytes

24:18

divided by 20 terabytes per second.

24:22

This

24:28

comes out to about 15

24:32

milliseconds.

24:33

Let me make sure I understand what this is saying.

24:36

I understand the unit analysis.

24:38

What it's saying is

24:43

we can evacuate and replace the HBM in this amount of time.

24:50

So we don't want to be in a situation where the HBM is not big enough that we're

24:56

not actually able to write everything we want to it or take everything out of it.

25:02

Or we don't want to be in a situation where our ability to write back

25:05

and forth is so small compared...

25:08

There are sort of two scenarios.

25:09

Why don't we pick a latency that is bigger than 15 milliseconds?

25:14

If I think about what that means, it means I actually have

25:16

time to read the HBM twice.

25:19

By the way, most HBM accesses are reads, not writes.

25:21

It's almost all reads because the weight matrices are read-only, and

25:25

almost all of the KV cache accesses

25:30

are reads.

25:30

In around 30 milliseconds, I can read all of HBM twice,

25:32

but what's the point of that?

25:35

I don't want to read the weight matrices twice.

25:37

I don't want to read the KVs twice.

25:38

Makes a ton of sense.

25:40

A couple of quick questions.

25:43

If it is the case that the optimal batch size is something like 2,000,

25:49

it's totally dependent on the sparsity, not dependent on

25:51

the model size or anything.

25:52

Sparsity shows up in model size, but beyond that, it only depends

25:55

on sparsity, not on scale.

25:57

That's a very interesting result.

26:02

One question is how much of a push towards centralization is it that

26:07

you would have these economies of scale from inference for batching?

26:10

But it seems like it's not that big a deal.

26:13

Is 2,000 users at the same time a lot?

26:14

It doesn't seem like a lot.

26:15

We can do a bit of analysis on this.

26:18

You can think of it in terms of number of users, but a more productive way to think

26:21

of it is in terms of tokens per second.

26:25

What does this batch size mean in terms of tokens

26:30

per second of the system?

26:32

Tokens per second is going to be equal to the batch size.

26:34

We run a batch of tokens, and we do that

26:40

every time interval, which is

26:44

equal to the 15-millisecond or 20-millisecond number.

26:48

This ends up being batch size times about 60, so 64

26:56

x B.

26:58

This ends up being around 2,000 x 64, so 128,000 tokens per second.

27:09

This is in more digestible units.

27:11

It's hard to reason about concurrent users, but what is

27:14

the global traffic for a system?

27:20

When you look at some of the announcements, sometimes the

27:23

API providers will brag about how much traffic they have.

27:28

The numbers I remember from some announcements of Gemini last year

27:31

were in the hundreds of millions of tokens per second worldwide.

27:34

This

27:37

is one-thousandth of that.

27:40

Gemini is big.

27:42

One-thousandth of Gemini is a lot.

27:44

To actually be competitive at scale, you need to be able to serve at

27:49

least one-thousandth of Gemini.

27:50

That's interesting.

27:57

The more sparsity you have, the less compute you need.

28:04

It does seem that as batch sizes get bigger, compute ends up being the

28:09

bottleneck, according to this analysis.

28:11

Then the question is, how far can you take sparsity?

28:14

As the sparsity ratio increases, as you have fewer active parameters relative

28:19

to total parameters, how much is the performance of the model degrading?

28:23

Is it degrading faster than you're saving compute by increasing the sparsity factor?

28:31

You mean the quality of the model, rather than the speed of the model.

28:36

Unfortunately, we're not able to answer that analytically.

28:40

That is an empirical question of model quality.

28:43

The best I can do is pull up a paper and answer that empirically.

28:46

Should

28:50

we pull up the paper now?

28:50

This paper is "Unified Scaling Laws for Routed Language Models."

28:53

It's a somewhat old paper by this stage, but one of the things they looked

28:57

at is if I keep increasing sparsity, what is the model quality impact?

29:01

This answer is very sensitive to the actual choice of mixture of experts.

29:05

Mixture of experts has been around for a really long time, maybe even back in 2017,

29:11

but the techniques have changed a lot.

29:13

DeepSeek's mixture of experts was a big change in how it worked.

29:17

There have been older papers, like "GShard" and "Switch Transformer".

29:21

The actual empirical results are going to depend on all of that.

29:24

On one of the older techniques shown here, you can see if I hold constant

29:29

the number of active parameters at a certain size, and then I increase

29:32

the sparsity, which they call expert count, the quality keeps increasing.

29:37

If you imagine drawing a horizontal line from 1.3B dense across, you end up

29:43

seeing that, in this case, the 64-expert 370-million activated parameter model

29:49

is as good as a dense 1.3-billion model.

29:52

So in some sense, it's actually not amazing returns where you need

29:55

to increase total parameters a hundredfold to get the equivalent

30:00

of 10x as many active parameters.

30:04

Actually even more so.

30:05

It's a huge increase in parameter count for a modest increase in efficiency.

30:10

So in this case, actually it's 4x?

30:11

64x for 4x.

30:13

So while it is true that you get this benefit of being able to economize

30:24

on your compute time if you increase sparsity, naively it would seem

30:29

like a trade-off worth making.

30:32

But if you're decreasing

30:35

this by 2x and then having this go up by 8x every time you double sparsity...

30:42

Is that good or bad, actually?

30:44

Even from a memory point of view… Keep in mind you are doubling this

30:48

portion of the memory fetches, which is amortized by batch.

30:52

So just keep running a larger batch size.

30:56

From the point of view of the analysis we've done here, this is a pure win.

31:00

Keep doing it until you run out of available users, basically.

31:08

There's

31:12

this equivalence where if I have a lot of users, I can go to a much sparser model.

31:16

From that point of view, it's a reasonable trade-off.

31:18

The other trade-off that shows up here is that it also consumes memory capacity.

31:23

We've only reasoned about memory bandwidth here, but it

31:25

also consumes memory capacity.

31:26

I see.

31:27

Let me make sure I understood.

31:29

You're saying we want

31:35

to spend less time computing, therefore we do more sparsity.

31:40

To make that work, we need bigger batch sizes.

31:42

Which means we need more memory capacity

31:48

to have more sparsity.

31:49

Maybe this would be a good point to talk about how a mixture of experts layer is

31:54

typically laid out on a rack of GPUs.

31:58

Cool.

31:58

Makes sense.

32:00

Where were we?

32:01

Sparse mixture of experts.

32:03

Maybe how we lay that out on a GPU.

32:08

Let's zoom in on the mixture of experts layer first and draw what that looks like.

32:15

Typically, we'll have some kind of a router layer, which is making the

32:20

decision of where we route the tokens to.

32:23

We get tokens coming in here, they go through a router layer, and then

32:27

we have a bunch of different experts.

32:34

I'll draw a few more to line some up.

32:38

The router will make a decision of which experts it's going to

32:41

route to, and it will be a small fraction of them, maybe 1 in 32.

32:45

Maybe it will make a decision to route to this one,

32:49

maybe this one, and maybe this one.

32:56

Each expert itself is a normal MLP.

32:59

It has an up projection and then a down projection with a nonlinearity in between.

33:04

Then finally, we do the inverse operation.

33:07

Where we were broadcasting things out here, we're going to bring

33:10

them back in and sum them up.

33:16

Bringing them in like

33:19

this.

33:19

Then finally, we have our residual connections.

33:21

The token is also passed through here, and it gets added to

33:26

the result of the MoE layer.

33:28

This is a normal MoE layer.

33:31

What I want to talk through is how this is mapped to a GPU rack and what

33:37

this means for communication, because I think this will start to show some

33:41

of the limits of how sparse we can go.

33:46

The standard practice here, and it is the best solution,

33:48

is to use expert parallelism.

33:51

That means different experts go on different GPUs.

33:54

If we take something like a DeepSeek model, they have 256 experts.

34:00

Let's say we want to run that on a Blackwell rack.

34:04

There are 72 GPUs.

34:07

We have a divisibility problem.

34:09

This is not a power of two.

34:11

We'll just simplify and say we're only going to use 64 of them.

34:16

Just ignore the other eight.

34:17

It's not a big deal.

34:18

So we have four experts per GPU.

34:23

Very simple.

34:24

For the sake of the diagram, actually let's just say we

34:26

have two experts per GPU.

34:28

We end up just putting these GPU boundaries.

34:34

Every pair of experts is on its own GPU.

34:39

Then we can look at the communication cost.

34:40

We had some tokens stored centrally here.

34:44

They get routed to all of these experts,

34:48

and there is some communication cost paid here.

34:51

There's the same communication cost paid on the output.

34:55

The hope is that this does not become communication limited.

34:58

Now

35:01

what is the traffic pattern here?

35:03

The traffic pattern here is that any GPU will be talking to

35:06

any other GPU, depending on the decisions made by the model.

35:11

This is an all-to-all traffic pattern.

35:14

When you say any GPU in the pre-tense, the router is more than one GPU?

35:20

I drew this as one router.

35:22

In reality, you would actually have many copies of the router, and you would

35:25

have as many routers as GPUs, in fact.

35:30

As the incoming traffic.

35:32

Yeah.

35:33

These are 64 GPUs and these are 64 GPUs.

35:37

It's actually the same GPUs, we just draw them as separate because

35:40

they're serving different purposes.

35:42

So at this point, any GPU can be sending to any other GPU.

35:46

This all-to-all pattern of communication that shows up and how the Blackwell

35:52

racks are configured is a perfect fit for the communication pattern that the MoE

35:59

actually wants to do.

36:01

However, if you think maybe one rack is too slow and I want to do two

36:06

racks, then I have this challenge that maybe I've got some sort of rack

36:11

boundary drawn outside here like this,

36:17

and I no longer have all-to-all communication between all

36:21

the GPUs in two racks.

36:24

The rack-to-rack communication ends up being a substantial bottleneck.

36:30

The fundamental thing here is that one rack bounds the size

36:33

of an expert layer you can do.

36:36

This has been part of what's been driving towards larger and

36:40

larger interconnect domains.

36:42

Before we continue, it may be worth you explaining what exactly a rack is.

36:47

The differences in bandwidth between a rack and within a rack, and the

36:52

all-to-all versus not all-to-all nature of communication within versus outside.

36:56

This is a place where it starts to be very different between Nvidia, for example,

37:00

and Google, and then others, including us.

37:04

Generally, a rack

37:09

is a physical structure.

37:11

It's a few meters tall, a meter or two wide, depending on configuration,

37:16

and it stores some number of GPUs or XPUs, which is typically about

37:24

64.

37:24

What constrains it being a certain size is power delivery,

37:27

weight, and cooling ability.

37:31

It ends up being about this size in many cases because of

37:34

these physical constraints.

37:38

When I deploy a data center, a data center may have thousands of these racks.

37:42

I've got one of these tall racks, it's got a bunch of GPUs in it, and so on.

37:46

And then I put another rack next to it.

37:50

You make it sound so easy.

37:51

Right.

37:52

I just drop them in.

37:55

In Nvidia's case, the communication topology…

38:02

They actually put the GPUs on the outside of the rack, and then they put

38:07

these switches on the inside of the rack.

38:09

What this ends up being is that there's a set of switches in here.

38:13

These are the NV switches.

38:17

Then they run a bunch of cables.

38:19

Every single GPU has cables going to the switches in the middle.

38:33

The switches have connections to all the GPUs.

38:35

All of the GPUs can talk to all the other GPUs in just two hops: going to

38:39

the switch, going to the other GPU.

38:41

Now, when I want to leave the rack, I end up going via a different path.

38:47

The GPUs also have a much slower connectivity, which is typically

38:52

about eight times slower.

38:56

The green that I drew here in the GPU cases is the NVLink.

38:59

More generally, it's called the scale-up network.

39:06

You will typically also have a scale-out network, which allows you

39:10

to connect to some data center switch.

39:13

All

39:19

of the GPUs will have some connectivity up to some data center switch somewhere.

39:23

This is

39:26

the scale-out, and

39:31

it tends to be about 8x slower

39:35

in bandwidth.

39:39

The challenge, if you want to lay out a mixture of experts layer across two racks,

39:44

is that half of the GPUs here are going to be wanting to talk to the GPUs here.

39:54

On average, when I look at where the tokens on these GPUs want to go, half of

39:59

the tokens want to go inside the rack.

40:00

That's great.

40:00

They can use the fast scale-up network.

40:03

But half the tokens are going to want to leave the rack and go to the

40:06

other rack, and that's not as good.

40:07

They need to use a much slower network, and so that becomes the

40:10

bottleneck on the all-to-all pattern.

40:13

A different choice would be, why don't I have a big switch

40:18

here and connect everything to

40:24

a much bigger switch that actually combines the two racks together?

40:27

There are many ideas in this direction, but in general, the

40:31

reason you have this hierarchy of switches rather than one big switch

40:34

is to manage the cabling congestion.

40:35

You just need to run a large number of cables.

40:39

Sorry, is that question you just asked basically, why isn't it a bigger scale-up?

40:43

Exactly.

40:44

Why not just have a million chips in scale-up or a thousand chips?

40:47

What has changed that has allowed Nvidia to go from Hopper, which was 8, then

40:53

Blackwell is 72, and now Rubin will be...

41:00

is it 500 something?

41:00

Yeah, 500 and something.

41:01

What has allowed that to happen?

41:02

From Hopper to Blackwell is mostly just the decision to switch from

41:10

trays as the form factor to switching to racks as the form factor.

41:15

That's a product decision.

41:16

There wasn't a substantial technical barrier there.

41:21

Switching from 64 to 500 or

41:27

so, there's a bit of Jensen math there, but there is at least a genuine 4x

41:33

increase, which is coming from a much more complicated and difficult rack design.

41:38

That is actually a new physical design to run more cables.

41:42

The cable complication is just the cost of figuring out which cable hops to which,

41:49

or which signal goes from what to what?

41:51

Let's zoom in on this and look at the wire density.

41:57

I'll draw this diagram just once more so we have a bit of a cleaner

41:59

and larger version to work with.

42:03

Let's say I have some switches in the middle.

42:04

Initially, I'm going to start with just two GPUs on each side

42:09

or two trays of GPUs on each side.

42:12

Let's say maybe each tray wants to have two cables coming out of it.

42:21

I physically run vertical cables that look like this running out to the switches.

42:25

Now if I want to double the number of GPUs in a rack,

42:31

I need to run literally twice the density of cables.

42:35

I need to run

42:38

these as well.

42:42

Extremely naive question.

42:43

But if you look at a physical data center, it seems like there's

42:47

a lot of space within a rack.

42:49

I don't know.

42:49

The cables are really big and...

42:52

There is space outside the rack.

42:54

Inside the rack… As they become more optimized, these racks are

42:59

very tight.

42:59

There's

43:02

connector density going from

43:07

the tray into the rack and the rack's backplane, and the backplane

43:10

itself has a really high density.

43:13

There are other physical constraints including the bend radius of cables.

43:16

You don't want to snap them and so on.

43:19

Okay, so it's literally the physical space to put a cable that's constraining it.

43:22

I had no idea.

43:23

Interesting.

43:24

That seems surprising.

43:25

The

43:27

rack is so big and we can't just stuff more cables in there.

43:31

Rack design is not my expertise, but when I talk to folks on what

43:34

constraints they're up against, it's a combination of things.

43:39

What are the big physical things you're optimizing for?

43:42

Space, weight of the rack.

43:45

It's actually really heavy, so you need enough metal to not sag and fall.

43:50

But then you add more metal, and it's heavier.

43:52

Then power and cooling.

43:53

All of those are competing.

43:56

Modern racks are pushing all of those to very extreme physical limits.

44:00

Deep work is by its nature quite aversive, so even things which seem

44:03

like work, like Slack and email, can be easy ways to distract yourself.

44:07

So I often wish that I could just turn the internet off.

44:11

But if I'm prepping for an interview, even if I have the papers and books on hand,

44:14

it's still super useful to be able to do a back and forth with an LLM so I can break

44:18

down concepts and research follow-ups.

44:20

Google's new Gemma 4 is the first open model that allows me to have this kind

44:24

of fully disconnected focus machine.

44:26

It's small enough to run on my laptop, but good enough to actually be useful.

44:29

So to prep for this episode, I downloaded Reiner's scaling

44:32

book and shut off the internet.

44:33

I was able to have Gemma help me understand the material

44:35

and answer my questions.

44:36

If you want an LLM that you can run locally on your laptop or even your

44:39

phone, you should check out Gemma 4.

44:45

When was GPT-4 released again?

44:46

Was it 2022 or 2023?

44:48

2023.

44:48

Okay.

44:49

And it was rumored to be over one trillion parameters.

44:53

It seems like only now, within the last six months, have models been getting

44:58

released that have significantly more parameters than the model released three

45:00

years ago, when supposedly there should have been this scaling in the meantime.

45:07

Is the reason that we were just waiting for racks with enough memory

45:11

to hold a five-trillion parameter model, along with its KV cache for

45:18

enough users for a lot of sequences?

45:21

Or if you're doing RL, a similar consideration of actually

45:25

holding the KV cache for

45:28

the batch of problems you're trying to solve.

45:30

If you look at Hopper, you had eight Hoppers, and I think

45:35

that's 640 gigabytes as of 2022.

45:39

With Blackwell finally, which was deployed in…?

45:42

Very recently.

45:43

Maybe last year.

45:44

Last year.

45:44

You finally have a scale-up on the order of 10-20 terabytes, which is

45:49

enough for a 5T model plus KV cache.

45:53

Deploying in larger scale-up domains is a huge unlock.

45:58

I've drawn here the Nvidia Blackwell deployment.

46:01

The Google deployment has actually had very large scale-up

46:04

domains for a long time.

46:05

That also explains why Gemini seemed to be ahead.

46:08

It

46:11

just seems like Gemini has had successful pre-training for longer

46:14

than some of the other labs.

46:15

Not having been there at the time, I'm not sure how much is coming

46:17

from successfully deploying higher sparsity ratios, which it could be.

46:22

It could also be a whole bunch of actual modeling things,

46:27

specifically how you do the mixture of experts.

46:29

We've seen

46:33

the DeepSeek mixture of experts activate more experts, but finer-grained experts.

46:38

That was a big innovation.

46:39

I'm sure there are many other innovations on the model architecture

46:43

as well as on the training data.

46:44

It's hard to disentangle all of them, but what shows up in terms

46:48

of the limits of what you can do

46:52

is that the active parameters, as we saw, are limited by the compute

46:57

cost, and the total parameters are limited by the scale-up size.

47:02

When you're operating within a single scale-up domain, is that

47:06

a consideration specifically for either forward or backward, or

47:12

specifically for prefill versus decode?

47:17

Or is it preferred to always be within a scale-up whatever kind

47:23

of workload you have, whether you're doing a pre-training run, RL

47:29

generation, or inference for users?

47:32

Really interesting.

47:37

To answer that question, we're going to need to talk about

47:38

the communication patterns.

47:40

We've talked about the mixture of experts communication pattern.

47:43

That is this all-to-all.

47:51

All-to-all very strongly favors full connectivity,

47:57

which is what we've just shown here, and it favors being within one rack.

48:03

There are other kinds of parallelism besides expert parallelism,

48:06

which we just showed here.

48:08

In the literature is

48:12

tensor parallelism.

48:12

With the trend towards smaller experts, this has become much less

48:15

relevant, so we can ignore that.

48:17

But the other two things we have available are data parallelism

48:20

and pipeline parallelism.

48:24

They can be a much better fit for using multiple racks.

48:28

Let's focus on pipeline parallelism specifically.

48:32

This is one layer of MoE.

48:34

I'm going to have a hundred more layers up above.

48:39

I could decide at this point, for example, to move to a different rack, change rack.

48:50

Now, is that going to become a communication bottleneck?

48:54

We can actually solve for when this becomes a communication bottleneck.

48:57

Before we do that algebraically, let's

49:00

visualize it out and sketch the path.

49:01

We're going to have another MoE layer, and another MoE layer here, and so on.

49:09

Let's say I change rack here, and then some number of layers

49:12

later, I change rack here as

49:21

well.

49:21

The methodology we're going to use to determine whether we have

49:24

a communication bottleneck at the point where we change rack is we're

49:28

going to compare the scale-out

49:35

bandwidth requirements to the scale-up bandwidth requirements.

49:43

Let's write this.

49:44

The hint is going to be that there's a lot more sends here.

49:49

We're sending many things here, whereas we're only sending one thing here, and

49:52

we're also maybe doing it many times.

49:54

That's

49:56

what makes the difference.

49:58

Can I try to guess?

49:59

Just out of curiosity, to see if I'm actually understanding, it seems like

50:03

you're sending batch size into the rack.

50:07

In here?

50:08

Yes.

50:09

But the communication within the rack is batch size times number of GPUs.

50:18

Number of activated GPUs.

50:21

I don't send to this GPU at all.

50:23

There's an explosion from 1-3x larger here in this diagram.

50:29

The key thing is that I didn't even need to send to this GPU at

50:32

all, and so that's a big saving.

50:35

We're going to talk through

50:40

to what extent scale-up is a bottleneck over scale-out.

50:48

We will directly jump to the ratio of the time spent on scale-up

51:00

over the time spent on scale-out.

51:04

This is the quantity we're talking about.

51:09

The first consideration is that scale-up is

51:15

8x faster than scale-out generally.

51:18

At a baseline, if the bandwidths were the same, we would have this

51:21

1/8, which is coming from bandwidth.

51:28

But then we have some amount of expansion in how much data we're sending.

51:34

If one token comes in here, then this one token gets routed to, in the DeepSeek

51:40

case maybe 32 experts or 16 experts.

51:44

It gets routed to some number of experts.

51:47

So this is the number of

51:51

activated

51:54

experts.

52:03

This same thing applies on multiple different layers, so

52:05

maybe I'm going to run two layers.

52:08

There's also multiple times the number of layers

52:16

per stage.

52:19

Don't you need to multiply the whole thing by two for the all-to-all?

52:22

For the up and down.

52:23

Yes, there's a factor of two.

52:28

Thank you.

52:29

What we would like is for the scale-up time to be greater than the scale-out

52:33

time, because the scale-up time is the more important and precious resource.

52:38

We would like this number to be greater than or equal to one.

52:43

This really doesn't seem hard.

52:44

There's just a factor of 8 that we need to overcome.

52:46

So we need the product of these three things to be bigger than 8.

52:50

Typically we have a fairly large number of activated experts.

52:53

It could be 8 by itself.

52:55

Then we can increase the number of layers per stage a lot until we satisfy this.

53:01

What this ends up looking like is that I can have an entire pipeline of racks

53:05

where one rack does one layer, and then I move on to the next rack and

53:08

do another layer, and then I move on to the next rack and do another layer.

53:11

It's interesting to me that the best

53:15

parallelism strategy in practice ends up being one which physically

53:20

resembles the actual architecture.

53:23

It's not some galaxy brain thing.

53:25

It's like, "Oh, we have experts, we're going to put them on different GPUs,

53:27

or we have different layers, we're just going to put them on different

53:29

racks." I feel that's interesting.

53:31

The cutting matches the model architecture.

53:37

Exactly.

53:38

It could have been something wackier with tensor parallelism and whatever.

53:45

The galaxy brain way to think of it is,

53:49

what are all the different dimensions in which a model is scaled up?

53:54

It is scaled up by layers,

53:58

it is scaled up by the model dimension, it is scaled up by the DFF dimension, it

54:00

is scaled up by the number of experts.

54:02

Every single one of those numbers you can choose to cut along.

54:06

If those numbers are big enough, it eventually becomes

54:08

profitable to cut along there.

54:11

We have selected two of them.

54:12

The other two, in the way models are typically sized, are not profitable.

54:16

So there's a talk by Ilya where he says, "Today we know not

54:21

to do pipeline parallelism."

54:23

And Horace He gave my friends and me… I hate that it sounds

54:29

like a Dr. Seuss quote.

54:33

But he gave us a lecture on these different kinds of parallelisms.

54:36

He said the problem with pipeline parallelism is that, other than

54:39

the bubbles, it creates these architectural constraints.

54:42

Kimi,

54:45

for example, has these residuals where attention attends to layers a few back, so

54:52

it becomes hard to implement in this way.

54:56

I guess we didn't fully articulate even what is the

54:59

benefit that we're getting from

55:05

pipelining.

55:05

These complexities are real.

55:06

Pipelining is a massive hassle, but it does give you some benefits.

55:15

You can then decide whether those benefits are worth the costs.

55:22

It has some benefits in inference, maybe bigger benefits in training.

55:25

In inference, what are we saving on?

55:27

Are we saving on memory time or compute time?

55:31

Not really.

55:32

We're just moving the memory time from one chip to another chip,

55:35

or one rack to a different rack.

55:37

There's no actual benefit in runtime.

55:41

However, what we are saving on is memory capacity.

55:45

If we think that the memory in a rack is a bottleneck, then there's

55:51

a constraint on how fast we can go.

55:55

Pipelining allows us to massively reduce that bottleneck.

55:59

The opposite connotation to this… Before this interview, I was

56:06

chatting with Axel, who's a GPU performance engineer at Jane Street.

56:11

He was explaining that to do pipelining, you have to do

56:13

micro-batches rather than full batches.

56:16

If you do micro-batches, then you're by definition not able to

56:23

amortize loading the weights across all the users or all the sequences.

56:30

The positive connotation of that is you don't have to use as much memory.

56:32

The negative connotation is that we can't amortize loading the

56:36

weights across all those users.

56:37

Maybe it's worth explaining why you have to do micro-batches.

56:40

Shall we draw the pipeline bubble?

56:46

What is this micro-batching that shows up in pipeline

56:53

parallelism?

56:53

I'll focus on inference first.

56:55

It's a slightly simpler problem.

56:56

I'm going to draw time, and then which rack

57:06

we're on.

57:07

The idea is that maybe I'll have four racks.

57:10

I've got an inference that is going to step through these four

57:14

racks in some time like this.

57:19

This is inference number zero.

57:20

It

57:23

runs at a certain batch size and steps through all the pipeline stages like this.

57:28

Now, if we were to say, "Well, we're going to run inference number one

57:30

here," this is clearly a massive waste.

57:34

Like three-quarters of the time each of the racks is doing

57:40

nothing.

57:40

We don't actually run inference one here, we run it as soon as we can, which is

57:44

immediately after inference zero finishes.

57:47

And

57:50

then we keep going.

57:53

If we hadn't filled this in, we would call this the pipeline bubble.

57:56

When I've drawn it in this inference context where we're only going

57:58

in a forwards pass, it's obvious.

58:00

Why would you do this stupid thing?

58:02

In a training context, it's maybe less obvious.

58:05

But in the inference context, it's really natural to make this change.

58:09

Oh, interesting.

58:14

This is sort of obvious, but the difference between micro-batch and batch

58:16

doesn't matter at all in inference because you can just call it whatever you want.

58:22

It only matters in training because there is an optimal batch size.

58:27

Yes.

58:28

Before you do a full backward step, you want to have accumulated

58:33

all the sequences in that batch.

58:33

If you want to do

58:38

pipelining in training, in order to avoid that bubble, you need to—

58:43

Should we draw the training diagram

58:48

with that?

58:48

Let’s do that.

58:48

This is the inference diagram, and I'll call this forward so we don't

58:51

have the wrong thing showing up there.

58:53

Let's do the same thing for training now.

58:55

We've got a forwards pass, but at some stage we're going to have to

58:57

transition to a backwards pass.

59:01

We'll do some number of batches in the forwards pass,

59:11

and then we're going to transition to the backwards pass for everyone all

59:24

in one go.

59:24

The inference part is the same here, but then we do a hard stop at this point

59:28

and transition everyone to the backwards pass, with similar numbering like this.

59:33

It may be worth clarifying the reason there is that hard stop

59:35

is because you want to do a whole batch at once for the backward step.

59:40

And then there is an optimal size for how big that batch should be.

59:45

Smaller is always better, actually, is a way to put it.

59:48

From an ML convergence rate perspective, smaller is always better

59:53

because you're getting the freshest information from the gradient descent.

59:55

But from a total training time perspective?

59:58

From a total training time perspective, smaller is worse

60:01

from a systems perspective.

60:02

The optimum is the trade-off between those two.

60:05

So you pick a batch size, and

60:10

for that batch size, you do some amount forwards and then some amount backwards.

60:14

You asked why there is even a hard stop there.

60:16

With pipeline parallelism, because

60:21

you've got this idle time here which is the bubble, there are so many

60:26

techniques in the literature for how to lay this out differently and avoid that.

60:31

There are more complicated schemes called zero bubble or one-forward-one-backward,

60:35

which interweave the forwards and the backwards in complicated ways.

60:40

You can mine Bitcoin in that bubble.

60:42

Right.

60:42

More usefully, you can do the weight gradient step, but

60:46

you can also mine Bitcoin.

60:49

In inference, the effect of pipelining on anything you care about, like

60:55

batch size or latency, is neutral.

60:58

It doesn't improve it, it doesn't make it worse.

61:00

If you look at the latency of this inference, running it if it were pipelined

61:03

versus if it were all on one rack… If it were all on one rack, we would just slide

61:07

all the boxes down and still put them in a row, and the latency would be the same.

61:11

Pipelining

61:14

is neither better nor worse for latency.

61:17

It does mean that you just use less memory capacity per rack.

61:23

Because now instead of needing the whole model, you only need a quarter

61:25

of the model, and you can expand.

61:26

Makes a ton of sense.

61:27

So it's a no-brainer to use pipelining during inference, but there's this

61:33

harder trade-off during training.

61:36

Even in inference, in fact, it is not used a ton.

61:39

It reduces your memory capacity requirements, but there's

61:42

actually a huge surplus.

61:44

I think you were saying that a rack of Blackwell has many tens of terabytes.

61:52

That's much bigger than a trillion parameter model.

61:58

A trillion parameter model only needs one terabyte, so it already fits.

61:59

There's not a huge benefit from pipelining because you're reducing a

62:05

number that's already pretty small.

62:07

But it does say that theoretically, maybe you had too much memory there.

62:11

You could have

62:14

built different hardware that has less memory.

62:16

If you were designing your hardware, you could say, "I didn't need that

62:19

much memory because I don't need the weights to fit in one rack.

62:22

I can fit the weights in eight racks, then I could have built hardware that

62:28

didn't have so much HBM per GPU."

62:30

Last week, Horace He was kind enough to give me and my friends a great lecture

62:34

on large-scale pre-training systems.

62:36

And there were some concepts that I wanted to animate for a write-up

62:39

on my blog, like how weights shard and gradients flow depending on

62:43

the parallelism that you're using.

62:45

So I gave Cursor my lecture notes and a sketch that I made during the lecture.

62:49

And I asked it to visualize a specific hierarchical collective

62:53

that Horace had explained.

62:55

The first version was already pretty good, and then I was able to use

62:57

design mode to select and tweak any specific components from there.

63:01

I was able to do all of this without a clear end state in mind.

63:03

Cursor's Composer 2 Fast model was quick enough that I was able to

63:06

iterate almost instantaneously.

63:08

I could try an idea, test the results in the built-in browser,

63:11

and immediately make any changes.

63:13

I went through 10 different versions in under 20 minutes.

63:15

If you want to check out this animation, I published it along with

63:18

the lecture notes in a blog post.

63:20

The link is in the description.

63:21

And if you want to try out this kind of iterative design flow for yourself, go

63:24

to cursor.com/dwarkesh to get started.

63:31

everybody's talking about the memory wall right now.

63:33

Memory is getting super expensive.

63:34

There's not enough memory.

63:36

Smartphone volume will go down 30% because there's not enough memory.

63:41

This is shocking, Dylan said hyperscalers are spending 50% of

63:45

their CapEx this year on memory.

63:47

That’s believable.

63:50

What is hyperscaler CapEx?

63:51

That's high hundreds of billions, maybe a trillion, and they're

63:54

spending half of that on memory?

63:57

That is a huge constraint.

63:58

That's why we're not going to get new laptops and phones this year.

64:00

But at the same time, we have too much memory?

64:04

People are willing to put too much memory into these systems.

64:06

Why is Jensen shoving all this memory into these racks if you don't need it?

64:14

In the equations we had here before we erased them, we were doing memory time,

64:18

memory bandwidth and compute bandwidth.

64:20

Let's now start looking at memory capacity.

64:23

We'll start off with memory capacity without even thinking

64:26

about a parallelism scheme.

64:35

The demand on memory is the number of total parameters.

64:43

This is what we need to fit the weights in some system that we are using.

64:48

Then we need to fit the KVs as well.

64:51

KVs go as batch size times the length of the context

64:56

times the bytes per

65:06

token.

65:07

What I was arguing about in this context, and the case I was making

65:10

for pipelining, is that there are some techniques that allow us to solve this.

65:18

Let's consider running this on some number of GPUs.

65:23

We're going to have one extent, which

65:28

is E, the expert

65:32

parallelism.

65:33

When we had this sharding of an expert layer across many GPUs,

65:38

to what extent do we do that?

65:39

How many GPUs?

65:42

We're going to say that this is, for example 64.

65:46

Then P is going to be the extent of

65:52

pipelining.

65:52

This is the number of racks,

65:56

maybe we'll pick 4 or something

66:01

like that.

66:03

This is the total memory requirement across the system,

66:07

but now I'm going to calculate a

66:11

memory requirement per GPU.

66:20

I'll use a lowercase

66:26

cmem.

66:26

Obviously, we just take all of these numbers and divide

66:27

it by E and P. Really easy.

66:29

It's

66:32

this Ntotal, plus the batch times length of context times bytes per

66:42

token, all divided by E times P.

66:49

Why is this correct as divided this way?

66:54

We knew that the parameters were perfectly divided amongst all the GPUs in a rack.

67:00

The layers are perfectly divided amongst the different racks.

67:04

So that works here.

67:05

Somehow we're going to arrange—I'll hand-wave exactly how—the same

67:10

perfect sharding of the contexts across GPUs in a rack, and then

67:15

based on layer across racks.

67:17

Sorry, 4 is the number of racks?

67:18

Yeah, for

67:25

example.

67:26

This is the place where we actually need to go back and analyze this batch size B.

67:30

You were making this comment that there's micro-batching versus global

67:35

batching.

67:35

Let's come back to this pipelining diagram here.

67:38

We've got one batch going forward here, and then as I drew it,

67:42

it kind of just disappeared.

67:44

That's not really correct.

67:45

If you think about how decode is working, I have a bunch of tokens

67:50

that I have generated already.

67:52

I do one forwards pass where I generate a new token,

67:57

and then I write that to my KV cache.

68:00

Then I do another forwards pass that generates the next token.

68:04

I'm actually going to be running this batch zero in a loop.

68:06

In fact, I go forwards.

68:10

Once I finish, I can start the next iteration of the loop up here.

68:17

We'll just

68:29

fill this in.

68:29

We've got the two,

68:36

three, two and three, and two and three.

68:36

Let's split this batch.

68:38

This batch will be the global batch size.

68:41

B is going to be the number of micro-batches times

68:51

the batch size per micro-batch.

68:53

How many micro-batches do we need?

68:55

The number of micro-batches in this diagram is 4: zero, one, two, three.

69:03

The micro-batch size is still this 2000-ish number.

69:08

Sorry, no, this is the

69:14

300 times sparsity.

69:16

This is

69:22

how big the train that takes off every 20 milliseconds is.

69:23

Right.

69:23

This is going to be the 20-millisecond train.

69:29

The global batch size is the number of micro-batches times the local batch size.

69:33

Local batch size is set by this hardware parameter.

69:35

The number of micro-batches

69:39

is as small as possible, such that we can wrap around and not leave any idle time.

69:47

If we had fewer, we would have this idle time when we wrap around.

69:51

You can visually see that it is equal to the number of pipeline stages.

69:55

It's a proof by visual here.

69:57

It is 4, and it's 4 this way as well.

69:59

You can look and see that it goes along here, and then it wraps around

70:03

to the number of pipeline stages.

70:05

Sorry, very basic question.

70:06

Is this what is actually done?

70:10

A frontier model today will have pipelining during inference?

70:15

For sure during massive scale training this is done.

70:19

It can be done for inference.

70:21

I'm actually going to make the case for why it is less attractive.

70:24

It is useful for weights, but not so useful for KVs.

70:28

The

70:30

big challenge is... Let's fill this in.

70:33

The micro-batch size here ends up being equal to the number of pipeline stages.

70:40

When we go back and substitute all of that into here,

70:49

we get

70:52

a number of pipeline stages times this little b showing up in here.

70:59

When we factor this out, I'm going to split this plus into two terms.

71:04

We

71:08

get the full division by E times P over here.

71:12

We still have division by E times P over here, but the Ps cancel.

71:22

What we find is that if you increase the number of pipeline stages, the

71:26

memory footprint for the number of weights keeps going down and down and

71:29

down, but the memory footprint for the number of activations stays constant.

71:34

So it doesn't actually work.

71:37

Most of your memory…

71:40

Once you do enough pipelining—and it's really not much, even two is often

71:44

enough—this term becomes very small.

71:48

The KV cache becomes the dominant term.

71:52

I know this is wrong.

71:53

I'm just trying to think about why my train of logic here is wrong.

71:56

If

71:59

you're pipelining through many different stages, the KV values

72:02

are not shared between layers.

72:03

Why would it not help to be pipelining across multiple layers?

72:06

Because then you don't have to store...

72:08

You only need to store one layer rather than two layers of KVs.

72:12

It helps from that perspective, you're right.

72:16

What's competing with that, though, is that you need to be keeping all

72:19

of the racks usefully busy at a time, so the number of sequences that are

72:24

in flight simultaneously has gone up.

72:27

Ah, that makes sense.

72:28

Those exactly cancel, and you end up not getting a saving per GPU.

72:31

Right.

72:32

This is going back fundamentally to the point of how you're not

72:34

able to amortize across KV caches.

72:38

First, we established you can’t amortize KV caches across batch size.

72:41

Now we're saying you also can't shard it across pipeline stages.

72:48

It sucks from both of those points of view.

72:49

Interesting.

72:50

So then what is done during inference?

72:54

The DeepSeek paper reports what they do, which is that they just

72:58

do a lot of expert parallelism.

73:00

In effect, you should increase your expert parallelism up to your scale-up domain

73:04

size, and then do very little pipelining.

73:08

Maybe none at all, maybe two, just enough to make the weight

73:12

storage not too big of an issue.

73:15

Those are the only two parallelisms that really make sense.

73:17

In the past, there was tensor parallelism, which was cutting up within an expert,

73:24

but the experts are so small now that that is not a profitable optimization.

73:30

Does that mean that frontier labs, when they're doing inference, are

73:33

just within a single scale-up?

73:35

Yes.

73:36

You can look at how it depends on model size.

73:41

You could have a very large model,

73:46

one that exceeds the memory of a rack.

73:49

There you should be doing a bit of pipelining.

73:52

Maybe it's extremely sparse, for example, and that would be a reason to do it.

73:56

This goes back to the promise at the beginning of the lecture,

74:00

which was this will actually tell you about AI progress as well.

74:03

To the extent it is the case that model size scaling has been slow until recently…

74:10

Let me make sure I understand the claim.

74:12

The claim would not be you could have trained across more racks.

74:17

It was just that it would not have made sense before, we didn't have the ability

74:20

to do inference for a bigger model easily.

74:24

Actually, pipelining doesn't help with context length.

74:29

It totally helps with model size.

74:31

Because of the ability to do pipelining, a rack at least should

74:36

not be a constraint on your ability to fit the model parameters.

74:40

The other consideration you're asking is, why hasn't it scaled up more, and

74:43

why did bigger scale-up domains help?

74:46

We talked through one aspect of that, which is that it's

74:49

not because of memory capacity.

74:52

We have a solution to the memory capacity at least with respect to model size,

74:55

not with respect to KV cache size but at least with respect to model size.

75:03

The other issue that shows up is latency.

75:06

I was just about to ask, going from rack to rack, what is the latency cost per hop?

75:13

This is very much dependent on the hardware.

75:20

I can't say with a lot of authority.

75:21

I think it's probably on the order of a few milliseconds, but it could be

75:24

off by an order of magnitude there.

75:26

Is 4 a realistic number of how many pipelining stages you might have?

75:28

Yes.

75:29

So that's not that much.

75:31

On a small number of pipelining stages, this is not a huge latency impact.

75:35

But I guess it's 10 milliseconds per token.

75:39

That's right.

75:39

2 times 4-ish, or I don't know how many you said… 10 milliseconds

75:45

per token is actually a lot.

75:46

If it goes from 20 to 30, or something

75:50

like that…

75:50

Just to chart the path that it goes through, here you're going from your

75:56

GPU or TPU to a network card, which

76:04

then goes to a top-of-rack switch,

76:08

and then hops over to the other rack and does the same thing in reverse.

76:12

You have to sum up the latencies of these different things.

76:15

Sorry, is this the same thing as the data center switch?

76:18

It may in fact go up to a data center switch and back.

76:21

It depends on deployment configuration.

76:22

Got it.

76:24

And because it's decode and sequential,

76:30

they stack up across the stages.

76:32

You can't do them at the same time.

76:34

That’s right.

76:36

This brings us back to the question then, is the size of the scale-up at

76:39

all relevant to why AI model sizes have been what they have been over

76:44

the last few years, whether through training or through inference?

76:48

We talked about latency of the hop.

76:53

There is also just the tmem

76:56

latency.

76:57

The memory time latency is actually massively improved

77:02

by larger scale-up domains.

77:06

I'll recall tmem down here.

77:07

tmem for the weights

77:18

was equal to the number of total parameters

77:24

divided by the memory bandwidth.

77:28

Which memory bandwidth are we talking about here?

77:30

Is it just one GPU?

77:32

It

77:34

is the number of GPUs that I can use in parallel to load these weights.

77:40

I can't use different pipeline stages in parallel because they're not

77:43

running at the same time, but I can use all the GPUs in my scale-up domain

77:46

in parallel to load the weights.

77:50

This is actually extremely effective.

77:54

Basically, I end up with a term here, this memory bandwidth term

77:57

itself is equal to scale-up size...

78:03

Times memory bandwidth per GPU.

78:05

Yeah.

78:05

Times GPU

78:09

bandwidth.

78:10

This term doesn't increase a lot.

78:11

It maybe increases 1.5 or 2x per generation, but this one increased

78:14

by a factor of 8 from Hopper.

78:16

So the reason the bigger scale-up matters, it's not the memory capacity of the whole

78:19

scale-up, but really the memory bandwidth.

78:21

Yeah.

78:22

Pipelining totally solves the capacity problem, but

78:27

scale-up size helps solve the bandwidth problem.

78:30

And the bandwidth problem helps you do longer context lengths,

78:34

which is more and more relevant as these models get more agentic.

78:37

It lets you just run the model at lower latency as a first thing.

78:41

If I just do a very sparse model and it's on a little H100 box,

78:46

the latency will be really high.

78:49

A super tangential question.

78:53

There's Chinchilla scaling, which tells you how big a model should

78:57

be relative to the amount of data you're going to train it on.

79:01

But now, obviously, you're not just trying to optimize for the highest quality model

79:07

you could get with training compute.

79:09

You want the best results a user can get with a mixture of

79:11

training and inference compute.

79:14

So there's a question of how much you should over-train a model

79:18

such that compute amortized over training and inference is minimized

79:23

to get a certain performance.

79:24

But now with RL, there's another consideration which is, you're going

79:30

to do some amount of pre-training.

79:32

That pre-training will be used both for RL generation and then

79:36

for inference for the final user.

79:39

By over-training here I mean that while it would have been more efficient just from

79:42

a training compute perspective to have a bigger model that you train for less time

79:46

because it can learn faster, maybe you get a smaller model, spend more compute

79:50

training it than you otherwise would have, but now it's cheaper to give it to users.

79:55

Let me make the question more concrete.

79:56

Basically, how much more than Chinchilla optimal are models over-trained?

80:00

And has that changed as a result of RL generation?

80:03

This is a place where we have to do a bit of guesswork because the updated

80:07

scaling laws and the model traffic are not reported, so we have to guess there.

80:14

One way to look at it…

80:19

Let me first just make a general heuristic claim.

80:23

If I have some cost, and I've got a total cost which is a sum of cost A

80:30

and cost B, like maybe this is the training cost and this is the inference

80:34

cost, and I want to minimize this sum…

80:39

For many

80:42

curves, the minimum tends to be where the costs are equalized.

80:47

That's something of a heuristic claim, but

80:52

there are many examples where it's true.

80:54

Where one is 1/x and the other one is x, for example, they tend to be minimized

80:59

at the point where they equal each other.

81:03

It's also true for ex and e-x and all kinds of other things.

81:10

Basically, I've got some curve that's going down, some other curve

81:14

that's going up, and they tend to be minimized at this equal point.

81:17

Heuristically,

81:21

I will conjecture that that is true for the setup you described as

81:27

well.

81:28

Actually showing that would be true would require looking at the

81:30

scaling laws and fitting these weird exponents, but things that follow

81:37

power laws tend to have this property.

81:39

So I'll just make that claim and move on.

81:43

We're going to say that we want to equalize the cost of training

81:47

and the cost of inference.

81:56

We can do all of it in general.

81:58

The cost of pre-training, that's the number of

82:05

active params times the data on pre-training.

82:13

There's a factor of 6 out here, which is the number of FLOPs.

82:16

There's the famous 6ND formula.

82:18

Then

82:20

in RL, we have approximately the same thing.

82:24

We've got the same number of active parameters, but now the

82:28

amount of data is the RL data.

82:31

There is this extra efficiency multiplier, or inefficiency...

82:42

Which is the fact that you're not training on all your rollouts.

82:45

Well, there's that, and then the other, perhaps even bigger

82:49

inefficiency is that this involves a substantial amount of decode.

82:54

Often decode runs at less MFU than training.

82:58

Okay.

82:59

So if you're doing a backward pass on every single generation

83:03

in RL, it would be 6ND.

83:06

So this could be a smaller number, right?

83:07

It

83:09

would at least be two, because that's the lower...

83:11

Somewhere in the range of two to six.

83:12

We'll say somewhere in the range of two to six and leave it

83:18

at that.

83:18

Then we can add in the inference cost.

83:20

The inference cost is two, the number of active parameters

83:24

times the data in inference.

83:28

Sorry, I think the way I said it was super garbled.

83:30

Just for the audience,

83:33

forward plus backwards per parameter is 6.

83:37

Forward alone is 2.

83:39

That's why RL, where you're definitely going to generate all the trajectories

83:43

but you might or might not train all the trajectories, is 2 to 6.

83:46

Yes.

83:48

Thank you.

83:48

And then inference is just 2.

83:51

We're going to solve for essentially equality of all three of these terms.

83:54

That is the ballpark of where people are going to be.

83:58

Labs have more information on what is productive in doing more RL, for

84:03

example, versus doing more pre-training.

84:04

I don't have that information, but I think a good ballpark is a

84:09

33% split between each of them.

84:11

I'm not sure I understand the intuition for that.

84:15

Another naive model could have been that RL plus pre-training would

84:17

be 50% and inference would be 50%.

84:20

That's also a valid answer.

84:24

Because this is heuristic, I can't really argue for one versus the other.

84:27

They don't differ by that much.

84:28

Thirty-three versus twenty-five is only a small factor off.

84:36

Let's pick one of them.

84:38

All equal seems simple enough,

84:42

so we're just going to solve for equality of them.

84:44

It's pretty straightforward.

84:45

We can immediately see that the number of activated parameters totally

84:47

disappears, so let's factor that out.

84:49

We're going to just say that data in pre-training—I decided to do it your

84:55

way, it's a little bit nicer—plus...

84:59

Oh,

85:02

I didn't have the inefficiency over here either.

85:04

Data

85:08

in pre-training plus some multiple of α times the data in RL is

85:17

going to end up equal to some β times the data in inference.

85:28

Let's just roughly size the α.

85:30

This α

85:37

is maybe somewhere in the range of 2 to 6.

85:40

Over 6, from this term compared to this term.

85:44

And then we've got an inefficiency term, which I would say is

85:47

maybe in the range of 30%.

85:50

So this alpha is going to be something like

85:59

1/10.

86:00

And this β here is actually the same.

86:02

It's a third.

86:03

It's one third times 30%.

86:05

So it also equals 1/10.

86:11

If both of them are one in ten, that kind of implies that there's

86:13

never a backward pass on RL?

86:15

Yeah.

86:15

Okay, we can make this 2/10.

86:17

Make it a bit bigger.

86:20

Just write it out once more, this is 2/10, this is 1/10.

86:27

The number of inference tokens you have is just a function of hundreds

86:32

of millions of tokens per second times my model is deployed for two months

86:37

before I ship to the next version.

86:40

That should determine

86:45

the number of tokens in RL and pre-training.

86:48

I guess we didn't do the equivalence between pre-training

86:50

and RL, so we'll do that here.

86:52

Data in pre-training should be equal to 2/10 data in RL for

86:57

them to be cost equivalent.

87:03

Sorry, 1/10.

87:04

I got it backwards.

87:06

We pay more cost when it's inefficient, so this needs to be 1/10.

87:15

Tracing this back… This thing ends up actually being, as written here…

87:21

This is like 1.5, and this is one.

87:28

Billions of dollars worth of compute just flowed in the other direction.

87:31

Right?

87:33

I think if you do it with a spreadsheet and actually model

87:35

it out, you might notice when the money’s going down the drain.

87:42

All of these end up being close in, as modeled here.

87:45

This 30% may have been a little bit too generous.

87:47

So let's say something like 1.5 here, and leave this as a one here.

87:53

I think at this point, you can almost read it off.

87:56

The number of inference tokens should be about the same as the number of

87:58

pre-training tokens, which should be about the same as the number

88:00

of RL tokens, within factors that we're not able to reason about.

88:08

Sorry for making a basic algebra mistake.

88:09

It seems like there should be fewer RL tokens than pre-training tokens?

88:12

That's in general right.

88:13

Because RL is less efficient in terms of machine time,

88:22

if you're trying to equalize the RL and pre-training time, then

88:24

you should have fewer tokens in order to have the same wall time.

88:28

This is all quite interesting.

88:31

I never thought about it in terms of equalizing data.

88:35

I think starting with equalizing in cost is right, but depending

88:39

on how you model the cost, this comes close to equalizing in data.

88:42

So for GPT to be trained optimally, every single user who uses GPT-5,

88:51

the total amount of tokens that they stream should equal the total amount

88:53

that has gone into pre-training.

88:54

And the total amount of tokens that have gone into pre-training

88:58

is the sum of all human knowledge.

89:01

Each model should generate the sum of human knowledge on the

89:04

output that it gets on the input.

89:06

Yeah.

89:07

Which way are people going to err?

89:08

If you think that people's power of prediction is not perfect, and also you

89:14

run the risk that you make a model that is not a frontier model and then you

89:19

just throw it away, then that changes the cost trade-off because there's some

89:26

probability that applies to the inference.

89:28

And you should derate the inference tokens by some amount.

89:30

Right.

89:31

Can we back out how much more compute than Chinchilla optimal

89:37

for a given sized model?

89:40

I think we just have to make some real-world assumptions here in order to do

89:45

that.

89:46

The inference tokens, we should totally be able to count, right?

89:49

Let's say a few hundred million.

89:51

Maybe it's five hundred million tokens a second now, I don't really know.

89:56

Five hundred million tokens a second times.

89:58

A model is deployed for two months before it becomes obsolete?

90:02

I

90:05

can't do this in my head.

90:06

Can you type it into a computer?

90:08

2.6 x 1015.

90:15

Okay.

90:15

2.6 x 1015.

90:20

This number is probably too large because this is going to

90:23

be multiple models in a family.

90:25

Let's make it

90:30

5x smaller or 10x smaller or something

90:33

like that.

90:35

So we're estimating maybe fifty million tokens per second, per specific model.

90:41

The model is live for two months.

90:46

This comes out to around two hundred trillion tokens.

90:50

And then we want to compare that to active parameters on a frontier model.

90:55

I don't actually know the latest rumors.

90:57

Do

91:00

you know?

91:01

Somebody told me a hundred and fifty trillion.

91:03

Active parameters?

91:04

Sorry, I meant tokens.

91:06

Trained on a hundred and fifty trillion tokens.

91:07

Interesting.

91:08

Which is similar.

91:09

That's actually similar.

91:11

So data on pre-training.

91:12

This is not well-cited but it’s fine.

91:17

I think often the number of active parameters

91:21

could be in the range of

91:30

a hundred billion, something like that.

91:31

Maybe a bit larger.

91:31

So multiply by 20 to get the Chinchilla token count.

91:34

So Chinchilla, DChinchilla, would be around two trillion.

91:43

We see we're about a hundred times larger than that.

91:47

What does DChinchilla actually mean?

91:48

The token count for pre-training

91:53

that the Chinchilla scaling law would recommend, I guess.

91:56

Oh, I see.

91:57

So how much is it over-trained?

91:59

Got it.

92:00

The ratio of this two hundred trillion or a hundred trillion parameters over the

92:07

Chinchilla optimal of two trillion, that's the amount it's over-trained.

92:10

Which is a factor of a hundred over-trained.

92:12

A hundred.

92:14

So if you consider this right here, to the extent this is in the right

92:16

ballpark, just by thinking about how you want everything to be equal in terms of

92:22

compute… If OpenAI also realizes that and they're serving a certain amount of

92:28

tokens per second, that tells you how much data went into the pre-training of GPT-5.

92:34

Even if it's 50% off or something, it is wild that you can first-principles

92:40

these kinds of numbers.

92:41

This is why you should just approximate everywhere, because

92:44

there are big error bars on this.

92:45

But it's kind of empowering to just set A equal to B and figure it out.

92:49

That's super cool.

92:51

Okay, so in the spirit of trying to deduce things, we can publicly look

92:56

up the API prices of these models, and maybe we can learn something from that.

93:03

First, with longer context, Gemini 3.1 is 50% more expensive if you go over 200k

93:15

tokens than if you're below 200k tokens.

93:21

At a high level, I understand why that might be, but why specifically

93:26

50%?

93:27

Why specifically 50%?

93:30

The high level, even in the first place, is that there is some amount of

93:36

increasing cost with context length.

93:42

We can bring that back up.

93:43

That was

93:46

the memory time versus the compute time.

93:50

We've put up these same equations from before, of the time for memory

93:54

fetches which is the weights and the KV cache, and then the time for

93:58

the compute which is just the matrix multiplications for the weights.

94:03

I will also draw the cost curve, but this time I'll do it as a function of

94:05

context length instead of batch size.

94:05

So this is the cost curve as

94:13

a function of context length.

94:26

We'll draw the compute.

94:28

The cost of the compute is actually constant as a function of context length.

94:31

There's no dependence here on context length.

94:33

In reality, there is some dependence, but it is very mild, so we'll ignore it.

94:38

So this is the time for the compute.

94:48

Then we'll also draw the dependence of the memory fetch on context length.

94:53

This starts at a large number for the weights and then grows

94:56

gradually with the context length.

95:00

Maybe starting here, and then grow gradually with context length.

95:04

And

95:09

so, you take the maximum and you see there is this inflection point here.

95:13

So this is the cost that Gemini might be paying.

95:18

And then you think, how might you put a pricing structure on top of that?

95:23

You would like to ensure that no matter what the context length

95:25

is, you are still profitable.

95:30

So we've got a two-tier pricing structure.

95:31

Maybe we've got something that looks like this up to some extent.

95:36

I think it says something about, given that the bump is at 200k, it

95:41

probably means that this is somewhat aligned with this crossover point.

95:44

Maybe not exactly aligned with it.

95:47

We can actually probably even complete that calculation just

95:50

to see where it lands out.

95:53

We can solve for the number of bytes per token if we make some assumptions

95:58

about the number of active parameters.

96:01

So solving for the number of bytes per token, we're going to assume

96:05

the point where we equalize the time of memory and the time of compute

96:08

is at, let's say, 200k tokens.

96:12

So we equalize these two.

96:14

We're also going to assume that the batch size is large enough that the memory

96:20

time spent on weights is negligible.

96:22

So we'll forget about this, and we'll focus on the actual

96:25

memory time spent on KV cache.

96:29

That ends up saying, copying this term over, batch times length of context times

96:36

bytes per token over memory bandwidth

96:44

is going to be equal to the number of activated

96:49

params over FLOPs.

96:54

And then we're going to solve for bytes per token.

97:18

Batch size was missing here.

97:20

It shows up here, and then it cancels out by the time we get to here.

97:28

And I dropped the length of context.

97:35

So we can plug in numbers.

97:36

This is the reciprocal of the number that we saw before.

97:40

This is 1/300, which is reasonably stable across many different hardware platforms.

97:47

We conjecturally said that maybe the number of activated

97:50

parameters is a hundred billion.

97:54

The length of the context we said was 200k.

97:59

Something is wrong here, though.

98:01

Length of the context should

98:20

be on the denominator, not the numerator.

98:22

1667. Almost two kilobytes.

98:23

That is plausible, actually.

98:27

You said around two kilobytes.

98:35

Let's just do a sanity check for what this could be.

98:38

There are two mechanisms that people do attention with a

98:42

small number of bytes per token.

98:44

One is dense attention with a lot of reuse across layers.

98:50

Character AI has a blog post talking about that, alternating long and short context.

98:56

In the Character AI kind of model, which also showed up in the Gemma

98:59

models, the global context—which is really what we're talking about

99:03

here—was shared across all the layers.

99:06

To get this to kilobytes, you could get that, for example, as

99:09

a dhead of 128, which is typical.

99:14

Then

99:16

the number of bytes is typically the number of attention layers

99:26

times two times dhead times the number of KV heads.

99:39

This is the number of unique contexts per layer.

99:43

Do you share the context across many layers, or do you use it only once?

99:49

In the Character AI-like models, this number is one.

99:54

We said this is 128.

100:00

This is a choice which typically ranges from one...

100:03

Sorry, this is KV heads, I meant.

100:06

The difference between a head and a KV head is that…?

100:08

The KV heads are the heads that are stored in memory, store the

100:13

contents of the previous tokens.

100:14

The Q heads are the retrieval heads.

100:17

They're only used temporarily and they’re used by the attending token.

100:23

In this autoregressive context, I've got KV heads associated with all

100:26

of the contexts, and then Q heads associated with this new token here.

100:30

But this head, the

100:36

128.

100:37

Oh, sorry.

100:37

This d-head is the dimension of the vector.

100:39

The

100:41

number of KV heads is typically in the range of 1 to 8.

100:47

It is totally plausible to get this by, for example, having 8

100:50

KV heads and a d-head of 128.

100:52

That gives you exactly this number.

100:54

Or you could have fewer KV heads, but more layers.

101:00

This is one way to get there via dense attention.

101:02

There's also a way to get there via sparse attention, where you

101:04

increase all of these numbers, but then you have a 1/sparsity term.

101:12

I think this number is plausible, if maybe a little bit small.

101:15

It's funny that they would leak so much information through their API pricing.

101:18

I mean, you are incentivized to price close to your costs because

101:22

otherwise someone could scoop you.

101:24

Maybe we can learn something about the difference in input versus output

101:26

prices, and what that tells us about decode versus prefill in these models.

101:33

I think last I checked it's 50% more expensive or something like that?

101:38

I don't remember.

101:39

What I've seen in the past is 3-5x more expensive.

101:42

Okay, that makes more sense.

101:42

So let's say it's 5x more expensive.

101:45

This is the compute to process the next token in decode.

101:50

Suppose you're doing prefill, where you're not just processing the most

101:54

recent token, you're processing all the tokens in parallel.

101:57

I want to say that it would be this times length prefill?

102:05

Or length of the pass in general.

102:10

If we can think of decode as being a pass with one, and then

102:13

prefill being a pass with many.

102:14

Okay.

102:16

So maybe prefix?

102:17

Okay,

102:20

memory.

102:22

You're not storing the KV cache for the tokens that are the prefill tokens.

102:28

Let's actually draw how prefill shows up here, if I may clarify.

102:33

We do a bit of decode like this.

102:37

We may actually come back and do more prefill.

102:40

If you think this is a chat session, the user says something, the AI generates

102:44

a response, and then the user says something else and we prefill this.

102:48

Maybe this is the general case, rather than this.

102:52

In fact, this is like you read a file or something.

102:54

Read a file or the AI is responding to a user input, tool call, or

102:58

anything that's not AI-generated.

103:01

Okay, suppose we're here.

103:11

You will have calculated all of this previously.

103:14

So just the KV of everything that came before.

103:19

But what is the memory cost of this?

103:22

Well, the

103:26

memory bandwidth cost of this.

103:28

If you're doing flash attention, it would—

103:31

It's basically temporary.

103:33

It doesn't even go to main memory.

103:34

Just ignore that.

103:35

Exactly.

103:35

So then it would just be everything that came before.

103:39

Is it not just that then?

103:41

There's actually no adjustment at all to the memory time.

103:42

Okay.

103:43

Great.

103:43

So it's a very trivial change to accommodate.

103:47

This

103:50

term is making it 5x more expensive.

103:52

Now, why would that be?

103:53

What

103:57

does that actually tell us?

103:58

What variable does this help us clamp?

104:00

The

104:05

only thing that could have changed is that the compute is

104:06

5x more expensive as a result.

104:09

This is the time for one pass, but actually the amount of

104:12

tokens is that much larger.

104:14

We want the cost per token, in fact, or the time per token.

104:19

I'm not sure I understood.

104:20

This

104:24

is for processing the next token in prefix?

104:27

Well, actually for processing the entire batch.

104:31

At this cost, we have processed this many tokens, the length of prefill.

104:34

Or I guess the length of the pass.

104:34

Not this prefix, but it's this cost.

104:34

Okay.

104:34

Let's just do this pass.

104:34

So this is 5x more expensive.

104:34

Input is 5x more expensive.

104:34

Output is more expensive, in fact.

104:34

Output is 5x more expensive.

104:34

The result we want to work towards is that prefill is compute-limited and

104:34

decode is memory bandwidth-limited.

104:34

Why don't we do this?

104:34

Why don't we just chart it with len-pass

105:11

on the X-axis and t on the Y-axis.

105:17

We want the cost per token, so it'll be t over

105:22

length of the pass.

105:28

That'll be

105:32

right.

105:46

I

105:49

guess I’m getting confused by this.

105:50

Len-pass is... It seems like this should be higher when you're doing prefill.

105:56

Prefill has a bigger length pass.

105:57

Yeah.

105:59

But then why is it cheaper?

106:01

Why is the cost higher?

106:06

It's this division by length pass.

106:12

This is going to divide out, but then all of this is going to divide

106:16

by length of pass, and it's going to make the memory costs cheaper.

106:19

Okay.

106:21

Let me think about this then.

106:21

Basically we'll have four different lines.

106:22

Let's do

106:31

prefill first...

106:34

Actually, let's do decode first.

106:39

Length of the pass, when it's one, that is decode.

106:42

When it is bigger, that is prefill.

106:44

Oh, okay.

106:45

I see.

106:46

That makes sense.

106:47

Getting back to it.

106:48

So tcompute, if you have basically just this divided by

106:52

len-pass, so just this amount.

106:55

This actually does not vary based on t, so it'll just be some flat value like this.

107:03

And this is tcompute.

107:09

And

107:12

this is—

107:12

That's decode.

107:13

Decode.

107:13

Right.

107:15

Now tmem, we have this whole thing divided by len-pass.

107:18

Well, it doesn't really matter what's up there, it'll just be

107:21

something that looks like this.

107:25

Let's say this is tmem.

107:31

This is decode again.

107:33

So as the length of the prefix goes up, or pass, your memory bandwidth time

107:46

declines, and that means that to the extent that you were bottlenecked on

107:51

memory bandwidth before, you can avoid being bottlenecked on memory bandwidth.

107:56

The fact that they are charging 5x less for prefill than decode does suggest that

108:04

they are bottlenecked on memory bandwidth to quite a degree, such that for them at

108:08

least—because t is equivalent to cost, it's the cost of renting a compute—this

108:16

would be at 1, and this would be at 5.

108:18

That's right.

108:20

So it is, in fact, tremendously memory bandwidth bottlenecked.

108:23

The real graph looks something like

108:29

that.

108:30

It still crosses, but yeah.

108:32

Exactly.

108:33

Let me do it this way.

108:35

This is

108:44

the gap on decode between the memory and the compute time.

108:50

Okay, interesting.

108:52

Another interesting one would be why cache hits are so much cheaper.

108:58

If I remember correctly, cache hits are like 10x… It's more expensive

109:02

to write to cache according to the pricing on all these models.

109:06

But if you do hit a cache, it's

109:13

10x.

109:15

Presumably, this is the cost of keeping something in HBM

109:19

rather than just evacuating it.

109:22

But if you do keep it in HBM, then it's cheaper to load again?

109:25

Right.

109:26

There are two ways you can produce the KV

109:30

cache for a token.

109:31

You can just produce it from scratch by computing it from the

109:34

underlying token IDs, which are tiny.

109:37

Or

109:40

you can previously have produced it and stored it in a memory

109:45

somewhere.

109:45

The cost ratio is really talking about the ratio between those

109:48

two mechanisms of producing it.

109:49

A cache miss means you've deleted it from all your memories, and you have to

109:53

recompute it from the tokens directly.

109:55

You can even take that a step further and think about which

109:59

memory tier you store it in.

110:01

You could store it in HBM.

110:03

There are other slower and cheaper memories than HBM, like DDR

110:07

on your host or flash as well.

110:11

One of the things you can do is a calculation of where it makes sense to be

110:17

in each memory tier, and this is related to how long you're going to store it for.

110:24

We want to look at the cost of storage in a few different memory tiers and

110:27

also the cost of rematerialization.

110:32

Remat means the cost to rebuild all of the KV cache from scratch after you

110:38

deleted it, so we rematerialize it.

110:42

Basically, this is going to cost the length of the context.

110:48

Actually, we'll look at the cost per token, so we don't need to carry around

110:52

this length of context everywhere.

110:54

To rematerialize one token of KV cache, I just need to run a

111:02

forward pass on the whole model.

111:07

This is going to be the compute time.

111:08

I have to rerun the compute at whatever speed my GPU does it, and then I

111:13

multiply it by my GPU dollars per second.

111:19

Sorry, excuse a naive question.

111:21

Why is there not a quadratic term?

111:24

There is a quadratic term.

111:27

It shows up in the compute.

111:35

As an approximation, I chose to remove it.

111:39

I'll just show you quickly what that looks like.

111:42

If you look at the

111:47

cost per token, or the number of FLOPs per token, there are the FLOPs

111:52

that are coming from doing the weight matrix multiplies as a function of—

111:56

Which is flat.

111:56

...context length.

111:58

And then there is the number of multiplies that comes from doing the

112:01

KV cache, which goes up linearly with the amount of stuff you attend to.

112:07

The slope on this is so low that when you draw it like this, it's very

112:11

well approximated by a flat line.

112:15

You start to notice the effect of the quadratic or the linear term

112:18

up in the millions of tokens or so.

112:20

So it's just not super relevant.

112:22

So what is the reason that there's no company which has over a million

112:26

token context length, if this is true?

112:30

There are two costs of long context.

112:31

One is the memory bandwidth cost, which we've spent a lot of time analyzing.

112:34

That's this thing.

112:37

The other one is the compute cost.

112:39

The compute cost is almost always forced by

112:45

fundamental principles to be a much smaller slope than

112:49

the memory bandwidth cost.

112:52

The primary things that limit you to really large contexts are memory

112:56

bandwidth and memory capacity, which is exactly this effect.

113:01

There's this idea that Dario said on the podcast, and others have said, which

113:04

is, "We don't need continual learning for AGI, in-context learning is enough."

113:09

If you believe that, then you have to think that we have to get to a

113:12

hundred-million-token context length to have an employee that is the equivalent

113:17

of working with you for a month.

113:19

Now, maybe that's no longer true with sparse attention or something.

113:25

But if you think that, then some ML infra thing would have to change to

113:29

allow for a hundred million, like the memory bandwidth, to allow for a

113:33

hundred-million-token context lengths.

113:36

Sparse attention gives you a get-out for sure, because you get this square root.

113:40

It gives you a big improvement.

113:46

But if you look at the history of context lengths of models,

113:54

from earlier models like GPT-3, maybe to GPT-4—I don't remember when the

113:59

transition happened exactly—they shot up from about 8K to 100-200K.

114:04

And then for the last year or two, they've all been hovering around there.

114:08

I think that indicates that this is the reasonably balanced cost

114:13

point, and going massively beyond that would be cost-prohibitive.

114:17

Not because of the compute cost, because of the memory bandwidth...

114:19

Because of memory bandwidth cost, yeah.

114:24

I actually don't see a very good path to solving that.

114:29

The HBM is where it is.

114:34

It's not getting hugely better.

114:35

And why doesn't sparse attention solve it?

114:38

Sparse attention is a big improvement.

114:39

Maybe that is priced in already, perhaps.

114:44

It's not an infinite improvement because if you go too sparse,

114:47

you lose too much quality.

114:49

The empirical result is that the context lengths haven't been increasing that much.

114:53

I think it's because there is no solution to the memory wall here.

115:00

Going too sparse just means you're attending to a very small subset of the

115:03

tokens, and the quality will get worse.

115:05

Makes sense.

115:05

What is the cost of these different ways of

115:10

resynthesizing the KV cache?

115:13

Computing it from scratch is based on my GPU time.

115:15

I have to do a certain amount of multiplies, of GPU time

115:18

that I spend in order to

115:22

produce it.

115:25

Storing in

115:30

HBM.

115:33

This really goes as my

115:36

bytes per token.

115:39

I need to just have some number of bytes per token, and then I

115:44

need to store this in the HBM.

115:46

It's going to use up some of my HBM capacity.

115:50

A way to think of this is that if I have too many of these things

115:55

sitting in my HBM, if I fill up my HBM with just KV caches that I'm

115:59

not using, I can't use that GPU.

116:02

How do I price that?

116:03

Maybe I say that the cost of it is proportional to the

116:06

fraction of the HBM I'm using.

116:08

There's also times GPU dollars.

116:14

Let's just do one more memory tier and say

116:17

store in DDR instead.

116:23

The same kind of thing goes up for flash and for DDR.

116:27

I put these in the wrong columns.

116:29

I meant to make two columns.

116:32

The distinction I want to make is that there is the cost to

116:35

retrieve, and then there's a cost to

116:46

hold on.

116:49

This is a cost per second, whereas this is an instantaneous cost.

116:55

Rematerialization has a cost to retrieve and has zero cost to

116:58

store it because we've deleted it.

117:02

This is the one that I put in the wrong location.

117:04

This is actually the cost just to hold on, so I will rewrite it.

117:27

If we're just storing it in HBM, it has this sort of cost profile.

117:30

If

117:34

we store in DDR, it's actually going to take some time.

117:38

We get the same thing here:

117:41

bytes per token over DDR capacity

117:47

times DDR cost per

117:53

second.

117:53

But now this has a cost to retrieve that is higher than the HBM because

117:58

we need to copy it into the HBM.

118:00

So this is bytes per token

118:06

over

118:09

DDR bandwidth.

118:11

And then this consumes some amount of the DDR as well.

118:14

And every scale-up has DDR and flash?

118:17

This is really a deployment question, so you can choose that.

118:20

Nvidia does deploy in this form.

118:23

It has both.

118:24

Why isn't the cost to retrieve HBM

118:28

the bytes divided by memory bandwidth?

118:30

It depends what you define a retrieve to be.

118:32

Here, I'm defining retrieve to be, move it into HBM so that you can

118:37

start actually doing inference on it.

118:40

Because if it's already in HBM, you can be doing compute while

118:43

you're getting it from HBM to SRAM?

118:44

Interesting.

118:44

Yeah, for example.

118:47

These are three things, and I guess I ordered them wrong.

118:50

In general, if you're balancing two costs and you've got different

118:54

tiers in the memory hierarchy, you should expect as this cost

118:58

goes up, this cost should go down.

119:01

You can kind of see where the zeros are.

119:06

I should have ordered them this one first, this one second, and this one third.

119:12

If you're going to hold onto it for a very short amount of time, then all of

119:18

this is multiplied by the hold time.

119:24

This one is, and so is this

119:29

one.

119:29

Interestingly, they have different prices to write for.

119:32

Do you specify this in the API for five minutes versus an hour?

119:38

Which suggests that the five minutes is HBM and the hour is DDR.

119:41

I think that's a pretty good assumption.

119:44

If you look at the numbers, it might also turn out that it's one tier

119:47

down, and it's DDR versus flash.

119:50

Interesting.

119:50

I'll look up the price difference.

119:59

The base input tokens is $5 per million tokens.

120:04

Base, which means remat.

120:04

This is $5.

120:05

That's $5

120:08

to "retrieve".

120:12

And then to write,

120:21

presumably HBM, for five minutes is 6.25.

120:25

We might be able to determine which memory tier it is by the durations.

120:35

Five minutes versus one hour.

120:37

Exactly.

120:37

I think this will probably end up being

120:42

the drain time of the memory tier that you're in.

120:45

What that means is,

120:49

given that I know I'm going to be holding something for

120:51

five minutes, I would like to

120:55

pick a memory that I can read every five minutes.

120:58

I can read the whole memory once per five minutes, ballpark.

121:01

That is the drain time of the memory.

121:02

So if I take the storage capacity over storage bandwidth,

121:11

I would like this to be equal to five minutes.

121:16

We did this calculation for HBM.

121:17

For HBM, we know that this number is 20 milliseconds.

121:21

So HBM is much

121:26

too small.

121:27

DDR could be about an order of magnitude or two off from this, so

121:30

this is probably on the order of

121:34

seconds, like 1 to 10 seconds.

121:40

I don't have these numbers memorized, but generally, as you go to

121:42

slower tiers, flash is plausibly on the order of one minute.

121:46

And then spinning disk, which is massively different, is on the order of one hour.

121:52

So this might actually identify the tiers of flash and spinning disk.

121:57

Sorry, why is this the calculation?

121:58

This is the storage capacity divided by the bandwidth?

122:02

You've got a bunch of different memory tiers, we've listed four of them.

122:08

Your choice of which memory tier is about minimizing the cost.

122:15

What fraction of the device are you using?

122:20

You're using some fraction of the device for holding onto it, and

122:22

then you're using some fraction of the device to retrieve it.

122:27

Let's say I'm using 10% of the device.

122:31

And I want to equalize those two fractions.

122:33

That's a sign that I've hit the right thing.

122:36

Let's say I've got some runtime here.

122:39

I'm going to hold on for all of this time,

122:43

so this is the time-hold.

122:47

And then there's going to be some amount of time here, which is time-retrieve.

122:55

Basically to equalize these two costs, I want the retrieval time

123:00

to be equal to the hold time

123:06

times the fraction of capacity.

123:13

Because this is the retrieval time, this is how many other

123:18

things I can hold simultaneously.

123:20

Basically,

123:22

you want to store things in there for so long such that the amount of

123:28

time it's in there is the time to get all your things in there and out.

123:32

Yeah basically.

123:33

I think that probably indicates that the two tiers are flash and spinning disk.

123:38

I'm kind of shocked to see spinning disk being used at all, because

123:41

it's such an old technology.

123:43

Interesting.

123:44

It’s also crazy that it’s so slow that it takes an hour to

123:46

load its full capacity to it in.

123:48

It’s a really unattractive technology but it’s useful in some places.

123:52

We're sitting down because I want to ask you some questions

123:54

that don't need a blackboard.

123:56

You have this extremely interesting blog post where you talk about how,

124:01

at a high level, the architecture of different cryptographic protocols

124:05

looks a lot like neural networks.

124:08

There's this convergent evolution where they both need to jumble

124:11

information across all their inputs.

124:13

For cryptographic protocols, it's to make sure that each new

124:17

input into a hash function will totally scramble what happens.

124:20

For neural networks, of course, they need to consider how this piece of

124:25

information changes what you should make of this other piece of information.

124:29

I thought that was an extremely interesting point.

124:32

At a high level, in some sense they're trying to do the inverse thing.

124:38

Cryptographic protocols are trying to take information which has structure and make

124:43

it look indistinguishable from randomness.

124:45

Neural networks are trying to take things which look random—protein

124:51

sequences, DNA, garbled text—and extract higher-level structure from it.

124:58

They have similar high-level mechanisms, but they're actually

125:01

trying to do the opposite things.

125:04

I wonder what you make of that.

125:10

I try to look for other examples where mixing and scrambling shows up as well.

125:14

There's almost a physical example where you're making a cake and

125:19

you want to stir the batter.

125:21

Literally the idea to first stir it this way and then stir it this

125:23

way is not too bad of an approach.

125:26

Beyond that, back to the digital world,

125:31

there are some differences, and the one you call out is

125:34

a pretty strong difference.

125:37

The way it shows up,

125:43

if you just randomly initialize a neural network, maybe it's a reasonable

125:48

cipher as well because the random initialization is going to jumble

125:51

stuff in a complicated way.

125:52

It may even do what you want.

125:53

Who knows?

125:56

The thing that makes it interpretable is the gradient descent.

125:59

You can differentiate a neural network and get a meaningful derivative.

126:04

We do a lot of work to not overcomplicate the derivative, so the residual

126:10

connection keeps it contained and simple.

126:14

And so does the LayerNorm stuff that we do.

126:18

One of the biggest attacks against cryptographic ciphers is also

126:21

to differentiate the cipher.

126:22

Ciphers

126:26

run in a different number field.

126:27

They run in the field of two elements, so just binary, whereas neural nets run,

126:34

in theory, in the field of real numbers.

126:38

You have to differentiate with respect to binary numbers, but you

126:43

can absolutely differentiate a cipher.

126:46

This is called differential cryptanalysis.

126:50

Basically, what it says is that if you take a small difference of the

126:52

input, it's quite difficult to make the difference of the output be small.

126:58

The whole job of a well-designed cipher is to make the

127:01

difference in output very large.

127:04

The distinction is that the optimization goals at that

127:07

point are about complexifying.

127:10

They don't have the same residual connections, like LayerNorms.

127:14

I guess a place where the two merge is

127:22

backdoors.

127:22

With a backdoor in an LLM, you're trying to hide… Would you consider it an

127:27

input?

127:28

It’s not an input into the forward pass but it’s an input into the backward pass.

127:31

You’re trying to hide an input into the backward pass.

127:34

This is an adversarial

127:39

context?

127:39

This is actually a place where you get exactly the avalanche

127:44

property that ciphers have as well.

127:49

Adversarial attacks on image classification models are about finding

127:56

a very small perturbation of the image that totally changes the classification,

127:59

totally changes the output.

128:01

That is the common case in ciphers, whereas that's the

128:02

undesired case in neural nets.

128:02

Interesting.

128:02

Has it at all been a successful field to actually use neural networks as ciphers?

128:02

Almost anything you do in trying to create a cipher, if it doesn't have 10

128:02

years of scrutiny, it's probably broken.

128:02

So in that direction, it's a little dangerous.

128:03

In the other direction, there has been at least one very

128:03

clear adoption of technology.

128:03

There is a construction where you take a function, an f[x] function,

128:03

which is not invertible, and use that to build an invertible function.

128:03

That started in ciphers.

128:03

It's called a Feistel cipher or Feistel network.

128:04

You apply the function f—I want to write on the blackboard but I won’t—remember

128:05

the input, and then you swap the two.

128:06

That allows you to construct invertible layers.

128:06

There is a paper from 2018 or 2019 called Reversible Nets, RevNets,

128:06

which does exactly this construction.

128:06

In addition to your residual connection, you also remember the

128:06

input from the previous layer.

128:06

That actually makes the entire layer reversible and almost

128:06

completely eliminates your memory footprint during training.

128:06

Instead of needing to save activations for the backwards pass, you can

128:06

run the entire network backwards and rematerialize the activations.

128:07

Ok, so I was asking you,

128:10

have neural networks actually been used for cryptography?

128:13

And we realized it may be better to just do this on the blackboard.

128:18

Are they actually being used for cryptography?

128:20

Using neural nets for cryptography… In general, creating a new cipher

128:26

is a very dangerous proposition.

128:27

Almost all of them are broken.

128:29

99% of them are broken, so it’s probably a bad place to start.

128:34

But the other direction has been, in at least one very

128:38

clear case, quite productive.

128:41

There's a construction

128:44

that exists in ciphers and then was imported into neural nets called a

128:48

Feistel cipher, or Feistel network.

128:51

The idea is that you may have some function f which is not invertible,

129:00

but you like the function because it does interesting things, like

129:03

it does an MLP, for example.

129:06

Or it mixes it in an interesting way.

129:08

You'd like to build something out of this that is invertible.

129:11

The construction we're going to make is going to be a two-input function

129:13

rather than a one-input function.

129:15

We're

129:19

going to apply

129:22

f[x].

129:25

We need to actually remember what x was, so we're going to stick x over

129:28

here so that we can work backwards, and then we also can't drop y.

129:33

We're going to remember y, and we're going to add them together to form this tuple.

129:36

The way to invert this, if you think I have

129:43

this output and I want to recover x and y, I can easily recover x.

129:47

That's right there, I just read it off.

129:49

To recover y, if this thing was called z, I can recover y by z minus f[x],

129:58

because I've already recovered x.

130:01

That means this construction is invertible.

130:06

This was used in ciphers a ton and still is used.

130:08

It's one of the main mechanisms of constructing ciphers.

130:11

Often you want ciphers to be invertible, especially the layers of ciphers, because

130:16

that has better cryptographic properties.

130:16

This has actually been

130:20

ported over

130:24

into neural nets.

130:25

There's a 2017 paper called RevNets, reversible networks.

130:32

What it does is make the entire network invertible.

130:34

You can apply it to any network, like a transformer network.

130:38

I do a forwards pass, but then I can run the entire pass backwards as well.

130:42

The whole neural network is invertible with exactly this construction.

130:48

This paper applied it to some layer, like a transformer layer, for example.

130:53

We've got this function f, which is our transformer layer.

130:57

Normally we would have just an input and then a residual connection

131:01

coming out, and it gets added

131:06

over here.

131:08

Now, the variation of this is going to be we've got two inputs, x and y.

131:20

x goes through the function, gets added to y,

131:28

and then this becomes the new x, output x.

131:34

Then this x becomes the output y.

131:40

Really what this is doing, if

131:44

you think of two layers back, is the thing you mentioned before.

131:48

It's doing the residual connection from two layers back.

131:52

This y came from the previous layer and was the residual connection there.

131:56

Because of this construction, the whole thing is invertible.

132:00

Why do I care?

132:00

What does invertible matter for?

132:02

The big thing that it can be interesting for is training.

132:05

If I think of a forward pass of training… Let's say I have four layers and I run

132:10

them in zero, one, two, three order.

132:13

I have to write all of the activations to HBM.

132:19

I get an HBM footprint here that is kind of linear

132:24

in the number of layers.

132:30

This can actually be the largest memory footprint during training.

132:35

This is normal training, and then I run the backwards pass and read it in reverse.

132:39

The

132:41

forward pass goes forward, and the backward pass goes backwards.

132:43

I have to read them back out.

132:46

The idea of this RevNets paper is that because it's invertible, I

132:51

don't need to store this at all.

132:52

I can completely rematerialize it.

132:56

I run my forwards pass, and then when I'm running my backwards pass,

132:59

I'm simultaneously in lockstep undoing all of the forwards pass

133:03

steps that I did in order to have the activations that I need here.

133:08

This ends up being memory saving, which is a nice idea.

133:11

Interesting.

133:12

In some sense you're spending more compute to save memory.

133:15

That's right.

133:16

Interesting.

133:17

It's the opposite of what you're doing with the KV cache.

133:20

With the KV cache, you're spending more memory to save compute.

133:23

Yeah.

133:25

Spending more memory to save compute is generally profitable

133:27

given where hardwares are.

133:29

Interesting.

133:30

That was super fun.

133:32

Reiner, thank you so much for doing it.

133:33

I feel like it really vindicated the vision behind

133:36

the studio and the blackboard.

133:38

Yeah.

133:38

Cool, thanks so much for doing it.

133:39

Thanks.

Interactive Summary

Loading summary...