HomeVideos

Efficient Reinforcement Learning – Rhythm Garg & Linden Li, Applied Compute

Now Playing

Efficient Reinforcement Learning – Rhythm Garg & Linden Li, Applied Compute

Transcript

588 segments

0:13

[music]

0:20

Hey everyone, it's great to meet you

0:22

all. Really great to be here today. My

0:24

name is Rhythm. This is my co-founder

0:25

Lyndon. Our third co-founder, Yash,

0:27

couldn't make it today, but we're all

0:29

very excited to be here. Um, three of us

0:31

were previously researchers at OpenAI,

0:33

and now we're bringing Frontier AI

0:35

inside of enterprise at applied compute.

0:38

Today, we're going to be talking about

0:39

efficient reinforcement learning.

0:43

As some context on applied compute, we

0:45

help enterprises build their own

0:46

intelligence to power real work in their

0:48

company. We think a lot about how do we

0:51

push AI beyond productivity into real

0:54

automations that deliver ROI. that's

0:56

quantitative for the company. Once we

0:58

build a system that's specialized to the

1:01

way that a company operates for a

1:03

particular use case, we deploy it with a

1:05

data flywheel so that it gets better

1:07

over time the more and more that you use

1:08

it. Picture an in-house expert at a

1:11

company that's always at the forefront

1:13

of their field.

1:16

RL mechanically is the is the tool that

1:19

we use in order to bring these out of

1:21

distribution data sets in distribution

1:23

for the models today. Yash Lyndon and I

1:26

all worked on the RL effort at OpenAI in

1:28

its early days and we saw firsthand the

1:31

power of RL in going and maximizing

1:33

these public benchmarks. Now we're

1:35

taking that a step further and helping

1:38

enterprises go solve the problems they

1:39

care the most about sort of their

1:41

private benchmarks.

1:44

So here's a very highle overview of how

1:47

highMP compute RL helps LM acquire these

1:49

reasoning and intelligence capabilities.

1:53

Let's say that you have a data set of

1:54

math problems and we pick four of them

1:57

for an RL training step.

2:01

Then we'll take an open source model,

2:02

say one of the GPOSS models or one of

2:04

the llama models, and we have the model

2:06

attempt each of those four problems 100

2:08

times. So each of these 100 attempts is

2:12

the model thinking through how it would

2:14

get to the final answer and then ending

2:15

off with with the final answer itself.

2:17

And these are many many reasoning tokens

2:19

in its thinking trajectory.

2:21

We can grade all of these answers

2:24

and when the model is correct, we can

2:26

bias the model's weights to reinforce

2:29

its thinking trace in that attempt. When

2:31

it's incorrect, we can discourage the

2:32

model from having that kind of behavior

2:34

again. So in this fashion, as we train

2:36

do more and more training steps with

2:38

batches of four problems, 100 attempts

2:40

each, the model learns to reason and

2:42

solve math problems, and it becomes

2:43

really, really good at math. Of course,

2:45

at Applied Compute, we're not really

2:46

helping enterprises solve math problems,

2:48

but this is kind of the mechanism by

2:50

which we're able to teach the models to

2:51

get really, really good at tasks that

2:53

they care about.

2:56

So, as we mentioned, the type of RL work

2:58

that we do at Applied Compute is

3:00

actually quite different from the lab.

3:01

So, the these are some real life photos

3:02

from from the labs and a photo we took

3:05

at the at the applied comput office the

3:06

other day. Um, they you know, the labs

3:09

do these big training runs over several

3:11

weeks. We do more specialized runs

3:15

And you know, there's a couple of

3:16

aspects of RO training that are

3:18

particularly important to us.

3:21

We need our runs to be fast so that we

3:23

can train a model and deliver it to a

3:24

customer very quickly on the order of

3:26

days.

3:27

They have to be cheap so that our unit

3:29

costs work and we're able to scale the

3:31

business sustainably.

3:33

And importantly, and this is a point

3:35

that I think um you know it's it's easy

3:37

to miss, we need our estimates for how

3:39

long these training jobs will be to be

3:41

very low variance because we don't want

3:43

to just be generally fast. We want to be

3:44

reliably fast when we work with

3:46

customers.

3:48

And so the research problem for us that

3:50

is very business critical is can we

3:53

build an RL stack that is so efficient

3:56

so that in conjunction with our agent

3:58

building platform we are really able to

4:00

scale up this use case specific training

4:03

motion.

4:06

So let's start with an inefficient form

4:08

of RL which is synchronous RL. In

4:11

synchronous RL sampling and training

4:13

happen in lock step. So there's some

4:15

simplifications here, but but let's say

4:16

that we want to train on batches of

4:18

eight samples. That means we're going to

4:20

wait for all eight samples to finish and

4:23

basically finish completion before we

4:25

start training. And then we're going to

4:26

repeat this process again. As a result,

4:29

we have a lot of idle GPUs that are

4:31

waiting on that third straggler sample

4:33

to complete.

4:36

So in other words, in synchronous RL,

4:38

our step times are dictated by whatever

4:40

sample takes the longest time in order

4:41

to complete.

4:43

To illustrate why this is bad, we took

4:45

40 arithmetic problems, requested 32

4:48

samples each for each of them with quen

4:50

30B, and we measured how long it would

4:52

take for the for these samples to

4:54

complete.

4:56

It turns out that 99% of the samples

4:58

completed in about 40 seconds. Took

5:00

another 80 seconds to get that last

5:01

percent of samples to complete. It

5:03

really has a long tail.

5:06

So, as you'd expect, if you look at the

5:08

throughput chart, the GPUs are doing a

5:10

lot of work at the beginning when all of

5:11

the sampling requests are launched, but

5:13

by the end, they're very very

5:14

underutilized because they're waiting on

5:16

those last samples to complete. The

5:18

technical term we use at applied compute

5:19

is the GPUs are slacking. Um, so

5:21

synchronous RL is not an efficient way

5:23

to to use these GPUs.

5:27

In order to solve this problem, we need

5:28

to break the condition that sampling and

5:31

training need to happen in lock step. In

5:34

other words, we need to allow training

5:35

while we're sampling. This is called

5:37

asynchronous RL. And there are many

5:39

approaches to doing asynchronous RL. One

5:41

that we particularly like is pipeline RL

5:43

from P at all.

5:46

We're going to make some simp

5:47

simplifications here, but in

5:49

asynchronous pipeline RL, we dedicate

5:51

some GPUs to sampling and some GPUs to

5:53

training. The sampling workers never

5:55

stop. They're constantly doing inference

5:57

with high batch size. As samples

5:59

complete, they get added to a queue for

6:01

training and the training workers pull a

6:03

batch from the queue to train on. After

6:06

a a batch has been trained on, the

6:09

training workers propagate the new model

6:11

weights to all of the sampling workers

6:13

for what's called an in-flight weight

6:14

update. And this is really what

6:16

differentiates pipeline RL. The sampling

6:18

workers might be in the middle of a

6:19

sample, but their weights will still get

6:21

updated if if a training step just

6:24

completed.

6:27

As a result, we end up with samples that

6:29

had multiple versions of the policy that

6:31

contributed to the sample in order to

6:32

generate it. In other words, there are

6:34

stale tokens in some of these in some of

6:36

these samples. Let's take a look at one

6:39

sample to make this a bit more clear.

6:42

As you can see, there's three versions

6:43

of the policy at time steps t, t+1, and

6:46

t plus2 that were used to generate this

6:48

sample since there were two completed

6:50

train steps and in turn two in-flight

6:52

weight updates while this sample was

6:54

being generated.

6:56

So when this sample gets trained on in

6:58

the T+3 to t+4 training batch, we will

7:00

have some tokens that came from policy

7:03

three steps behind, some that came from

7:04

policy two steps behind, and those last

7:06

two tokens that came from a policy that

7:08

was one step behind.

7:11

Now, let's say that we only tolerate

7:13

stailness up to two. That means we're

7:16

not going to allow the inflight weight

7:18

update after the T+1 to T+2 training

7:20

batch completes. And that means the

7:22

training workers are just going to be

7:23

idle waiting for this sample to complete

7:25

before they can propagate that in-flight

7:27

weight update and start training on the

7:29

next batch. Because if they were to do

7:30

the inflight weight update, that would

7:31

cause this sample to have stalness 3 as

7:33

we just saw.

7:35

And if we only tolerate stailness one,

7:37

the training workers are going to be

7:39

idle for even longer,

7:42

which is bad. So as you increase how

7:44

much stale you tolerate, you have less

7:46

idle GPUs in general. But as we all

7:48

know, there's no free lunch. Um this is

7:50

the standard policy gradient with an

7:52

importance ratio to adjust for the fact

7:54

that we're sampling from a policy at

7:56

time step t and training with the policy

7:59

at time step t t plus k given that

8:00

there's case staleness. The importance

8:03

ratio is what makes this policy gradient

8:05

unbiased. But the variance of that ratio

8:08

increases as you increase stalness. And

8:10

so this is kind of the big issue here

8:12

because now with with higher variance

8:14

importance ratio learning can become

8:16

unstable and cause divergence.

8:19

The concrete trade-off is we want a lot

8:20

of stailness for fast RL runs, but a lot

8:23

of staleness makes learning unstable,

8:25

which then requires innovating on the

8:27

algorithm and the science. And this is

8:29

one of the primary research problems

8:30

that we focus on here at Applied

8:32

Compute. And as I was talking about

8:33

earlier, it directly flows back into our

8:35

core business.

8:38

For the purpose of this talk, we're

8:39

going to focus on a simpler sub problem.

8:41

Let's assume that we have good science

8:43

and algorithmic innovations that allow

8:45

us to tolerate staleness up to some

8:47

fixed threshold and we have some fixed

8:49

compute budget as usually exists in the

8:51

world. What is the highest way for us to

8:54

do RL in this setting?

8:57

Cool. Thanks Rhythm.

9:00

So we posed this as a modeling problem

9:02

of our endto-end system which you know

9:04

admittedly is a little bit complicated

9:05

at first but we did find that we can get

9:07

surprisingly far with some first

9:09

principle systems modeling and as with

9:11

any modeling problem let's figure out

9:13

the cast of characters that describe the

9:15

system and then we'll think about how

9:17

they all fit together to model it. So

9:19

the first cast member is some proxy of

9:21

compute budget in which in this case we

9:23

have as the number of GPUs. In the

9:25

synchronous setting like rhythm just

9:27

explained all the GPUs will either be

9:29

used for training or sampling since they

9:31

happen one after the other. But in the

9:33

asynchronous setting it's a little bit

9:35

trickier cuz we can choose to allocate

9:37

that pool of GPU GPU compute as much as

9:40

we want for training or as much as we

9:41

want from sampling and that leads to

9:43

some design decisions.

9:45

The next is the training batch size

9:47

which is some proxy of the workload that

9:49

we have uh on the on the overall system

9:52

and this is kind of an ML decision but

9:54

in short what we have is a batch of

9:56

problems which is a subset of our data

9:58

set. Let's say we have n math problems

10:00

that we want to train on and for each of

10:02

these problems we're going to sample n

10:04

problems in parallel. So if the problems

10:06

are really difficult, we might sample

10:08

more to encourage some diversity in the

10:10

samples to encourage the model to learn

10:12

some potentially uh divergent

10:13

strategies.

10:16

The next thing we need is some proxy of

10:18

sampling throughput. And to get some

10:19

intuition of what we should choose here

10:21

as a modeling decision, let's look at

10:23

how some modern inference engine surface

10:25

requests. So in GPU memory, we have the

10:27

model weights, the activations, and some

10:29

runtime state called the KV cache in

10:31

memory. And given this train model,

10:33

we're going to run the forward pass

10:35

several times where each forward pass

10:37

samples the next token and then we'll

10:39

write to the KV cache. And so what this

10:43

model shows is that a principal estimate

10:45

that we should do is we should find some

10:47

way to measure the latency per GPU of

10:50

the forward pass. And this ends up being

10:52

a pretty good choice in practice because

10:54

from the systems angle, the inference

10:56

throughput that we choose is largely

10:58

determined by the batch size that we

11:00

perform sampling with. So what I've

11:02

shown here in the red square is a batch

11:04

of tokens that are all forwarded at the

11:06

same time. And this sampling forward

11:08

pass needs to be as large as possible to

11:10

efficiently utilize the GPUs subject to

11:13

the runtime constraint that we don't

11:15

actually run out of memory uh in the KV

11:17

cache.

11:19

So what we can then do is we can fit a

11:21

latency curve as a function of batch

11:23

size and that latency curve will look

11:25

something like this. You'll have some

11:27

regime where it's memory bound and when

11:28

it increases it becomes computebound and

11:30

there's some functional form below. And

11:33

to explain the details of why we chose

11:34

this decision, what we have here is an

11:36

equation that's based in the roof line

11:38

model from systems. At lower batch

11:40

sizes, which I've highlighted in yellow

11:42

here, we don't have that much work to do

11:45

because there isn't that much compute to

11:46

do on the processor and there's so many

11:48

parameters you need to load in at the

11:50

same time. And so, as a result, when you

11:52

add incremental work, it doesn't really

11:55

add that much latency to the overall

11:56

system since the processor is so fast at

11:59

doing math that we're just waiting on

12:01

memory to stream parameters in from the

12:03

pro from memory to the processor. But as

12:06

the batch sizes begin to get larger, we

12:08

then get bottlenecked by the processor.

12:10

And the more we add to our batch, the

12:12

slower the forward pass takes. And just

12:14

for good measure, we have this sigmoid

12:16

here that just sort of modulates the

12:17

smooth transition at this hinge point

12:19

here to show that there's a subtle

12:21

transition from a memory bound

12:22

computation to one that's more

12:24

computebound and bottlenecked by the

12:26

processor.

12:28

The final cast member is some proxy of

12:30

training throughput and we chose to

12:32

measure this on a per GPU basis. So in

12:35

this case the model takes in the

12:37

training batch size. So the parameter we

12:38

saw earlier and we typically do this by

12:41

fitting a proxy of our empirical

12:43

workloads. The units here is how many

12:45

each train how many tokens per second

12:47

each training GPU processes. So it needs

12:49

to do the forward the backward and some

12:51

optimizer steps.

12:53

So given these forecast members we can

12:55

then begin modeling the system. And the

12:57

first idea we had although Rhythm you

12:59

know suggested that this might not be a

13:01

great idea we can think about how to use

13:03

a synchronous setup. And this might be a

13:05

good idea from first principles because

13:06

we definitely meet the staleness

13:08

constraint because we don't train on

13:09

stale data and we always use the entire

13:12

GPU fleet for either training or

13:14

sampling making sure that we're using

13:16

efficient use of the hardware. Let's

13:18

think about how to actually model this.

13:20

There are two things we need to know. We

13:21

need to know the batch size at which

13:23

generation runs. And we also need to

13:25

know the response length distribution to

13:27

figure out how our training workload's

13:29

going to work and also how long the

13:30

sampling's going to take. And so what

13:33

I'm showing here in this simulation is a

13:35

couple of engines. Each square is a

13:37

request being processed and they get

13:38

darker and darker as we make progress

13:40

throughout the batch. And as they finish

13:42

samples, they write to the queue. And on

13:45

the right hand side is a time series

13:46

metric, maybe something that you'd see

13:47

in Graphana if you're monitoring

13:49

production metrics. And what you can see

13:50

is that the batch size begins very high,

13:52

but it slowly goes down over time as it

13:55

eventually goes to zero and all the

13:56

samples complete. And we can finally run

13:59

an optimization step. After the step

14:01

completes, we run this in a loop and we

14:02

move on to the next step. And so as a

14:05

result, we can have the following

14:06

sampling procedure. We do maximum tokens

14:09

inference forward passes where maximum

14:12

tokens is the total number of forward

14:14

passes we do for the longest request. We

14:16

use the fitted latency estimator to

14:18

figure out how long that forward pass

14:20

will take. And then the response length

14:22

distribution will tell us how many

14:23

responses to drop. And so what we're

14:25

showing in this video here is this

14:27

entire thing of the response length

14:28

distribution that we feed into the

14:30

latency estimator. At training time, we

14:32

can compute the total number of tokens

14:34

that we just sampled in the batch and

14:36

divide by the total uh training

14:37

throughput uh which is just the number

14:39

of GPUs multiplied by the per GPU

14:41

training throughput. And so what we have

14:43

here is a simulation of what this

14:45

latency curve looks like. So we have the

14:47

CDF of the response length distribution

14:49

that tells us how many responses we

14:50

should drop on the left and the latency

14:52

curve on the right. And this roughly

14:54

kind of tracks because as we add more

14:55

GPUs, we'd expect the latency per step

14:57

to go down.

15:00

The next idea, given that the

15:01

synchronous setup might not be the most

15:03

principled choice, as Rhythm showed, is

15:05

an asynchronous setup. But it's not just

15:07

as easy as just sort of provisioning the

15:09

compute between training and inference

15:11

because if we don't do this carefully,

15:13

we might actually run into the idle GPU

15:15

problem again. And to show this, let's

15:17

illustrate two extremes of what this

15:18

allocation problem looks like. Let's f

15:21

let's let's first look at one end of the

15:22

spectrum where we provision way too many

15:24

inference GPUs and not that many

15:26

samplers. In this case, we're consuming

15:29

from a queue much faster than we're

15:31

actually producing from it because the

15:33

sampling workers are producing work

15:35

significantly faster than significantly

15:37

slower than we can actually consume

15:39

them. When the red square grays out, it

15:41

shows that they're idle. And what this

15:42

diagram should hopefully illustrate is

15:44

that for a lot of the time we're

15:46

actually not using that and that has the

15:47

same problem of low GPU utilization in

15:50

the synchronous case as shown earlier.

15:53

On the other end of the extreme we can

15:55

provision way too many sampling GPUs in

15:58

which case our production rate is way

16:00

faster than the rate that we actually

16:02

consume them in. So here we've doubled

16:04

the number of overall sampling GPUs and

16:06

have the number of training GPUs. As you

16:08

can see, they produce samples at much

16:10

more rapid of a rate. But this index

16:12

here in each yellow square, which is the

16:14

staleness count of each sample, goes up.

16:17

And as time moves on, we get more and

16:19

more stale. And so the samples get more

16:21

and more kind of less more and more

16:23

transparent as a result. And we learn

16:25

less from each individual sample. So

16:28

let's think about how we can actually

16:29

model this workload then to to determine

16:31

an optimal async workload. In this case,

16:34

the picture looks a little bit different

16:36

because in steady state, the batch size

16:37

is relatively consistent compared to the

16:39

synchronous setup where it kind of goes

16:41

down over time. So on the right hand

16:43

side here, we have the same time series

16:44

metrics. But in this case, it's a little

16:46

bit different because the yellow squares

16:48

are always full because every time we

16:50

complete a sample, a new sample goes in

16:52

and we can continue writing to the

16:54

queue. And so that batch size with a

16:56

little bit of wiggles just for good

16:57

measure is like a is pretty consistent

16:59

over the course of a run. Now obviously

17:02

the caveat here is that this batch size

17:04

will certainly go down as we you know as

17:06

response lengths go up because we run

17:07

out of cache uh KV cache but that's kind

17:09

of a separate story and actually our

17:11

model accommodates for that because

17:12

we're actually accommodating for a

17:13

response length distribution.

17:17

We can then begin to figure out the

17:18

optimal layout and there's two kind of

17:20

constraints that we have to satisfy now

17:22

that we know that the generation batch

17:23

size is roughly consistent throughout

17:25

the course of a run. The first invariant

17:27

that we need to have is that the

17:28

production consumption rate are roughly

17:30

equal. So on the left hand side of this

17:32

equality we have the training throughput

17:34

which is the number of training GPUs

17:35

multiplied by the per GPU uh throughput

17:38

and then also we have the number of

17:39

sampling GPUs multiplied by the sampling

17:42

throughput which is just the batch size

17:43

multiplied by the latency to actually do

17:45

a forward pass on that batch size. And

17:48

the next thing is that given that rhythm

17:50

you indicated that if we have too much

17:52

stailness that can be bad from an ML ML

17:54

perspective, we want to make sure that

17:56

our max theoretical stailness or

17:58

simulated steness doesn't exceed what

18:00

our ML can handle. And so here we have

18:03

the max stillness on the left which is

18:05

equal to on the top how much time the

18:08

longest request took in the batch which

18:10

is just the maximum number of tokens

18:12

multiplied by the number of uh by the

18:14

amount of time each token forward pass

18:15

takes. And on the bottom here we have

18:17

the length of a training step which is

18:19

the training batch size multiplied by

18:21

the mean sequence uh by the mean

18:23

sequence length.

18:25

So the simulation here then will sweep

18:27

through multiple different values of the

18:29

number of training GPUs. And since we

18:31

have a fixed pool of compute that then

18:33

implies a certain number of GPUs used

18:35

for sampling. And for this number of

18:37

sampling GPUs, we can compute the

18:39

minimum steadystate generation batch

18:41

size to make sure that we don't blow out

18:43

of memory uh subject to our KV cache

18:45

memory constraints and also such that we

18:47

have maximum throughput on the on the

18:49

sampling side. And the final thing is we

18:51

want to prune out all simulations where

18:53

the sampling throughput brings us over

18:55

the maximum possible stailness. When we

18:58

look at that simulation, we can run an

18:59

end to end similarly parameterized by

19:01

the response length. We see that this

19:02

kind of roughly simulates a 60% speed up

19:05

relative to our synchronous baseline,

19:07

assuming that the GPU compute is

19:09

optimally allocated between training and

19:11

sampling.

19:12

As a result, when we sweep layouts

19:14

within these constraints, this allows us

19:16

to limit staleness, but also make sure

19:18

that we have our runs running at maximal

19:20

throughput without actually doing the

19:22

run itself. And so this gives us insight

19:24

to simulate different workloads before

19:26

actually running them on the GPU because

19:28

these runs can actually be fairly

19:29

expensive. And so this allows us to ask

19:31

answer scientific questions from first

19:33

principles like what is the optimal

19:35

configuration that we we should have of

19:37

our GPU compute if we made response

19:39

lengths very long because often times

19:41

when models learn via reinforcement

19:43

learning they begin to think for much

19:44

longer and also what empirical

19:46

throughputs we should target during our

19:48

performance optimization. So this has

19:50

been a really useful piece of technology

19:51

for simulation has informed a lot of the

19:53

systems and research design decisions

19:55

that we make. Cool. Thanks for your time

19:57

and find us afterwards to jam on some

19:59

more RL research engineering together

20:00

later. Thank you. [music]

20:18

>> [music]

Interactive Summary

The presentation discusses efficient reinforcement learning (RL) techniques tailored for enterprise applications by the company Applied Compute. The speakers, Rhythm and Lyndon, explain the transition from inefficient synchronous RL, where sampling and training are locked, to asynchronous pipeline RL. They describe the engineering challenges of managing stale weights and model stability, ultimately demonstrating how they use systems modeling from first principles to optimize GPU allocation, ensuring both high throughput and controlled staleness to achieve efficient, cost-effective enterprise AI model training.

Suggested questions

3 ready-made prompts