How to kill leaked Ray Actors

Under some circumstances, you can have leaked Ray Actors sitting around reserving resources and preventing a cluster from down-scaling. Here is a script which kills actors by either their PIDs or their class name, no matter which nodes they occupy:

#!/usr/bin/env python3
"""
Kill specific PIDs or actors (by name or class) on all alive nodes in a Ray cluster.

Usage:
  - kill_leaked_actors.py 10119 my_actor NCCLUniqueIDStore 10834

Notes:
  - Connects to the existing Ray cluster with address="auto".
  - Schedules a small task on each alive node that attempts SIGKILL on each PID.
  - Reports per-node results: "killed", "notfound", or an error message like "perm".
"""

import argparse
import os
import signal
from dataclasses import dataclass, field
from typing import Dict, List, Sequence, Set

import ray
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy


@ray.remote(num_cpus=0)
def kill_pids_on_this_node(target_pids: List[int]) -> list[tuple[int, str]]:
    """Attempt SIGKILL for each PID local to this node."""

    results: list[tuple[int, str]] = []
    for pid in target_pids:
        try:
            os.kill(pid, signal.SIGKILL)
            results.append((pid, "killed"))
        except ProcessLookupError:
            results.append((pid, "notfound"))
        except PermissionError:
            results.append((pid, "permission"))
        except Exception as exc:
            results.append((pid, f"error:{exc}"))
    return results


def _split_numeric_and_named_targets(raw_targets: Sequence[str]) -> tuple[List[int], Set[str]]:
    numeric: List[int] = []
    names: Set[str] = set()
    for target in raw_targets:
        try:
            numeric.append(int(target, 10))
        except ValueError:
            names.add(target)
    return numeric, names


@dataclass
class ActorTargetInfo:
    pids: Set[int] = field(default_factory=set)
    matched_fields: Set[str] = field(default_factory=set)


def _resolve_actor_targets(labels: Set[str]) -> Dict[str, ActorTargetInfo]:
    if not labels:
        return {}

    from ray.experimental.state.api import list_actors

    resolved: Dict[str, ActorTargetInfo] = {
        label: ActorTargetInfo() for label in labels
    }
    alive_actors = list_actors(filters=[("state", "=", "ALIVE")])
    for actor in alive_actors:
        if actor.pid is None:
            continue
        if actor.name in labels:
            resolved[actor.name].pids.add(actor.pid)
            resolved[actor.name].matched_fields.add("name")
        if actor.class_name in labels:
            resolved[actor.class_name].pids.add(actor.pid)
            resolved[actor.class_name].matched_fields.add("class")
    return resolved


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Send SIGKILL for PIDs or actors (by name or class) on every alive Ray node"
    )
    parser.add_argument(
        "targets",
        metavar="TARGET",
        nargs="+",
        help=(
            "Pid(s) or actor identifiers to kill. Integers are treated as PIDs; other strings"
            " match actor names or class names."
        ),
    )
    args = parser.parse_args()

    ray.init(address="auto", ignore_reinit_error=True)

    numeric_targets, actor_names = _split_numeric_and_named_targets(args.targets)
    actor_pid_map = _resolve_actor_targets(actor_names)

    for label in sorted(actor_pid_map):
        info = actor_pid_map[label]
        pids = sorted(info.pids)
        if pids:
            pid_list = ", ".join(str(pid) for pid in pids)
            match_desc = "/".join(sorted(info.matched_fields)) or "name"
            print(f"Resolved actor target '{label}' (matched by {match_desc}) to PID(s): {pid_list}")
        else:
            print(f"No alive actors found matching '{label}'")

    target_pids = set(numeric_targets)
    for info in actor_pid_map.values():
        target_pids.update(info.pids)

    if not target_pids:
        print("No PIDs to kill.")
        return

    pid_list = sorted(target_pids)

    nodes = [node for node in ray.nodes() if node.get("Alive")]
    tasks = []
    for node in nodes:
        node_id = node["NodeID"]
        node_ip = node["NodeManagerAddress"]
        task = kill_pids_on_this_node.options(
            scheduling_strategy=NodeAffinitySchedulingStrategy(node_id, soft=False)
        ).remote(pid_list)
        tasks.append((node_ip, task))

    for node_ip, task in tasks:
        for pid, status in ray.get(task, timeout=30):
            print(f"[{node_ip}] {pid}:{status}")

if __name__ == "__main__":
    main()

Copyright Ricardo Decal. ricardodecal.com