Build Asynchronous ML Inference with FastAPI and Celery

When you deploy a Large Language Model (LLM) or a heavy computer vision model, a single inference request can take anywhere from 2 to 30 seconds. In a standard synchronous web architecture, this blocks the worker process, leading to 504 Gateway Timeouts and a complete system freeze under even moderate load. To build a production-ready AI service, you must decouple the request reception from the heavy computation.

By using FastAPI as a high-performance gateway and Celery as a distributed task queue, you can handle thousands of concurrent requests while processing the heavy lifting in the background. This architecture ensures that your API remains responsive, providing users with a task ID they can use to poll for results later. You achieve horizontal scalability by simply adding more Celery workers as your traffic grows.

TL;DR — Decouple long-running ML tasks from the request-response cycle using FastAPI for the API layer, Redis as a message broker, and Celery for background workers. This prevents HTTP timeouts and allows for independent scaling of web and inference resources.

The Core Architecture Concept

💡 Analogy: Imagine a busy restaurant. In a synchronous setup, the waiter stands at your table and waits for the chef to cook your steak before moving to the next customer. The whole restaurant stops. In an asynchronous setup, the waiter takes your order, gives you a buzzer (Task ID), and moves to the next customer. When the chef finishes the meal in the kitchen (Celery Worker), your buzzer goes off.

The primary goal of this architecture is to separate the "IO-bound" task (receiving a request) from the "CPU/GPU-bound" task (running inference). FastAPI excels at handling many concurrent connections because of its asyncio foundation. However, asyncio does not help with CPU-heavy tasks like running a PyTorch model; those will still block the event loop if run directly inside the endpoint.

By introducing Redis as a broker, FastAPI offloads the model input data to a queue. Celery workers, which live in separate processes or even separate containers with GPU access, consume these messages. This separation allows you to update your API without restarting your model workers and vice versa, significantly increasing the reliability of your MLOps pipeline.

When to Choose Asynchronous Inference

You should adopt this pattern when your model inference time exceeds 500ms. Most modern load balancers, such as AWS ALB or NGINX, have default timeout settings around 30 to 60 seconds. If your model experiences a spike in traffic, queueing delays can easily push response times past these limits, causing the client to receive an error even if the model eventually finishes its work.

Another critical scenario is cost management. GPU instances are expensive. By using an asynchronous queue, you can handle bursts of traffic without needing to provision enough GPUs to handle the absolute peak concurrency. The queue acts as a buffer, allowing a smaller pool of workers to process the backlog steadily without crashing the system or dropping requests. This is especially useful for non-real-time tasks like image generation, document parsing, or batch sentiment analysis.

System Design and Data Flow

The data flow follows a specific path to ensure data integrity and observability. First, the client sends a POST request with the input data (e.g., text prompt or image URL). FastAPI validates this input using Pydantic, generates a unique task_id via Celery, and pushes the payload into Redis. The client immediately receives a 202 Accepted status code.

[Client] -> (HTTP POST) -> [FastAPI Gateway]
                                |
                   (Push Task) -> [Redis Broker] <- (Fetch Task)
                                                        |
                                               [Celery Worker (ML Model)]
                                                        |
                   (Store Result) -> [Redis/Postgres Result Backend]

Once the Celery worker finishes the inference, it writes the result to a "Result Backend." This can be the same Redis instance or a persistent database like PostgreSQL. The client then polls a separate GET endpoint with their task_id to check the status. This decoupled flow means that if a worker crashes, the task remains in Redis (if configured for persistence) and can be retried by another worker, providing high availability.

Step-by-Step Implementation

1. Configure the Celery Instance

First, define your Celery application and the inference task. Ensure you initialize the ML model inside the worker process so it stays in GPU memory rather than reloading on every request.

from celery import Celery
import os

# Initialize Celery
celery_app = Celery(
    "worker",
    broker=os.getenv("REDIS_URL", "redis://localhost:6379/0"),
    backend=os.getenv("REDIS_URL", "redis://localhost:6379/0")
)

@celery_app.task(name="run_inference")
def run_inference(data: dict):
    # In a real scenario, load the model globally to keep it in memory
    # result = model.predict(data)
    import time
    time.sleep(5)  # Simulating heavy ML work
    return {"prediction": "success", "score": 0.99}

2. Create the FastAPI Gateway

The FastAPI app serves as the interface. It doesn't need to know anything about the ML model code; it only needs to know the task name and how to talk to Redis.

from fastapi import FastAPI
from celery.result import AsyncResult
from worker import run_inference

app = FastAPI()

@app.post("/predict", status_code=202)
async def predict(payload: dict):
    task = run_inference.delay(payload)
    return {"task_id": task.id}

@app.get("/result/{task_id}")
async def get_result(task_id: str):
    result = AsyncResult(task_id)
    return {
        "task_id": task_id,
        "status": result.status,
        "result": result.result if result.ready() else None
    }

3. Execution and Scaling

Run the FastAPI app using Uvicorn and start the Celery worker in a separate terminal. To scale, you would run multiple instances of the worker command, potentially across different machines.

# Start API
uvicorn main:app --reload

# Start Worker (Use --concurrency=1 for GPU tasks to avoid OOM)
celery -A worker.celery_app worker --loglevel=info --concurrency=1

Architecture Trade-offs and Comparisons

Choosing between synchronous and asynchronous architectures involves balancing user experience with operational complexity. While sync is easier to debug, it fails under the weight of modern ML requirements.

Feature Synchronous API Async (FastAPI + Celery)
Max Concurrency Limited by worker processes High (limited by Redis memory)
Timeout Risk Very High Zero (for the HTTP request)
Implementation Simple Moderate (requires broker)
Client Handling Wait and see Polling or Webhooks

The most important factor is Predictability. In a synchronous world, a slow request blocks the queue for everyone. In an asynchronous world, the "waiter" is always available to take new orders, even if the "kitchen" is backed up. This prevents your entire frontend from appearing "down" to new users just because a few users are running heavy tasks.

Operational Tips for Production

When running this in production, avoid using the default prefork pool for Celery if your ML model uses a lot of RAM. Each Celery fork will attempt to copy the parent memory space, which can lead to Out-Of-Memory (OOM) errors. Instead, set --concurrency=1 per container and scale by adding more containers (pods in Kubernetes). This ensures strict resource isolation for your expensive GPU/RAM assets.

Furthermore, implement Result Expiration in Redis. If you don't set a result_expires value in your Celery configuration, Redis will store every inference result forever, eventually consuming all available RAM. For most ML tasks, a TTL (Time To Live) of 1 hour is sufficient for the client to poll the result and move on.

📌 Key Takeaways:

  • FastAPI handles the gateway; Celery handles the heavy ML compute.
  • Use status_code=202 for the initial request to signal background processing.
  • Redis acts as the buffer, protecting your ML workers from traffic spikes.
  • Scale your workers independently of your API gateway for cost efficiency.

Frequently Asked Questions

Q. How should the client know when the ML inference is finished?

A. There are two main patterns: Polling and Webhooks. In Polling, the client calls the GET endpoint every 1-2 seconds. For better efficiency, you can use Webhooks, where the Celery worker sends an HTTP POST request to a URL provided by the client once the task is complete.

Q. Can I use RabbitMQ instead of Redis?

A. Yes. RabbitMQ is a more robust message broker for complex routing. However, Redis is often preferred for ML tasks because it can serve as both the message broker and the result backend, simplifying your infrastructure stack.

Q. What happens if a Celery worker crashes mid-inference?

A. If you enable task_acks_late=True, the task will not be removed from the queue until the worker finishes. If the worker crashes, the broker will recognize the connection loss and put the task back in the queue for another worker to pick up.

Post a Comment