Using DSPy with Ray Data

DSPy is the best framework I know for auto-tuning LLM prompts. Instead of hand-crafting prompt strings, you define input/output specs that you chain together. You optimize these using metric functions against a train/test set, just as you would a regular ML algorithm. Pretty neat!

One limitation with DSPy is scalability. In this post, I'll show how to adapt DSPy modules so they work inside Ray Data pipelines, enabling you to scale DSPy to hundreds of thousands of documents. I'll assume you are familiar with DSPy and Ray Data already.

To demonstrate, I'll build a pipeline that triages 911 calls. First, some toy data:

calls = [
    "There's a raccoon in my kitchen and it's eating my cereal. I think it has a gun.",
    "Ahhhhhh I'm stuck in the toilet",
    "My roommate will not stop playing bagpipes at 3am.",
    "Help, I'm chronically stinky.",
]

DSPy forces you into its own class hierarchy (signatures, modules, predictors), which aren't directly compatible with Ray Data. You want to reuse the same module classes for both production inference (in Ray Data) and prompt optimization (in DSPy's optimizers) to ensure everything is consistent. So I like to keep the DSPy modules pure, and use adapter classes for Ray Data.

Here's a DSPy signature and module to classify the 911 transcripts:

import dspy

# Signatures are used to define the input/output structure
class TriageCall(dspy.Signature):
    """Triage an incoming 911 call."""
    transcript: str = dspy.InputField(desc="the caller's transcript")
    priority: str = dspy.OutputField(desc="one of: urgent, non_urgent, not_an_emergency")
    category: str = dspy.OutputField(desc="one of: crime, medical, fire, animal, noise, other")
    summary: str = dspy.OutputField(desc="one-sentence summary for the dispatcher")

# The module does the actual extraction/classification
class CallTriager(dspy.Module):
    def __init__(self):
        self.triage = dspy.ChainOfThought(TriageCall)

    def forward(self, transcript: str):
        return self.triage(transcript=transcript)

These are pure DSPy classes, so you can use them for optimization without any problems.

Now for the Ray Data adapter. DSPy configures the language model globally:

lm = dspy.LM("claude-sonnet-4-6")
dspy.configure(lm=lm)

This is fine in a single process, but doesn't propagate to Ray workers. If you try to use a DSPy module inside a plain map() function, the Ray workers won't have an LM configured and you'll get an error. To avoid repeating expensive LM setup on every row, put dspy.configure() in a callable class so Ray will preserve state in a Ray Actor:

class TriageStage:
    def __init__(self, model: str, optimized_path: str | None = None):
	    # Setup the LM only once
        dspy.configure(lm=dspy.LM(model, max_tokens=1000))
        self.module = CallTriager()
        # Load optimized version of the prompt from previous
        # dspy optimization runs.
        # If not applicable, DSPy will fallback to the default prompts.
        if optimized_path:
            self.module.load(optimized_path)

    def __call__(self, row: dict) -> dict:
        prediction = self.module(transcript=row["item"])
        row["priority"] = prediction.priority
        row["category"] = prediction.category
        row["summary"] = prediction.summary
        return row

Now the LM setup is only done once per Ray worker. Now you can wire TriageState into a Ray Data pipeline using ActorPoolStrategy to control parallelism:

import ray
from ray.data import ActorPoolStrategy

ray.init()

ds = ray.data.from_items(calls)

ds = ds.map(
    TriageStage,
    fn_constructor_kwargs={
        "model": "claude-sonnet-4-6",
        "optimized_path": "optimized/triager.json",
    },
    compute=ActorPoolStrategy(size=4),
)

ds.write_parquet("output/triaged_calls/")

ActorPoolStrategy(size=4) tells Ray to spin up 4 worker actors, each with its own LM instance. For LLM workloads, you typically want many workers since the bottleneck is API latency, not CPU. On the other hand, you don't want so many clients making API requests that you get rate-limited.

To chain multiple extraction stages, use separate actor pools:


ds = ds.map(FooStage, ...)
ds = ds.map(BazStage, ...)

Each stage adds fields to the row dict as it flows through the pipeline. The row accumulates results from all stages before being materialized at the end.

If you have multiple LLM stages, the wrapper classes get repetitive. Especially if you add error handling and rate limiting. To reduce boilerplate, I use a base class:

from pathlib import Path
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception

def _is_rate_limit(e: Exception) -> bool:
    return "rate" in str(e).lower() or getattr(e, "status_code", 0) == 429

class _BaseStage:
    _module_cls = None           # subclass sets this to a dspy.Module class
    _optimized_filename = None
    _output_fields = ()          # prediction fields to copy onto the row

    def __init__(self, model: str, optimized_dir: str | None = None):
        dspy.configure(lm=dspy.LM(model, max_tokens=1000))
        module = self._module_cls()
        if optimized_dir and self._optimized_filename:
            path = Path(optimized_dir) / self._optimized_filename
            if path.exists():
                module.load(str(path))

        # If the output doesn't parse,
        # retry the LLM call with error feedback (up to 3 attempts).
        self.module = dspy.Refine(
            module=module,
            N=3,
            reward_fn=self._reward_output,
            threshold=1.0,
        )

    def _reward_output(self, args: dict, pred) -> float:
        """Return 1.0 if all output fields parse, 0.0 otherwise.

        Subclasses can override this for custom validation.
        """
        for field in self._output_fields:
            raw = getattr(pred, field, None)
            if raw is None or raw.strip() == "":
                return 0.0
        return 1.0

	# If the LLM call fails, retry up to 3 times with exponential back-off
    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=30),
        retry=retry_if_exception(_is_rate_limit),
        reraise=True,
    )
    def _call_llm(self, transcript: str):
        return self.module(transcript=transcript)

    def __call__(self, row: dict) -> dict:
        try:
            pred = self._call_llm(row["item"])
            for field in self._output_fields:
                row[field] = getattr(pred, field)
        except Exception as e:
	        # Accumulate error messages in an 'errors' field
            prev = row.get("errors")
            row["errors"] = f"{prev}; {e}" if prev else str(e)
        return row

There are two layers of retry here. tenacity retries on transient API errors (429s, timeouts). dspy.Refine retries when the LLM output doesn't pass validation (malformed JSON, missing fields). If _reward_output returns below threshold, Refine uses a separate OfferFeedback LLM call to analyze what went wrong. OfferFeedback is given the full execution trace, reward function source code, and the score. This feedback is passed as a hint to the main LLM on the next attempt, which increases the likelihood that it'll fix the problem.

Subclasses declare which module to use, which optimized file to load, and which fields to extract:

class TriageStage(_BaseStage):
    _module_cls = CallTriager
    _optimized_filename = "triager.json"
    _output_fields = ("priority", "category", "summary")

Gotchas

from dotenv import dotenv_values

env_vars = {k: v for k, v in dotenv_values(".env").items() if v is not None}
ray.init(runtime_env={"env_vars": env_vars})

Putting it together

Here's a full, runnable pipeline with two DSPy modules, the base adapter, and a two-stage Ray Data pipeline:

import ray
import dspy
from pathlib import Path
from dotenv import dotenv_values
from ray.data import ActorPoolStrategy
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception

MODEL = "claude-sonnet-4-6"

# --- DSPy modules (no Ray awareness) ---

class TriageCall(dspy.Signature):
    """Triage an incoming 911 call."""
    transcript: str = dspy.InputField(desc="the caller's transcript")
    priority: str = dspy.OutputField(desc="one of: urgent, non_urgent, not_an_emergency")
    category: str = dspy.OutputField(desc="one of: crime, medical, fire, animal, noise, other")
    summary: str = dspy.OutputField(desc="one-sentence summary for the dispatcher")

class CallTriager(dspy.Module):
    def __init__(self):
        self.triage = dspy.ChainOfThought(TriageCall)
    def forward(self, transcript: str):
        return self.triage(transcript=transcript)

class ExtractDetails(dspy.Signature):
    """Extract actionable details from a 911 call."""
    transcript: str = dspy.InputField()
    location: str = dspy.OutputField(desc="reported location, if mentioned")
    subject_description: str = dspy.OutputField(desc="description of person or animal involved")

class DetailExtractor(dspy.Module):
    def __init__(self):
        self.extract = dspy.ChainOfThought(ExtractDetails)
    def forward(self, transcript: str):
        return self.extract(transcript=transcript)

# --- Base adapter (handles Ray Data + retry plumbing) ---

def _is_rate_limit(e: Exception) -> bool:
    return "rate" in str(e).lower() or getattr(e, "status_code", 0) == 429

class _BaseStage:
    _module_cls = None
    _optimized_filename = None
    _output_fields = ()

    def __init__(self, model: str, optimized_dir: str | None = None):
        dspy.configure(lm=dspy.LM(model, max_tokens=1000))
        module = self._module_cls()
        if optimized_dir and self._optimized_filename:
            path = Path(optimized_dir) / self._optimized_filename
            if path.exists():
                module.load(str(path))
        self.module = dspy.Refine(
            module=module, N=3,
            reward_fn=self._reward_output, threshold=1.0,
        )

    def _reward_output(self, args: dict, pred) -> float:
        for field in self._output_fields:
            raw = getattr(pred, field, None)
            if raw is None or raw.strip() == "":
                return 0.0
        return 1.0

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=30),
        retry=retry_if_exception(_is_rate_limit),
        reraise=True,
    )
    def _call_llm(self, transcript: str):
        return self.module(transcript=transcript)

    def __call__(self, row: dict) -> dict:
        try:
            pred = self._call_llm(row["item"])
            for field in self._output_fields:
                row[field] = getattr(pred, field)
        except Exception as e:
            prev = row.get("errors")
            row["errors"] = f"{prev}; {e}" if prev else str(e)
        return row

# --- Stage subclasses (just declarations) ---

class TriageStage(_BaseStage):
    _module_cls = CallTriager
    _optimized_filename = "triager.json"
    _output_fields = ("priority", "category", "summary")

class ExtractDetailsStage(_BaseStage):
    _module_cls = DetailExtractor
    _optimized_filename = "detail_extractor.json"
    _output_fields = ("location", "subject_description")

# --- Pipeline ---

env_vars = {k: v for k, v in dotenv_values(".env").items() if v is not None}
ray.init(runtime_env={"env_vars": env_vars})

calls = [
    "There's a raccoon in my kitchen and it's eating my cereal. I think it has a gun.",
    "Ahhhhhh I'm stuck in the toilet",
    "My roommate will not stop playing bagpipes at 3am.",
    "Help, I'm chronically stinky.",
]

stage_kwargs = {"model": MODEL, "optimized_dir": "optimized/"}

(
    ray.data.from_items(calls)
    .map(TriageStage, fn_constructor_kwargs=stage_kwargs,
         compute=ActorPoolStrategy(size=4))
    .map(ExtractDetailsStage, fn_constructor_kwargs=stage_kwargs,
         compute=ActorPoolStrategy(size=4))
    .write_parquet("output/triaged_calls/")
)

The DSPy modules are reusable for optimization. The adapter base class handles LM setup, optimized prompt loading, Refine-based retry, rate limit retry, and error accumulation. Each new stage is a few lines.

Copyright Ricardo Decal. ricardodecal.com