Table of Content

close

Introduction 

A Brief Recap of LoRA

Challenges in Batching Different Task Types

    How SGMV Works
    Benefits of SGMV

    Key Components of LoRAX
    Implementation Example
    Testing with curl

References

Multi-LoRA inference: Serve thousands of fine-tuned LLMs on a single GPU

open-book7 min read
Akshat Patil
Rohit Aggarwal
Harpreet Singh
Akshat Patil
  +2 More
down

Introduction 

In our previous articles, we explored the limitations of in-context learning and the motivations behind fine-tuning large language models (LLMs). We highlighted the growing need for solutions that can provide task-specific optimizations without the computational overhead traditionally associated with full model fine-tuning. This led us to our exploration of Low-Rank Adaptation (LoRA) in "LoRA Demystified: Optimizing Language Models with Low-Rank Adaptation," where we delved into the intricacies of this groundbreaking technique that has revolutionized how we fine-tune LLMs. 

Now, we take the next logical step in our journey: scaling LoRA for production environments. As organizations increasingly rely on specialized language models for a variety of tasks, a new challenge emerges: how can we serve multiple LoRA-adapted models simultaneously without sacrificing performance or breaking the bank? This article answers that question by introducing cutting-edge techniques for multi-tenant LoRA serving, enabling the deployment of thousands of fine-tuned LLMs on a single GPU.

We'll explore the evolution from basic LoRA implementation to advanced serving strategies, focusing on:

  • The challenges of batching different task types in a multi-tenant environment
  • Innovative solutions like Segmented Gather Matrix-Vector Multiplication (SGMV)
  • The concept and implementation of heterogeneous continuous batching
  • Practical examples using state-of-the-art tools like LoRAX

By the end of this article, you'll have a comprehensive understanding of how to leverage LoRA at scale, opening new possibilities for efficient, cost-effective deployment of multiple specialized language models in production environments. 

A Brief Recap of LoRA

Before we dive into the complexities of multi-tenant serving, let's quickly recap the key idea behind LoRA:

  1. Keep the pretrained model's weights intact.
  2. Add small, trainable matrices to each layer of the Transformer architecture.
  3. Use rank decomposition to keep these additional matrices low-rank.

This approach offers several significant advantages:

  • Drastically Reduced Parameter Count: By focusing on low-rank updates, LoRA significantly reduces the number of parameters that need to be trained. This makes fine-tuning more efficient and less resource-intensive.
  • Preserved Base Model: Since the original model weights remain unchanged, you can easily switch between different LoRA adaptations or revert to the base model without any loss of information.
  • Cost-Effective Customization: The reduced computational requirements make it feasible to create multiple customized LoRA models tailored to specific needs, even with limited resources.
  • Competitive Performance: Despite its simplicity, LoRA often achieves performance comparable to full fine-tuning across a wide range of tasks.

LoRA's efficiency and effectiveness have made it a cornerstone of modern LLM deployment strategies. However, to truly leverage its power in production environments, we need to address the challenges of serving multiple LoRA-adapted models simultaneously. This is where batching strategies come into play.

Challenges in Batching Different Task Types

As we move towards deploying multiple LoRA-adapted models in production, we encounter a new set of challenges, particularly when it comes to batching requests efficiently. Let's explore these challenges and why traditional batching approaches fall short.

The GPU Utilization Imperative

Graphics Processing Units (GPUs) are expensive and limited resources. Efficient GPU utilization is crucial for cost-effective deployment of LLMs. As highlighted by Yu et al. in their 2022 study, batching is one of the most effective methods for consolidating workloads to enhance performance and GPU utilization.

The Naive Approach: Separate Queues

A straightforward approach to handling multiple LoRA-adapted models would be to batch workloads separately for each task type or adapter. This method involves:

  • Segregating tasks into queues based on their type or associated adapter.
  • Waiting for each queue to reach a specific size (batch size) before processing

However, this approach leads to several significant drawbacks:

  • Resource Underutilization: The system might have idle resources even when there are enough tasks of different types for a batch, simply because it's waiting for individual queues to fill. This significantly reduces overall throughput.
  • Unpredictable Performance: Performance becomes highly dependent on the arrival rate of each task type. Less frequent tasks can cause long delays in their respective queues, potentially holding up dependent tasks waiting for completion.
  • Scalability Issues: Adding new task types or adapters requires creating new queues, increasing management complexity and potentially leading to more idle periods with less frequent queues.
  • Latency Spikes: Tasks might experience high latency if they arrive when their queue is nearly empty, as they'll have to wait for the queue to fill before being processed.

Here's a simplified Python example illustrating the challenges of this naive approach: 

import queue
import time

class NaiveBatchingSystem:
   def __init__(self, batch_size=32):
       self.queues = {}
       self.batch_size = batch_size

   def add_task(self, task_type, task):
       if task_type not in self.queues:
           self.queues[task_type] = queue.Queue()
       self.queues[task_type].put(task)

   def process_batches(self):
       while True:
           for task_type, task_queue in self.queues.items():
               if task_queue.qsize() >= self.batch_size:
                   batch = [task_queue.get() for _ in range(self.batch_size)]
                   print(f"Processing batch of {task_type} tasks")
                   # Process the batch...
               else:
                   print(f"Waiting for more {task_type} tasks...")
           time.sleep(1)  # Avoid busy-waiting

# Usage
batcher = NaiveBatchingSystem()
batcher.add_task("math", "2 + 2")
batcher.add_task("translation", "Hello in French")
batcher.process_batches()

This example demonstrates how tasks of different types might be stuck waiting for their respective queues to fill, even if there are enough total tasks to form a batch.

These challenges highlight the need for a more sophisticated approach to batching, one that can efficiently consolidate multi-tenant LoRA serving workloads onto a small number of GPUs while maximizing overall utilization. To address these challenges, researchers have developed innovative techniques like Segmented Gather Matrix-Vector Multiplication (SGMV).

Segmented Gather Matrix-Vector Multiplication (SGMV)

Chen et al. introduced SGMV in 2023 as a novel CUDA kernel designed specifically for multi-tenant LoRA serving. SGMV enables the batching of GPU operations, allowing multiple distinct LoRA models to be executed concurrently.

How SGMV Works

At its core, SGMV optimizes the matrix multiplication operations that are central to LoRA adapters. Here's a simplified explanation of how it works:

  1. Segmentation: Instead of treating each LoRA adapter as a separate entity, SGMV segments the operations across multiple adapters.
  2. Gather: It efficiently gathers the relevant weights from different adapters based on the incoming requests.
  3. Batched Multiplication: The gathered weights are then used in a batched matrix-vector multiplication operation, leveraging the GPU's parallel processing capabilities.

Benefits of SGMV

By leveraging SGMV, we can:

  1. Process Multiple Adapters Concurrently: Different LoRA models can be executed in parallel, improving overall system performance and resource utilization.
  2. Eliminate Queue-Based Bottlenecks: SGMV allows for grouping requests for different adapters together, avoiding the need for separate queues for each adapter or task type.
  3. Maintain Continuous Processing: The system can process tasks constantly, regardless of type, keeping the processing flow continuous and avoiding delays from waiting for specific task types to accumulate.
  4. Improve Throughput and Consistency: Heterogeneous continuous batching significantly improves overall throughput and maintains consistent performance even with a growing number of different tasks or adapters.

While the actual implementation of SGMV is complex and involves low-level GPU programming, its effects can be observed at the system level.

Heterogeneous Continuous Batching in LoRAX 

LoRAX, an open-source Multi-LoRA inference server, represents a significant leap forward in the efficient deployment of multiple fine-tuned language models. At its core, LoRAX leverages the power of SGMV to achieve heterogeneous continuous batching, optimizing overall system throughput while maintaining low latency.

Key Components of LoRAX

LoRAX's architecture is built around three fundamental components that enable its powerful heterogeneous batching capabilities:

  1. Dynamic Adapter Loading: LoRAX doesn't require all adapters to be pre-loaded into GPU memory. Instead, it dynamically downloads and loads adapters onto the GPU as requests arrive. This on-demand loading ensures efficient use of GPU memory and allows the system to handle a large number of different adapters without blocking other requests.
  2. Continuous Batching: Unlike traditional batching systems that wait for a fixed batch size, LoRAX employs a token-based approach to manage batching. It dynamically groups requests into batches based on available GPU memory and desired latency, ensuring a continuous flow of processing.
  3. Asynchronous Adapter Scheduling: A background thread in LoRAX efficiently manages adapter offloading and loading, minimizing the performance impact of swapping adapters in and out of GPU memory.

Implementation Example

Let's look at a simplified example of how LoRAX handles a batch of tasks using the lorax-client with Flask:

from flask import Flask, jsonify, request
from lorax import Client
import requests

app = Flask(__name__)

# Configuration
LORAX_ENDPOINT = "http://127.0.0.1:8080"  # Replace with your LoRAX server endpoint
CALLBACK_URL = "http://localhost:5001/uploadresponse/"  # Replace with your callback endpoint

# Initialize the LoRAX client
lorax_client = Client(LORAX_ENDPOINT)

@app.route("/lorax/upload", methods=["POST"])
def upload_batch():
   """
   Handles batch upload requests.
   """
   try:
       # Parse the request body
       data = request.get_json()
       batch_id = data.get("batchId")
       prompts = data.get("data")

       if not batch_id or not prompts:
           return jsonify({"message": "Missing batchId or data"}), 400

       # Send the batch to LoRAX
       responses = []
       for prompt_data in prompts:
           response = lorax_client.generate(
               prompt_data["prompt"],
               adapter_id=prompt_data.get("adapter_id"),
               max_new_tokens=prompt_data.get("max_new_tokens"),
               # ... other parameters
           )
           responses.append(response.dict())

       # Trigger the callback
       callback_data = {"batchId": batch_id, "response": responses}
       requests.post(CALLBACK_URL, json=callback_data)

       return jsonify({"message": "Batch processed successfully"}), 200

   except Exception as e:
       print(f"Error processing batch: {e}")
       return jsonify({"message": "Error processing batch"}), 500

if __name__ == "__main__":
   app.run(debug=True, port=5001)

This implementation showcases several key aspects of LoRAX's heterogeneous continuous batching:

  1. Batch of Tasks: The Flask server receives a batch of tasks as a JSON payload. Each task includes a prompt, an optional adapter ID, and the maximum number of tokens to generate.
  2. LoRAX Client: The server uses the lorax-client library to communicate with the LoRAX server, abstracting away the complexities of heterogeneous batching.
  3. Heterogeneous Batching: Notice that the server doesn't need to filter or sort prompts by adapter ID. LoRAX handles this internally, dynamically grouping tasks based on available resources and efficiently managing adapter loading.
  4. Dynamic Adapter Loading: If an adapter specified in a request isn't already loaded, LoRAX will download and load it on-demand, allowing for efficient use of GPU memory.
  5. Asynchronous Processing: The server processes each prompt in the batch asynchronously, allowing for efficient handling of multiple requests with different adapters.

Testing with curl

To test this implementation, you can use a curl command like this:

curl -X POST -H "Content-Type: application/json" \
    -d '{"batchId": "10001", "data": [
        {
            "prompt": "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]",
            "adapter_id": "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
            "max_new_tokens": 64
        },
        {
            "prompt": "[INST] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/INST]",
            "adapter_id": "ai2sql/ai2sql_mistral_7b",
            "max_new_tokens": 128
        },
        {
            "prompt": "[INST] What is the capital of France? Provide a brief history. [/INST]",
            "adapter_id": "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
            "max_new_tokens": 128
        }
    ]}' \
    http://localhost:5001/lorax/upload

This curl command sends a POST request to the Flask server's /lorax/upload endpoint with a batch of three prompts. The prompts are varied and include both math and SQL tasks, each specifying a different LoRA adapter to use.

LoRAX's heterogeneous continuous batching shines in this scenario. It efficiently handles the diverse set of tasks, potentially loading different adapters as needed, and processes them concurrently. This approach significantly improves throughput and maintains low latency, even when dealing with a mix of task types and adapters.

By leveraging LoRAX and its implementation of heterogeneous continuous batching, we can efficiently serve multiple fine-tuned LLMs in production, overcoming the challenges of traditional batching methods and maximizing GPU utilization.

References

  1. Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., ... & Chen, W. (2022). LoRA: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685.
  2. Yu, C., Han, S., Shen, H., Gao, Y., & Li, J. (2022). PaLM-Coder: Improving Large Language Model Based Program Synthesis Through Batching and Speculative Execution. arXiv preprint arXiv:2212.08272.
  3. Chen, Z., Jiang, Y., Luo, Y., Liu, X., Ji, S., & Gong, Z. (2023). LoRAX: A High-Performance Multi-Tenant LoRA Inference Server. arXiv preprint arXiv:2311.03285.
  4. Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., ... & Amodei, D. (2020). Language models are few-shot learners. arXiv preprint arXiv:2005.14165.
  5. Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, M. A., Lacroix, T., ... & Lample, G. (2023). LLaMA: Open and Efficient Foundation Language Models. arXiv preprint arXiv:2302.13971.
  6. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems (pp. 5998-6008).
  7. Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., ... & Liu, P. J. (2020). Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of Machine Learning Research, 21(140), 1-67.
  8. Dettmers, T., Pagnoni, A., Holtzman, A., & Zettlemoyer, L. (2023). QLoRA: Efficient Finetuning of Quantized LLMs. arXiv preprint arXiv:2305.14314.
  9. Zhang, S., Roller, S., Goyal, N., Artetxe, M., Chen, M., Chen, S., ... & Pasunuru, R. (2022). OPT: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068.
  10. Hoffmann, J., Borgeaud, S., Mensch, A., Buchatskaya, E., Cai, T., Rutherford, E., ... & Sifre, L. (2022). Training compute-optimal large language models. arXiv preprint arXiv:2203.15556.
  11. Lequn Chen, Zihao Ye, Yongji Wu, Danyang Zhuo, Luis Ceze, & Arvind Krishnamurthy. (2023). Punica: Multi-Tenant LoRA Serving.
  12. Justin Zhao, Timothy Wang, Wael Abid, Geoffrey Angus, Arnav Garg, Jeffery Kinnison, Alex Sherstinsky, Piero Molino, Travis Addair, & Devvret Rishi. (2024). LoRA Land: 310 Fine-tuned LLMs that Rival GPT-4, A Technical Report.