1 unstable release
new 0.1.0 | Apr 6, 2025 |
---|
#47 in #strategy
49 downloads per month
79KB
911 lines
LLM Router
A high-performance, Rust-based load balancer and router specifically designed for Large Language Model (LLM) APIs. It intelligently distributes requests across multiple backend LLM API instances based on configurable strategies, health checks, and model capabilities.
Core Concepts
- Router: The central component that manages backend instances and routing logic.
- Instance: Represents a backend LLM API endpoint (e.g.,
https://api.openai.com/v1
). Each instance has an ID, base URL, health status, and associated models. - ModelInstanceConfig: Defines the specific models (e.g., "gpt-4", "text-embedding-ada-002") and their capabilities (Chat, Embedding, Completion) supported by an instance.
- Routing Strategy: Determines how the router selects the next instance for a request (Round Robin or Load Based).
- Health Checks: Periodically checks the availability of backend instances. Unhealthy instances are temporarily removed from rotation.
- Request Tracking: Automatically manages the count of active requests for each instance when using the
LoadBased
strategy. TheRequestTracker
utility simplifies this.
Features
- Multiple Routing Strategies: Load-based or Round Robin distribution.
- Automatic Health Checks: Continuously monitors backend health via configurable endpoints.
- Model Capability Support: Route requests based on the specific model and capability (chat, embedding, completion) required.
- Instance Timeout: Automatically quarantine instances that return errors for a configurable period.
- High Throughput: Efficiently handles thousands of instance selections per second.
- Low Overhead: Adds minimal latency (microseconds) to the instance selection process.
- Dynamic Instance Management: Add or remove backend instances at runtime without service interruption.
- Resilient Error Handling: Gracefully handles backend failures and timeouts.
Installation
Add the dependency to your Cargo.toml
:
[dependencies]
llm_router_core = "0.1.0"
# Add other dependencies for your application (e.g., tokio, reqwest, axum)
tokio = { version = "1", features = ["full"] }
reqwest = { version = "0.11", features = ["json"] }
axum = "0.7"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Quick Start
use llm_router_core::{
Router, RequestTracker,
types::{ModelCapability, ModelInstanceConfig, RoutingStrategy} // Updated import
};
use std::time::Duration;
use std::sync::Arc; // Required for sharing Router
#[tokio::main]
async fn main() {
// Create the router configuration
let router_config = Router::builder()
.strategy(RoutingStrategy::RoundRobin)
.instance_with_models(
"openai_instance1",
"https://api.openai.com/v1", // Replace with your actual endpoint if different
vec![
ModelInstanceConfig {
model_name: "gpt-4".to_string(),
capabilities: vec![ModelCapability::Chat],
},
ModelInstanceConfig {
model_name: "text-embedding-ada-002".to_string(),
capabilities: vec![ModelCapability::Embedding],
},
]
)
.instance_with_models(
"openai_instance2",
"https://api.openai.com/v1", // Replace with your actual endpoint if different
vec![
ModelInstanceConfig {
model_name: "gpt-3.5-turbo".to_string(),
capabilities: vec![ModelCapability::Chat],
}
]
)
.health_check_path("/health") // Optional: Define if your API has a health endpoint
.health_check_interval(Duration::from_secs(30))
.instance_timeout_duration(Duration::from_secs(60)) // Timeout unhealthy instances for 60s
.build();
// Wrap the router in an Arc for sharing across threads/tasks
let router = Arc::new(router_config);
// --- Example: Selecting an instance ---
let model_name = "gpt-4";
let capability = ModelCapability::Chat;
match router.select_instance_for_model(model_name, capability).await {
Ok(instance) => {
println!(
"Selected instance for '{}' ({}): {} ({})",
model_name, capability, instance.id, instance.base_url
);
// Use RequestTracker for automatic request counting (especially with LoadBased strategy)
let _tracker = RequestTracker::new(Arc::clone(&router), instance.id.clone());
// ---> Place your API call logic here <---
// Example: Construct the URL
let api_url = format!("{}/chat/completions", instance.base_url);
println!("Constructed API URL: {}", api_url);
// Use your HTTP client (e.g., reqwest) to send the request to api_url
// Remember to handle potential errors from the API call itself
// If an error occurs, consider calling router.timeout_instance(&instance.id).await
}
Err(e) => eprintln!(
"Error selecting instance for '{}' ({}): {}",
model_name, capability, e
),
}
}
Authentication
Most LLM APIs require authentication (e.g., API keys). The llm-router
itself doesn't handle authentication headers directly during routing or health checks by default. You need to manage authentication in your application's HTTP client when making the actual API calls after selecting an instance.
If your health checks require authentication, you can provide a pre-configured reqwest::Client
to the Router::builder
.
use llm_router_core::{Router, types::{ModelInstanceConfig, ModelCapability, RoutingStrategy}};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use std::time::Duration;
use std::sync::Arc;
async fn setup_router_with_auth() -> Result<Arc<Router>, Box<dyn std::error::Error>> {
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set");
// Configure a reqwest client with default auth headers (useful for health checks)
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", api_key))?,
);
let client = reqwest::Client::builder()
.default_headers(headers)
.timeout(Duration::from_secs(10)) // Set a timeout for HTTP requests
.build()?;
let router_config = Router::builder()
.strategy(RoutingStrategy::LoadBased)
.instance_with_models(
"authed_instance",
"https://api.openai.com/v1",
vec![ModelInstanceConfig {
model_name: "gpt-4".to_string(),
capabilities: vec![ModelCapability::Chat],
}],
)
.http_client(client) // Provide the pre-configured client
.health_check_interval(Duration::from_secs(60))
.build();
Ok(Arc::new(router_config))
}
// Remember: When making the actual API call *after* selecting an instance,
// you still need to ensure your request includes the necessary authentication.
// The client passed to the builder is primarily for health checks.
Detailed Usage Guide
Choosing a Routing Strategy
The router supports two strategies set via Router::builder().strategy(...)
:
-
Round Robin (
RoutingStrategy::RoundRobin
):- Distributes requests sequentially across all healthy, capable instances.
- Simple, predictable, and often provides the best raw throughput if backends are homogenous.
- Default strategy if not specified.
-
Load Based (
RoutingStrategy::LoadBased
):- Routes to the healthy, capable instance with the fewest currently active requests.
- Requires using
RequestTracker
(recommended) or manually callingincrement_request_count
anddecrement_request_count
to be effective. - Helps balance load when backend processing times vary significantly or when requests might be long-running.
- Can lead to more consistent latency across requests.
Managing Backend Instances (Dynamic Updates)
You can modify the router's instance pool after it has been built. This is useful for scaling or maintenance.
# use llm_router_core::{Router, types::{ModelCapability, ModelInstanceConfig, RoutingStrategy}};
# use std::sync::Arc;
# use std::time::Duration;
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
// Add a new instance dynamically
router.add_instance_with_models(
"new_instance",
"https://api.another-provider.com/v1",
vec![
ModelInstanceConfig {
model_name: "claude-3".to_string(),
capabilities: vec![ModelCapability::Chat],
}
]
).await?;
println!("Added new_instance");
// Add a new model/capability to an existing instance
router.add_model_to_instance(
"new_instance",
"claude-3-opus".to_string(),
vec![ModelCapability::Chat]
).await?;
println!("Added claude-3-opus to new_instance");
// Remove an instance
match router.remove_instance("openai_instance1").await {
Ok(_) => println!("Removed openai_instance1"),
Err(e) => eprintln!("Failed to remove instance: {}", e),
}
// Get status of all instances
let instances_status = router.get_instances().await;
println!("
Current Instance Status:");
for instance_info in instances_status {
println!(
"- ID: {}, URL: {}, Status: {:?}, Active Requests: {}, Models: {:?}",
instance_info.id,
instance_info.base_url,
instance_info.status,
instance_info.active_requests, // Only relevant for LoadBased
instance_info.models.keys().collect::<Vec<_>>()
);
}
# Ok(())
# }
Selecting an Instance
The primary way to get a suitable backend URL is by requesting an instance for a specific model and capability.
# use llm_router_core::{Router, RequestTracker, types::{ModelCapability, RoutingStrategy}};
# use std::sync::Arc;
# use std::time::Duration;
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
let model_name = "gpt-3.5-turbo";
let capability = ModelCapability::Chat;
match router.select_instance_for_model(model_name, capability).await {
Ok(instance) => {
println!("Selected instance for {} ({}): {} ({})", model_name, capability, instance.id, instance.base_url);
// Use RequestTracker to ensure load balancing works correctly if using LoadBased strategy
let _tracker = RequestTracker::new(router.clone(), instance.id.clone());
// Now, make the API call to instance.base_url using your HTTP client...
}
Err(e) => {
eprintln!("Could not find a healthy instance for {} ({}): {}", model_name, capability, e);
// Handle the error (e.g., return an error response to the user)
}
}
# Ok(())
# }
Alternatively, if you don't need a specific model and just want the next instance according to the strategy:
# use llm_router_core::{Router, RequestTracker, types::RoutingStrategy};
# use std::sync::Arc;
# use std::time::Duration;
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
match router.select_next_instance().await {
Ok(instance) => {
println!("Selected next instance (any model/capability): {} ({})", instance.id, instance.base_url);
let _tracker = RequestTracker::new(router.clone(), instance.id.clone());
// Make API call... (Be aware this instance might not support the specific model you need)
}
Err(e) => {
eprintln!("Could not select next instance: {}", e);
}
}
# Ok(())
# }
Using RequestTracker
(Important for Load Balancing)
When using the LoadBased
strategy, the router needs to know how many requests are currently in flight to each instance. The RequestTracker
utility handles this automatically using RAII (Resource Acquisition Is Initialization).
# use llm_router_core::{Router, RequestTracker, types::{ModelCapability, RoutingStrategy}};
# use std::sync::Arc;
# use std::time::Duration;
# async fn make_api_call(url: &str) -> Result<(), &'static str> { /* ... */ Ok(()) }
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
let instance = router.select_instance_for_model("gpt-4", ModelCapability::Chat).await?;
// Create the tracker immediately after selecting the instance
let tracker = RequestTracker::new(router.clone(), instance.id.clone());
// Perform the API call or other work associated with this instance
println!("Making API call to {}", instance.base_url);
match make_api_call(&instance.base_url).await {
Ok(_) => println!("API call successful"),
Err(e) => {
eprintln!("API call failed: {}", e);
// If the call fails, consider putting the instance in timeout
router.timeout_instance(&instance.id).await?;
println!("Instance {} put into timeout due to error.", instance.id);
}
}
// When `tracker` goes out of scope here (end of function, or earlier block),
// it automatically decrements the request count for the instance.
println!("Request finished, tracker dropped.");
# Ok(())
# }
If you don't use RequestTracker
, you must manually call router.increment_request_count(&instance.id)
before the request and router.decrement_request_count(&instance.id)
after the request (including in error cases) for LoadBased
routing to function correctly. RequestTracker
is strongly recommended.
Handling Errors and Timeouts
If an API call to a selected instance fails, you might want to temporarily mark that instance as unhealthy to prevent routing further requests to it for a while.
# use llm_router_core::{Router, RequestTracker, types::{ModelCapability, RoutingStrategy}};
# use std::sync::Arc;
# use std::time::Duration;
# async fn make_api_call(url: &str) -> Result<(), &'static str> { Err("Simulated API Error") }
#
# async fn example(router: Arc<Router>) -> Result<(), Box<dyn std::error::Error>> {
let instance = router.select_instance_for_model("gpt-4", ModelCapability::Chat).await?;
let tracker = RequestTracker::new(router.clone(), instance.id.clone());
match make_api_call(&instance.base_url).await {
Ok(_) => { /* Process success */ },
Err(api_error) => {
eprintln!("API Error from instance {}: {}", instance.id, api_error);
// Put the instance in timeout
match router.timeout_instance(&instance.id).await {
Ok(_) => println!("Instance {} placed in timeout.", instance.id),
Err(e) => eprintln!("Error placing instance {} in timeout: {}", instance.id, e),
}
// Return an appropriate error to the caller
return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "API call failed")));
}
}
# Ok(())
# }
The instance will remain in the Timeout
state for the duration specified by instance_timeout_duration
in the builder, after which the health checker will attempt to bring it back online.
Health Checks Configuration
Configure health checks using the builder:
# use llm_router_core::{Router, types::RoutingStrategy};
# use std::time::Duration;
#
let router = Router::builder()
// ... other configurations ...
.health_check_path("/health") // The endpoint path for the health check (e.g., GET <base_url>/health)
.health_check_interval(Duration::from_secs(15)) // Check health every 15 seconds
.health_check_timeout(Duration::from_secs(5)) // Timeout for the health check request itself (5 seconds)
.instance_timeout_duration(Duration::from_secs(60)) // How long an instance stays in Timeout state (60 seconds)
.build();
- If
health_check_path
is not set, instances are initially consideredHealthy
and only move toTimeout
iftimeout_instance
is called. - The health checker sends a
GET
request to<instance.base_url><health_check_path>
. A2xx
status code marks the instance asHealthy
. Any other status or a timeout marks it asUnhealthy
. - Instances in
Timeout
state are not checked until the timeout duration expires.
Axum Web Server Integration Example
Here's how to integrate llm-router
into an Axum web server to act as a proxy/gateway to your LLM backends.
use axum::{
extract::{Json, State},
http::{StatusCode, Uri},
response::{IntoResponse, Response},
routing::post,
Router as AxumRouter,
};
use llm_router_core::{
Router, RequestTracker,
types::{ModelCapability, ModelInstanceConfig, RoutingStrategy, RouterError},
};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
// Define your request/response structures (matching the target LLM API)
#[derive(Serialize, Deserialize, Debug)]
struct ChatRequest {
model: String,
messages: Vec<serde_json::Value>, // Example structure
// ... other fields like temperature, max_tokens, etc.
}
#[derive(Serialize, Deserialize, Debug)]
struct ChatResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<serde_json::Value>,
usage: serde_json::Value,
}
// Application state
struct AppState {
router: Arc<Router>,
http_client: Client, // Use a shared reqwest client
}
#[tokio::main]
async fn main() {
// --- Router Configuration ---
let router_config = Router::builder()
.strategy(RoutingStrategy::LoadBased) // Example: Use LoadBased
.instance_with_models(
"openai_1",
"https://api.openai.com/v1", // Replace with actual URL
vec![
ModelInstanceConfig::new("gpt-4", vec![ModelCapability::Chat]),
ModelInstanceConfig::new("gpt-3.5-turbo", vec![ModelCapability::Chat]),
ModelInstanceConfig::new("text-embedding-ada-002", vec![ModelCapability::Embedding]),
],
)
.instance_with_models(
"openai_2", // Perhaps using a different key or region
"https://api.openai.com/v1", // Replace with actual URL
vec![
ModelInstanceConfig::new("gpt-4", vec![ModelCapability::Chat]),
],
)
.health_check_path("/v1/models") // OpenAI's model list endpoint can serve as a basic health check
.health_check_interval(Duration::from_secs(60))
.instance_timeout_duration(Duration::from_secs(120))
.build();
let shared_router = Arc::new(router_config);
let shared_http_client = Client::new(); // Create a single reqwest client
let app_state = Arc::new(AppState {
router: shared_router,
http_client: shared_http_client,
});
// --- Axum Setup ---
let app = AxumRouter::new()
.route("/v1/chat/completions", post(chat_completions_handler))
// Add other routes for embeddings, completions etc.
.with_state(app_state);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("🚀 LLM Router Gateway listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn chat_completions_handler(
State(state): State<Arc<AppState>>,
Json(payload): Json<ChatRequest>,
) -> Result<impl IntoResponse, AppError> {
println!("Received chat request for model: {}", payload.model);
// 1. Select an instance capable of handling the request
let instance = state
.router
.select_instance_for_model(&payload.model, ModelCapability::Chat)
.await?; // Use ? to convert RouterError into AppError
println!("Selected instance: {} ({})", instance.id, instance.base_url);
// 2. Use RequestTracker for load balancing (if using LoadBased) and timeout handling
let _tracker = RequestTracker::new(state.router.clone(), instance.id.clone());
// 3. Construct the target URL
let target_url_str = format!("{}/v1/chat/completions", instance.base_url); // Assuming OpenAI path
let target_url = target_url_str.parse::<Uri>().map_err(|_| {
AppError::Internal(format!("Failed to parse target URL: {}", target_url_str))
})?;
// --- Authentication ---
// IMPORTANT: Add authentication headers here. Get the key from secure storage/config.
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
AppError::Internal("OPENAI_API_KEY environment variable not set".to_string())
})?;
let auth_header_value = format!("Bearer {}", api_key);
// 4. Proxy the request using the shared http_client
let response = state.http_client
.post(target_url.to_string())
.header("Authorization", auth_header_value) // Add the auth header
.json(&payload) // Forward the original payload
.send()
.await
.map_err(|e| AppError::BackendError(instance.id.clone(), e.to_string()))?;
// 5. Handle the response from the backend
let backend_status = response.status();
let response_bytes = response.bytes().await.map_err(|e| AppError::BackendError(instance.id.clone(), e.to_string()))?;
if !backend_status.is_success() {
eprintln!(
"Backend error from instance {}: Status: {}, Body: {:?}",
instance.id,
backend_status,
String::from_utf8_lossy(&response_bytes)
);
// If the backend failed, put the instance into timeout
let _ = state.router.timeout_instance(&instance.id).await; // Ignore error during timeout
return Err(AppError::BackendError(
instance.id.clone(),
format!("Status: {}, Body: {:?}", backend_status, String::from_utf8_lossy(&response_bytes)),
));
}
// 6. Forward the successful response (potentially deserialize/re-serialize if needed)
// Here, we forward the raw bytes and original status code/headers
let mut response_builder = Response::builder().status(backend_status);
// Copy relevant headers if necessary (e.g., Content-Type)
if let Some(content_type) = response.headers().get(reqwest::header::CONTENT_TYPE) {
response_builder = response_builder.header(reqwest::header::CONTENT_TYPE, content_type);
}
let response = response_builder
.body(axum::body::Body::from(response_bytes))
.map_err(|e| AppError::Internal(format!("Failed to build response: {}", e)))?;
Ok(response)
}
// Custom Error type for Axum handler
enum AppError {
RouterError(RouterError),
BackendError(String, String), // instance_id, error message
Internal(String),
}
impl From<RouterError> for AppError {
fn from(err: RouterError) -> Self {
AppError::RouterError(err)
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AppError::RouterError(e) => {
eprintln!("Router error: {}", e);
// Handle specific RouterErrors differently if needed
match e {
RouterError::NoHealthyInstances(_) => (
StatusCode::SERVICE_UNAVAILABLE,
format!("No healthy backend instances available: {}", e),
),
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Internal routing error: {}", e),
),
}
},
AppError::BackendError(instance_id, msg) => {
eprintln!("Backend error from instance {}: {}", instance_id, msg);
(
StatusCode::BAD_GATEWAY, // 502 suggests an issue with the upstream server
format!("Error from backend instance {}: {}", instance_id, msg),
)
},
AppError::Internal(msg) => {
eprintln!("Internal server error: {}", msg);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Internal server error: {}", msg),
)
}
};
(status, Json(serde_json::json!({ "error": error_message }))).into_response()
}
}
To run this Axum example:
- Make sure you have the dependencies in your
Cargo.toml
(axum
,tokio
,reqwest
,serde
,serde_json
,llm_router_core
). - Save the code as a Rust file (e.g.,
src/main.rs
). - Set the necessary environment variables (like
OPENAI_API_KEY
). - Run the application:
cargo run
- Send a POST request (e.g., using
curl
or Postman) tohttp://127.0.0.1:3000/v1/chat/completions
with a JSON body matching theChatRequest
structure.
Benchmarking
The crate includes benchmarks to measure the performance of the instance selection logic.
1. Run the Benchmarks:
Execute the standard Cargo benchmark command. This will run the selection logic repeatedly for different numbers of instances and routing strategies.
cargo bench
This command compiles the code in release mode with benchmarking enabled and runs the functions annotated with #[bench]
. The output will show time-per-iteration results, but it's easier to analyze with the reporter.
2. Generate the Report:
After running cargo bench
, the raw results are typically stored in target/criterion/
. A helper binary is provided to parse these results and generate a user-friendly report and a plot.
cargo run --bin bench_reporter
This command runs the bench_reporter
binary located in src/bin/bench_reporter.rs
. It will:
- Parse the benchmark results generated by
cargo bench
. - Print a summary table to the console showing the median selection time (in microseconds or nanoseconds) for each strategy and instance count combination.
- Generate simple text-based plots in the console.
- Save a graphical plot comparing the scaling of RoundRobin and LoadBased strategies to a file named
benchmark_scaling.png
in the project's root directory.
Example Output:
--- LLM Router Benchmark Report ---
Found result for RoundRobin/10: 1745.87 ns
Found result for RoundRobin/25: 3960.23 ns
... (more results) ...
Found result for LoadBased/100: 14648.41 ns
--- Summary Table ---
+------------+---------------+---------------------------+
| Strategy | Instances (N) | Median Time per Selection |
+------------+---------------+---------------------------+
| RoundRobin | 10 | 1.75 µs |
| LoadBased | 10 | 1.75 µs |
... (more rows) ...
| RoundRobin | 100 | 15.15 µs |
| LoadBased | 100 | 14.65 µs |
+------------+---------------+---------------------------+
--- Performance Scaling Plot (RoundRobin) ---
Time (Median)
N=10 | 1.75 µs |
...
N=100 | 15.15 µs |
+------------------------------------------+
Instances (N) -->
... (LoadBased Plot) ...
--- Plot saved to benchmark_scaling.png ---
The benchmark_scaling.png
file provides a visual comparison of how the selection time increases as the number of backend instances grows for both routing strategies. This helps understand the minimal overhead added by the router.
Performance Considerations
- Selection Overhead: Benchmarks show that the core instance selection logic is extremely fast, typically taking only a few microseconds even with a hundred instances. This overhead is negligible compared to the network latency and processing time of actual LLM API calls.
- Throughput: The router itself is not typically the bottleneck. Throughput is limited by the capacity of your backend LLM instances and network conditions.
RoundRobin
vs.LoadBased
:RoundRobin
has slightly lower overhead as it doesn't need to check active request counts.LoadBased
provides better load distribution if backend performance varies, potentially leading to more consistent end-to-end latency, at the cost of slightly higher selection overhead (though still in microseconds).
- Health Checks: Health checks run in the background and do not block request routing. Ensure your health check endpoint is lightweight. Frequent or slow health checks can consume resources.
License
MIT
Dependencies
~22–35MB
~498K SLoC