1 How to Call an External LLM API at Scale
Looping over rows and calling an LLM API one at a time works on a hundred rows. On ten million, it falls apart: timeouts accumulate, 429s crash the run, and you end up with 3.7 million rows and no clear picture of what failed.
The fix is to treat this as a streaming data engineering problem. This post shows a PySpark Structured Streaming pattern built on three pillars provider portability, layered throughput, and built-in resilience that together produce a pipeline that is fast, fault-tolerant, and trivially swappable across any OpenAI-compatible endpoint.
1.1 Pillar 1 Decouple From Any Single Provider
All major LLM providers (OpenAI, Mistral, Anthropic, OpenRouter) expose the same OpenAI-compatible REST API. Only three values change between them: the base URL, the model name, and the API key. Represent those as top-level constants and switching providers becomes a small change. When going to prod use something like Databricks Secret Scopes to store the API key.
API_BASE_URL = "https://openrouter.ai/api/v1"
MODEL = "ministral-3b-2512"
API_KEY = "your-api-key"1.2 Pillar 2 Maximise Throughput at Every Layer
Three mechanisms stack on top of each other:
1.2.1 Streaming reads + checkpointing
readStream processes data as incremental micro-batches. Spark writes the offset of every completed batch to the checkpoint directory, so a restarted job resumes exactly where it stopped no reprocessing, no skipped rows. availableNow=True drains all available input and stops cleanly, making it suitable for scheduled jobs.
1.2.2 Pandas UDFs + ThreadPoolExecutor
LLM calls are I/O-bound, so threading gives near-linear speedup. A Pandas UDF fans out MAX_WORKERS concurrent HTTP calls per partition. The UDF and its imports must be defined inside the foreachBatch handler.
1.2.3 Connection pooling + coalesce()
One requests.Session per partition, pool-sized to MAX_WORKERS, means threads share TCP connections rather than opening a new handshake per call. coalesce(n) caps total cluster-wide concurrency to n × MAX_WORKERS tune this against your provider’s rate limit.
max_concurrent_calls = coalesce(n) × MAX_WORKERS
| Parameter | Controls | Start here |
|---|---|---|
maxBytesPerTrigger |
Rows per micro-batch | target_rows × avg_row_size |
coalesce(n) |
Parallel partitions | 4–8 |
MAX_WORKERS |
Threads per partition | 8–20 |
Rule of thumb: Start at 50% of your provider’s rate limit. Increase until you see 429s, then back off 20%.
1.3 Pillar 3 Resilience at Every Layer
1.3.1 Handeling retries with tenacity
tenacity retries handle transient failures (connection errors, 5xx responses, 429s) with exponential backoff, capped at 5 attempts and 60 seconds between retries. RateLimitError is a custom exception raised on 429s, giving tenacity a typed handle rather than relying on generic HTTP error detection. It must be defined inside the UDF closure for the same serialisation reason as the imports.
1.3.2 Wrapping call_api
safe_call wraps _call_api in a broad except and returns f"ERROR: {e}" on failure. One bad row becomes a tagged string in the output column not a partition crash that forces Spark to retry the entire micro-batch.
1.3.3 Idempotent Delta writes
Idempotent Delta writes use txnAppId (stable across all runs) + txnVersion (the micro-batch ID) as a deduplication key. If a batch is retried after an interruption, Delta silently skips the duplicate write. APP_ID must be a constant if it changes between runs, the key breaks.
1.3.4 Post-run verification
Post-run verification requires only a filter:
df_errors = spark.read.table(TARGET_TABLE).filter(F.col("response").startswith("ERROR:"))Failures are queryable rows, not buried log lines.
1.4 Complete Pipeline
import pandas as pd
import requests
from pyspark.sql import functions as F # noqa: N812
from requests.adapters import HTTPAdapter
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
# ---------------------------------------------------------------------------
# Configuration swap these for any OpenAI-compatible endpoint
# ---------------------------------------------------------------------------
API_BASE_URL = "https://openrouter.ai/api/v1"
MODEL = "ministral-3b-2512"
API_KEY = "your-api-key"
MAX_WORKERS = 8
INPUT_COLUMN = "review"
INSTRUCTIONS = "Summarize the following review in one sentence."
# APP_ID + batch_id together form an idempotency key if a batch is reprocessed
# after a failure, Delta skips the duplicate write instead of creating duplicates
APP_ID = "streaming_api_call"
SOURCE_TABLE = "catalog.schema.source_table"
TARGET_TABLE = "catalog.schema.target_table"
CHECKPOINT_PATH = "/path/to/checkpoint/streaming_api_call"
# ---------------------------------------------------------------------------
# Step 1: Open a streaming read Structured Streaming tracks progress via
# checkpoints, so only new files are processed on each run
# ---------------------------------------------------------------------------
df_stream = spark.readStream.table(SOURCE_TABLE) # noqa: F821
# ---------------------------------------------------------------------------
# Step 2: foreachBatch process each micro-batch as a static DataFrame so we
# can use Pandas UDFs and idempotent Delta writes
# ---------------------------------------------------------------------------
def process_batch(batch_df, batch_id):
# Local PySpark imports keeps them out of the module-level pickle graph.
# cloudpickle serialises process_batch + every nested code object; any
# module-level pyspark reference (F, T, DataFrame) can carry Spark Connect
# session state and trigger STREAMING_CONNECT_SERIALIZATION_ERROR.
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import types as T # noqa:N812
# Structured Streaming can dispatch empty micro-batches; bail early to
# avoid unnecessary API session setup
if batch_df.isEmpty():
return
@F.pandas_udf(T.StringType())
def call_api(prompts: pd.Series) -> pd.Series:
"""Pandas UDF vectorised interface lets us batch rows per partition
and fan out with threads, avoiding the overhead of one API call per row."""
from concurrent.futures import ThreadPoolExecutor, as_completed
# Defined inside the UDF so it lives in the serialised closure sent to workers
class RateLimitError(Exception):
pass
# Session is created once per partition reuses TCP connections
session = requests.Session()
# Match pool size to thread count so no thread blocks waiting for a connection
adapter = HTTPAdapter(pool_connections=MAX_WORKERS, pool_maxsize=MAX_WORKERS)
session.mount("https://", adapter)
# Nested inside call_api so it closes over `session` and gets serialised
# into the same closure Spark sends to workers no globals needed
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=2, max=60),
retry=retry_if_exception_type(
(
requests.ConnectionError,
requests.Timeout,
requests.HTTPError,
RateLimitError,
)
),
reraise=True,
)
def _call_api(prompt: str) -> str:
resp = session.post(
f"{API_BASE_URL}/chat/completions",
headers={"Authorization": f"Bearer {API_KEY}"},
json={
"model": MODEL,
"messages": [
{"role": "user", "content": f"{INSTRUCTIONS}\n\n{prompt}"},
],
"max_tokens": 512,
"temperature": 0.3,
},
timeout=600,
)
# 429 gets a custom exception for targeted retry; 5xx errors bubble
# up as HTTPError, also retried by tenacity
if resp.status_code == 429:
raise RateLimitError(f"Rate limited (429): {resp.text}")
resp.raise_for_status()
data = resp.json()
try:
return data["choices"][0]["message"]["content"]
except (KeyError, IndexError) as e:
raise ValueError(f"Unexpected response structure: {data}") from e
# Catch-all wrapper: converts exceptions to error strings so one failed
# row doesn't crash the entire partition
def safe_call(prompt: str) -> str:
try:
return _call_api(prompt)
except Exception as e:
return f"ERROR: {e}"
# Pre-allocate to preserve input ordering as_completed returns
# futures in arbitrary finish order
results = [None] * len(prompts)
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
future_to_idx = {
executor.submit(safe_call, prompt): idx
for idx, prompt in enumerate(prompts)
}
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
results[idx] = future.result()
return pd.Series(results)
# Reduce partitions to limit concurrent API connections across the cluster
# tune this to stay within rate limits
enriched_df = batch_df.coalesce(4).withColumn(
"response", call_api(F.col(INPUT_COLUMN))
)
(
enriched_df.write.format("delta")
.mode("append")
# Idempotent write Delta uses (txnAppId, txnVersion) to deduplicate
# if a batch is retried
.option("txnVersion", batch_id)
.option("txnAppId", APP_ID)
.saveAsTable(TARGET_TABLE)
)
print(f"Batch {batch_id}: complete")
# ---------------------------------------------------------------------------
# Step 3: Start the streaming query
# ---------------------------------------------------------------------------
query = (
df_stream.writeStream.foreachBatch(process_batch)
.option("checkpointLocation", CHECKPOINT_PATH)
# Cap micro-batch size soft max on bytes read per micro-batch.
# Tune this to control how many rows hit the API at once:
# target_rows × avg_row_size ≈ byte limit
# e.g. 50 000 rows × ~1 KB each ≈ 50 MB
.option("maxBytesPerTrigger", "50mb")
.trigger(availableNow=True) # runs all available files, and then stops
.start()
)
query.awaitTermination()
# ---------------------------------------------------------------------------
# Step 4: Verify results read back the target table after the stream finishes
# to confirm row counts and surface any API failures
# ---------------------------------------------------------------------------
df_result = spark.read.table(TARGET_TABLE) # noqa: F821
print(f"Total enriched rows: {df_result.count()}")
df_result.show(5, truncate=False)
df_errors = df_result.filter(F.col("response").startswith("ERROR:"))
error_count = df_errors.count()
if error_count > 0:
print(f"WARNING: {error_count} rows with errors")1.5 The Three Pillars Work as a Unit
- Use this pattern when enriching hundreds of thousands of rows or more, when multi-provider portability matters, or when the output needs to integrate with an existing Spark/Delta ecosystem.
- Use a simpler approach (plain
asyncio+httpx) when your data fits in memory and you need results quickly. - Use a managed batch API (OpenAI Batch, Anthropic Message Batches) when cost-per-token matters most managed batching is cheaper and requires zero infrastructure.
1.6 Conclusion
The three pillars are a single composable decision: portability means you can swap providers; layered throughput means you run at quota limit, not loop speed; built-in resilience means failures are recoverable rows, not lost data. Remove any one pillar and the other two become significantly less useful. Build all three from the start.
Tested on PySpark 3.5+, Delta Lake 3.x, requests 2.x, tenacity 8.x, Databricks serverless env 4. The local-import + closure-scoped UDF pattern is required for Spark Connect; on legacy cluster modes, module-level definitions also work.