Server-side batching: Scaling inference throughput in machine learning

Model serving meets micro batching

Go Gopher mascot with code

It is often the case in production machine learning that we need to serve predictions in a responsive manner, but not necessarily with realtime latency.

To illustrate, think of a computer vision task involving video data. Companies like Wildlife Protection Solutions use object detection to analyze a network of live-streaming cameras placed throughout wildlife preserves, detecting poacher activity and alerting authorities:

In this situation, inference obviously needs to be run in response to a request (in this case, the video being sent to the inference API). If the video is simply stored offline and then processed in a once-a-day batch prediction, it may only recognize poachers a full 24 hours after they’ve been recorded.

At the same time, we’re talking about large amounts of data being transferred and processed here, which incurs a serious cost. Additionally, the inference doesn’t necessarily need to be realtime—a few minutes is unlikely to be a deciding factor when we’re talking about poachers who have to spend hours trekking through the wildlife to travel.

In other words, we want responsive inference, but we want to maximize throughput at the cost of latency.

In Cortex 0.25, we’ve developed a way to achieve this in a new feature we call server-side batching.

But what exactly is server-side batching? Is it just normal realtime inference, but slower? To answer this, it’s helpful to set some context around realtime inference in Cortex.

How we serve models with realtime latency

In a realtime prediction API, the kind you might deploy with Cortex, you can think of your prediction service like any other microservice. Ignoring all the cloud infrastructure involved, the service itself is just some request handling code (the predictor) sat on top of a web server (FastAPI, in Cortex’s case).

To maintain low latency, platforms like Cortex autoscale replicas of prediction APIs as traffic increases. In Cortex’s case, we actually implement request-based autoscaling, in which you define a concurrency limit for a particular API, which Cortex uses to autoscale according to request queue length.

Latency, in this paradigm, can be thought of as a product of the amount of available replicas and compute resources. The more are available, the less time requests wait in the queue, and the faster inference is performed.

Reading that, you might think “Okay, but why do we need a new feature? If we don’t care about latency, can’t we just let the request queue fill up?” The answer is we can (and we do, kind of), but that when we assume that latency is not a primary concern, we give ourselves a lot of extra room to improve throughput—with features like server-side batching.

Why we need server-side batching for production machine learning

Let’s return to the video streaming example. This is a high-level overview of Wildlife Protection Service’s monitoring system:

We won’t get into the specifics of catching poachers, but we can use this as a general schematic for how a responsive video monitoring solution might look and to illustrate why a “just let the queue fill up” approach is suboptimal.

Imagine our system regularly captures long videos. How do we process that? Does the camera wait until it’s done capturing, and then upload the entire video in one request? Probably not. That’s fragile and inefficient. Instead, our cameras probably upload the video in small chunks as its captured.

That second option essentially is a batching scheme. We’re breaking our data into smaller pieces and placing it in a queue, which is responsible for distributing those pieces across our compute resources. It just isn't batching in a particularly efficient way.

This is where server-side batching comes in. By enabling server-side batching in a realtime API, we can aggregate their requests into a single batch inference, but on-demand.

This means that we get the throughput benefits of batch inference—processing multiple inputs in a single inference—but we also get the on-demand interface of a realtime API.

How to use server-side batching in Cortex

Setting up server-side batching is relatively straightforward. It is just an addition to regular realtime Cortex APIs. For example, to define an API with server-side batching, you simply add a “server_side_batching” field to your API configuration.

Here it is in YAML:

And of course, the fields are identical if you are using the Python client to define APIs.

The "max_batch_size" field allows you to define how many requests your API should aggregate into a single batch (i.e. when, as requests come in, to trigger a batch job). The "batch_interval" field gives you a guard against latency, allowing you to set a maximum time to wait for the "max_batch_size" limit to be reached before processing the queue.

In addition, you’ll need to make some slight changes to your request handling code. Namely, the arguments you pass to your "predict()" method will now be lists, as we’re processing multiple values:

Once that is setup, Cortex will automatically begin batching requests. For more information on optimizing your server-side batching behavior, see the docs.

Like Cortex? Leave us a Star on GitHub

Star Cortex

Interested in production machine learning?