📌 内容摘要

  • AI 网关解决的核心问题:当多个业务系统都需要调用 Claude,如何统一管理路由、鉴权、限流、成本和可观测性。
  • 本文设计一套完整的微服务 AI 网关:智能路由(按任务复杂度自动选模型)、多租户隔离、成本归因、故障熔断。
  • 架构分四层:接入层(认证限流)、路由层(模型调度)、代理层(Claude API 封装)、观测层(指标监控)。
  • 附完整 Python 实现,可直接部署为独立微服务,通过 HTTP API 供其他服务调用。

一、为什么需要 AI 网关?

当一个组织里有多个业务系统都需要调用 Claude 时——客服系统、内容生成系统、代码助手、数据分析平台——如果每个系统各自直接调用 Claude API,会出现一系列问题:

  • 成本失控:没有统一的 token 用量归因,不知道哪个系统在烧钱
  • 重复建设:每个系统都要实现限流、重试、会话管理,代码冗余
  • 模型选型分散:各系统自行选模型,无法全局优化成本和质量平衡
  • 缺乏可观测性:没有统一的请求日志和指标,排障困难

AI 网关把所有 Claude 调用的公共关注点集中到一处,业务系统只需调用网关接口,不直接接触 Claude API。

二、整体架构设计

┌─────────────────────────────────────────────────────────┐
│                    业务服务层                             │
│   客服系统    内容生成    代码助手    数据分析             │
└──────────────────────┬──────────────────────────────────┘
                       │ HTTP/gRPC
┌──────────────────────▼──────────────────────────────────┐
│                   AI 网关(本文实现)                     │
│                                                          │
│  ┌─────────────┐  ┌──────────────┐  ┌────────────────┐ │
│  │  接入层      │  │   路由层      │  │    代理层       │ │
│  │ · 认证鉴权   │→ │ · 模型调度   │→ │ · Claude API   │ │
│  │ · 速率限制   │  │ · 负载均衡   │  │ · 重试熔断     │ │
│  │ · 租户隔离   │  │ · 智能路由   │  │ · 缓存         │ │
│  └─────────────┘  └──────────────┘  └────────────────┘ │
│                                                          │
│  ┌──────────────────────────────────────────────────┐   │
│  │                   观测层                          │   │
│  │  · 请求日志  · Token 计量  · 延迟追踪  · 告警     │   │
│  └──────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────┘
                       │
┌──────────────────────▼──────────────────────────────────┐
│               Claude API(Anthropic)                     │
│   Haiku 4.5    Sonnet 4.6    Opus 4.6                   │
└─────────────────────────────────────────────────────────┘

三、项目结构

ai-gateway/
├── gateway/
│   ├── main.py              # FastAPI 入口
│   ├── config.py            # 配置(支持多租户)
│   ├── auth.py              # 认证与租户管理
│   ├── router.py            # 智能路由引擎
│   ├── proxy.py             # Claude API 代理层
│   ├── circuit_breaker.py   # 熔断器
│   ├── cache.py             # 语义缓存
│   ├── metrics.py           # 指标收集
│   └── middleware.py        # 中间件链
├── tests/
├── .env
├── Dockerfile
└── docker-compose.yml

四、配置与租户管理

# gateway/config.py
from dataclasses import dataclass, field
from typing import Optional
import os, json

@dataclass
class ModelConfig:
    """单个模型的配置"""
    model_id:      str
    max_tokens:    int   = 4096
    rpm_limit:     int   = 60      # 每分钟请求数上限
    tpm_limit:     int   = 100_000 # 每分钟 token 数上限
    cost_per_mtok_in:  float = 3.0   # 输入 $/M tokens
    cost_per_mtok_out: float = 15.0  # 输出 $/M tokens
    enabled:       bool  = True

@dataclass
class TenantConfig:
    """租户配置(对应一个业务系统)"""
    tenant_id:          str
    api_key:            str
    display_name:       str
    daily_budget_usd:   float = 10.0     # 每日预算上限(美元)
    allowed_models:     list[str] = field(default_factory=lambda: [
        "claude-haiku-4-5-20251001",
        "claude-sonnet-4-6",
    ])
    default_model:      str = "claude-sonnet-4-6"
    rate_limit_rpm:     int = 30
    system_prompt:      Optional[str] = None   # 租户级别的默认 system prompt

@dataclass
class GatewayConfig:
    anthropic_api_key: str
    redis_url:         str = "redis://localhost:6379"
    enable_cache:      bool = True
    cache_ttl:         int  = 3600       # 缓存 1 小时
    circuit_breaker_threshold: int = 5   # 连续失败 N 次后熔断
    circuit_breaker_timeout:   int = 60  # 熔断后 N 秒再试

    models: dict[str, ModelConfig] = field(default_factory=lambda: {
        "claude-haiku-4-5-20251001": ModelConfig(
            model_id="claude-haiku-4-5-20251001",
            cost_per_mtok_in=1.0, cost_per_mtok_out=5.0,
        ),
        "claude-sonnet-4-6": ModelConfig(
            model_id="claude-sonnet-4-6",
            cost_per_mtok_in=3.0, cost_per_mtok_out=15.0,
        ),
        "claude-opus-4-6": ModelConfig(
            model_id="claude-opus-4-6",
            cost_per_mtok_in=5.0, cost_per_mtok_out=25.0,
        ),
    })

    # 租户配置(生产环境从数据库加载)
    tenants: dict[str, TenantConfig] = field(default_factory=lambda: {
        "cs-system": TenantConfig(
            tenant_id="cs-system",
            api_key="cs-key-abc123",
            display_name="客服系统",
            daily_budget_usd=20.0,
            default_model="claude-haiku-4-5-20251001",
            system_prompt="你是一名专业的客服代表,回答简洁友善。",
        ),
        "content-gen": TenantConfig(
            tenant_id="content-gen",
            api_key="cg-key-def456",
            display_name="内容生成系统",
            daily_budget_usd=50.0,
            allowed_models=["claude-sonnet-4-6", "claude-opus-4-6"],
            default_model="claude-sonnet-4-6",
        ),
        "code-assistant": TenantConfig(
            tenant_id="code-assistant",
            api_key="ca-key-ghi789",
            display_name="代码助手",
            daily_budget_usd=30.0,
            default_model="claude-sonnet-4-6",
            system_prompt="你是一名资深软件工程师,代码要有类型注解和注释。",
        ),
    })


def load_config() -> GatewayConfig:
    return GatewayConfig(
        anthropic_api_key=os.environ["ANTHROPIC_API_KEY"],
        redis_url=os.getenv("REDIS_URL", "redis://localhost:6379"),
        enable_cache=os.getenv("ENABLE_CACHE", "true").lower() == "true",
    )

config = load_config()

五、智能路由引擎

# gateway/router.py
import re
from dataclasses import dataclass
from gateway.config import GatewayConfig, TenantConfig

@dataclass
class RoutingDecision:
    model_id:   str
    reason:     str
    estimated_cost_usd: float = 0.0

class IntelligentRouter:
    """
    智能路由引擎:根据请求特征自动选择最合适的模型
    目标:在保证质量的前提下最小化成本
    """

    # 触发升级到 Opus 的信号词
    OPUS_SIGNALS = [
        "架构设计", "系统设计", "深度分析", "综合评估", "战略",
        "复杂推理", "research", "architect", "analyze thoroughly",
    ]

    # 可以降级到 Haiku 的信号(简单任务)
    HAIKU_SIGNALS = [
        "翻译", "摘要", "分类", "提取", "是否", "判断",
        "translate", "summarize", "classify", "yes or no",
    ]

    def __init__(self, config: GatewayConfig):
        self.config = config

    def route(
        self,
        message:         str,
        tenant:          TenantConfig,
        requested_model: str | None = None,
        context_length:  int = 0,
    ) -> RoutingDecision:
        """
        综合多个因素决定使用哪个模型

        优先级:
        1. 请求中明确指定的模型(如果租户有权限)
        2. 智能路由规则(基于消息内容分析)
        3. 租户默认模型
        """

        # 1. 显式指定的模型
        if requested_model:
            if requested_model in tenant.allowed_models:
                return RoutingDecision(
                    model_id=requested_model,
                    reason="显式指定",
                )
            else:
                # 无权限使用该模型,降级到租户默认
                return RoutingDecision(
                    model_id=tenant.default_model,
                    reason=f"无权限使用 {requested_model},降级到默认模型",
                )

        # 2. 基于内容的智能路由
        decision = self._analyze_content(message, tenant, context_length)
        return decision

    def _analyze_content(
        self,
        message:        str,
        tenant:         TenantConfig,
        context_length: int,
    ) -> RoutingDecision:
        message_lower = message.lower()
        msg_len       = len(message)

        # 规则1:超长上下文,需要 Sonnet/Opus 的大窗口
        if context_length > 50_000:
            target = "claude-opus-4-6" if "claude-opus-4-6" in tenant.allowed_models else "claude-sonnet-4-6"
            return RoutingDecision(
                model_id=target,
                reason=f"上下文长度 {context_length} tokens,需要大窗口模型",
            )

        # 规则2:包含复杂任务信号,升级到 Opus
        if any(sig in message_lower for sig in self.OPUS_SIGNALS):
            if "claude-opus-4-6" in tenant.allowed_models:
                return RoutingDecision(
                    model_id="claude-opus-4-6",
                    reason="检测到复杂推理任务,使用 Opus",
                )

        # 规则3:短消息 + 简单任务信号,降级到 Haiku
        if (msg_len < 200
                and any(sig in message_lower for sig in self.HAIKU_SIGNALS)
                and "claude-haiku-4-5-20251001" in tenant.allowed_models):
            return RoutingDecision(
                model_id="claude-haiku-4-5-20251001",
                reason="简单任务,使用 Haiku 节省成本",
            )

        # 规则4:代码相关 + 较长输入,使用 Sonnet
        code_signals = ["```", "def ", "function ", "class ", "import ", "代码", "bug"]
        if any(sig in message for sig in code_signals):
            if "claude-sonnet-4-6" in tenant.allowed_models:
                return RoutingDecision(
                    model_id="claude-sonnet-4-6",
                    reason="代码相关任务,使用 Sonnet",
                )

        # 默认:使用租户配置的默认模型
        return RoutingDecision(
            model_id=tenant.default_model,
            reason="默认路由",
        )

六、熔断器

# gateway/circuit_breaker.py
import time
from enum import Enum
from dataclasses import dataclass, field

class CircuitState(Enum):
    CLOSED   = "closed"    # 正常,允许请求通过
    OPEN     = "open"      # 熔断,拒绝请求
    HALF_OPEN= "half_open" # 半开,试探性允许少量请求

@dataclass
class CircuitBreaker:
    """
    熔断器:保护下游 Claude API 不被雪崩

    状态机:
    CLOSED → (连续失败 N 次) → OPEN → (等待 T 秒) → HALF_OPEN → (成功) → CLOSED
                                                                → (失败) → OPEN
    """
    name:              str
    failure_threshold: int   = 5
    recovery_timeout:  int   = 60
    half_open_max:     int   = 3      # 半开状态最多允许几个请求

    _state:            CircuitState = field(default=CircuitState.CLOSED, init=False)
    _failure_count:    int = field(default=0, init=False)
    _last_failure_time:float = field(default=0.0, init=False)
    _half_open_count:  int = field(default=0, init=False)

    @property
    def state(self) -> CircuitState:
        if self._state == CircuitState.OPEN:
            if time.time() - self._last_failure_time >= self.recovery_timeout:
                self._state = CircuitState.HALF_OPEN
                self._half_open_count = 0
        return self._state

    def is_allowed(self) -> bool:
        state = self.state
        if state == CircuitState.CLOSED:
            return True
        if state == CircuitState.OPEN:
            return False
        # HALF_OPEN:限量放行
        if self._half_open_count < self.half_open_max:
            self._half_open_count += 1
            return True
        return False

    def record_success(self):
        if self._state == CircuitState.HALF_OPEN:
            self._state         = CircuitState.CLOSED
            self._failure_count = 0
        elif self._state == CircuitState.CLOSED:
            self._failure_count = max(0, self._failure_count - 1)

    def record_failure(self):
        self._failure_count    += 1
        self._last_failure_time = time.time()
        if self._failure_count >= self.failure_threshold:
            self._state = CircuitState.OPEN

    def get_status(self) -> dict:
        return {
            "name":           self.name,
            "state":          self.state.value,
            "failure_count":  self._failure_count,
            "last_failure":   self._last_failure_time,
        }


class CircuitBreakerRegistry:
    """管理多个熔断器实例(每个模型一个)"""

    def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
        self._breakers: dict[str, CircuitBreaker] = {}
        self._threshold = failure_threshold
        self._timeout   = recovery_timeout

    def get(self, name: str) -> CircuitBreaker:
        if name not in self._breakers:
            self._breakers[name] = CircuitBreaker(
                name=name,
                failure_threshold=self._threshold,
                recovery_timeout=self._timeout,
            )
        return self._breakers[name]

    def all_status(self) -> list[dict]:
        return [cb.get_status() for cb in self._breakers.values()]

七、语义缓存

# gateway/cache.py
import hashlib
import json
from typing import Optional
from redis.asyncio import Redis

class SemanticCache:
    """
    基于内容哈希的响应缓存
    相同的 (model, system, messages) 组合直接返回缓存,不调用 API

    适合缓存的场景:
    - FAQ 问答(同样的问题反复被问)
    - 内容分类(相同文本的分类结果)
    - 批量处理(相同格式的大量请求)
    """

    def __init__(self, redis: Redis, ttl: int = 3600):
        self.redis = redis
        self.ttl   = ttl

    def _make_key(
        self,
        model:    str,
        system:   str,
        messages: list[dict],
    ) -> str:
        """生成缓存 Key:对请求内容做 SHA256 哈希"""
        content = json.dumps({
            "model":    model,
            "system":   system,
            "messages": messages,
        }, sort_keys=True, ensure_ascii=False)
        hash_val = hashlib.sha256(content.encode()).hexdigest()[:24]
        return f"cache:{hash_val}"

    async def get(
        self,
        model:    str,
        system:   str,
        messages: list[dict],
    ) -> Optional[dict]:
        """查询缓存,命中返回缓存数据,未命中返回 None"""
        key  = self._make_key(model, system, messages)
        data = await self.redis.get(key)
        if data:
            result         = json.loads(data)
            result["cached"] = True
            return result
        return None

    async def set(
        self,
        model:    str,
        system:   str,
        messages: list[dict],
        response: dict,
    ):
        """写入缓存"""
        key = self._make_key(model, system, messages)
        await self.redis.setex(
            key,
            self.ttl,
            json.dumps(response, ensure_ascii=False),
        )

    async def invalidate_pattern(self, pattern: str):
        """批量失效缓存(如模型更新后清空对应缓存)"""
        keys = await self.redis.keys(f"cache:{pattern}*")
        if keys:
            await self.redis.delete(*keys)

八、指标收集与成本归因

# gateway/metrics.py
import time
import json
from datetime import date, datetime
from redis.asyncio import Redis
from gateway.config import GatewayConfig

class MetricsCollector:
    """
    收集并存储请求指标,支持:
    - 按租户的 token 用量和成本统计
    - 按模型的请求量和延迟统计
    - 缓存命中率统计
    - 错误率统计
    """

    def __init__(self, redis: Redis, config: GatewayConfig):
        self.redis  = redis
        self.config = config

    async def record_request(
        self,
        tenant_id:     str,
        model_id:      str,
        input_tokens:  int,
        output_tokens: int,
        latency_ms:    float,
        cached:        bool  = False,
        error:         str | None = None,
    ):
        """记录一次请求的完整指标"""
        today  = date.today().isoformat()
        hour   = datetime.now().strftime("%Y-%m-%d-%H")

        # 计算本次成本
        model_cfg = self.config.models.get(model_id)
        cost_usd  = 0.0
        if model_cfg and not cached:
            cost_usd = (
                input_tokens  / 1_000_000 * model_cfg.cost_per_mtok_in +
                output_tokens / 1_000_000 * model_cfg.cost_per_mtok_out
            )

        pipe = self.redis.pipeline()

        # 租户维度统计(按日)
        tenant_day_key = f"metrics:tenant:{tenant_id}:{today}"
        pipe.hincrbyfloat(tenant_day_key, "cost_usd",       cost_usd)
        pipe.hincrby(tenant_day_key,      "input_tokens",   input_tokens)
        pipe.hincrby(tenant_day_key,      "output_tokens",  output_tokens)
        pipe.hincrby(tenant_day_key,      "request_count",  1)
        if cached:
            pipe.hincrby(tenant_day_key,  "cache_hits",     1)
        if error:
            pipe.hincrby(tenant_day_key,  "error_count",    1)
        pipe.expire(tenant_day_key, 86400 * 30)  # 保留30天

        # 模型维度统计(按小时)
        model_hour_key = f"metrics:model:{model_id}:{hour}"
        pipe.hincrby(model_hour_key,      "request_count",  1)
        pipe.hincrby(model_hour_key,      "input_tokens",   input_tokens)
        pipe.hincrby(model_hour_key,      "output_tokens",  output_tokens)
        pipe.hincrbyfloat(model_hour_key, "total_latency",  latency_ms)
        pipe.expire(model_hour_key, 86400 * 7)   # 保留7天

        await pipe.execute()

        # 记录详细请求日志(用于排障)
        log_entry = {
            "ts":            datetime.utcnow().isoformat(),
            "tenant_id":     tenant_id,
            "model_id":      model_id,
            "input_tokens":  input_tokens,
            "output_tokens": output_tokens,
            "cost_usd":      round(cost_usd, 6),
            "latency_ms":    round(latency_ms, 1),
            "cached":        cached,
            "error":         error,
        }
        await self.redis.lpush("metrics:request_log", json.dumps(log_entry))
        await self.redis.ltrim("metrics:request_log", 0, 9999)  # 保留最近1万条

    async def get_tenant_stats(self, tenant_id: str, days: int = 7) -> list[dict]:
        """获取租户近N天的统计数据"""
        from datetime import timedelta
        stats = []
        for i in range(days):
            day = (date.today() - timedelta(days=i)).isoformat()
            key = f"metrics:tenant:{tenant_id}:{day}"
            data = await self.redis.hgetall(key)
            if data:
                stats.append({
                    "date":          day,
                    "cost_usd":      float(data.get(b"cost_usd", 0)),
                    "input_tokens":  int(data.get(b"input_tokens", 0)),
                    "output_tokens": int(data.get(b"output_tokens", 0)),
                    "request_count": int(data.get(b"request_count", 0)),
                    "cache_hits":    int(data.get(b"cache_hits", 0)),
                    "error_count":   int(data.get(b"error_count", 0)),
                })
        return stats

    async def check_budget(self, tenant_id: str, estimated_cost: float) -> bool:
        """
        检查租户今日预算是否还有余量
        返回 True 表示可以继续,False 表示已超预算
        """
        today       = date.today().isoformat()
        key         = f"metrics:tenant:{tenant_id}:{today}"
        spent_raw   = await self.redis.hget(key, "cost_usd")
        spent       = float(spent_raw) if spent_raw else 0.0

        tenant_cfg  = self.config.tenants.get(tenant_id)
        if not tenant_cfg:
            return True   # 未知租户,不限制

        return (spent + estimated_cost) <= tenant_cfg.daily_budget_usd

九、代理层:Claude API 封装

# gateway/proxy.py
import anthropic
import time
from typing import AsyncGenerator
from gateway.config import GatewayConfig
from gateway.circuit_breaker import CircuitBreakerRegistry
from gateway.cache import SemanticCache
from gateway.metrics import MetricsCollector

class ClaudeProxy:
    """
    Claude API 代理层,封装:
    - 熔断保护
    - 语义缓存
    - 指标收集
    - 错误标准化
    """

    def __init__(
        self,
        config:   GatewayConfig,
        cache:    SemanticCache,
        metrics:  MetricsCollector,
        breakers: CircuitBreakerRegistry,
    ):
        self.config   = config
        self.cache    = cache
        self.metrics  = metrics
        self.breakers = breakers
        self._client  = anthropic.AsyncAnthropic(
            api_key=config.anthropic_api_key,
            max_retries=2,
            timeout=120.0,
        )

    async def complete(
        self,
        tenant_id:   str,
        model_id:    str,
        messages:    list[dict],
        system:      str = "",
        max_tokens:  int = 4096,
        temperature: float | None = None,
    ) -> dict:
        """
        普通对话代理

        Returns:
            {"content": str, "input_tokens": int, "output_tokens": int,
             "model": str, "cached": bool, "cost_usd": float}
        """
        start = time.time()

        # 1. 查询语义缓存
        if self.config.enable_cache:
            cached = await self.cache.get(model_id, system, messages)
            if cached:
                await self.metrics.record_request(
                    tenant_id=tenant_id, model_id=model_id,
                    input_tokens=0, output_tokens=0,
                    latency_ms=0, cached=True,
                )
                return cached

        # 2. 检查熔断器
        breaker = self.breakers.get(model_id)
        if not breaker.is_allowed():
            raise RuntimeError(f"模型 {model_id} 熔断中,请稍后重试")

        # 3. 调用 Claude API
        kwargs = dict(
            model      = model_id,
            max_tokens = max_tokens,
            system     = system,
            messages   = messages,
        )
        if temperature is not None:
            kwargs["temperature"] = temperature

        error_msg = None
        try:
            response     = await self._client.messages.create(**kwargs)
            content      = response.content[0].text
            input_tok    = response.usage.input_tokens
            output_tok   = response.usage.output_tokens
            breaker.record_success()

        except anthropic.RateLimitError:
            breaker.record_failure()
            error_msg = "API 速率限制"
            raise
        except anthropic.APIStatusError as e:
            breaker.record_failure()
            error_msg = f"API 错误 {e.status_code}"
            raise
        except Exception as e:
            breaker.record_failure()
            error_msg = str(e)
            raise
        finally:
            latency_ms = (time.time() - start) * 1000
            if error_msg:
                await self.metrics.record_request(
                    tenant_id=tenant_id, model_id=model_id,
                    input_tokens=0, output_tokens=0,
                    latency_ms=latency_ms, error=error_msg,
                )

        # 计算成本
        model_cfg = self.config.models.get(model_id)
        cost_usd  = 0.0
        if model_cfg:
            cost_usd = (
                input_tok  / 1_000_000 * model_cfg.cost_per_mtok_in +
                output_tok / 1_000_000 * model_cfg.cost_per_mtok_out
            )

        result = {
            "content":       content,
            "input_tokens":  input_tok,
            "output_tokens": output_tok,
            "model":         model_id,
            "cached":        False,
            "cost_usd":      round(cost_usd, 6),
            "latency_ms":    round((time.time() - start) * 1000, 1),
        }

        # 4. 写入缓存
        if self.config.enable_cache:
            await self.cache.set(model_id, system, messages, result)

        # 5. 记录指标
        await self.metrics.record_request(
            tenant_id=tenant_id, model_id=model_id,
            input_tokens=input_tok, output_tokens=output_tok,
            latency_ms=result["latency_ms"],
        )

        return result

    async def stream(
        self,
        tenant_id:   str,
        model_id:    str,
        messages:    list[dict],
        system:      str = "",
        max_tokens:  int = 4096,
    ) -> AsyncGenerator[dict, None]:
        """流式代理"""
        breaker = self.breakers.get(model_id)
        if not breaker.is_allowed():
            yield {"type": "error", "data": f"模型 {model_id} 熔断中"}
            return

        start     = time.time()
        input_tok = output_tok = 0

        try:
            async with self._client.messages.stream(
                model=model_id, max_tokens=max_tokens,
                system=system, messages=messages,
            ) as stream:
                async for text in stream.text_stream:
                    yield {"type": "text", "data": text}
                msg        = await stream.get_final_message()
                input_tok  = msg.usage.input_tokens
                output_tok = msg.usage.output_tokens
                breaker.record_success()

        except Exception as e:
            breaker.record_failure()
            yield {"type": "error", "data": str(e)}
            return

        latency_ms = (time.time() - start) * 1000
        await self.metrics.record_request(
            tenant_id=tenant_id, model_id=model_id,
            input_tokens=input_tok, output_tokens=output_tok,
            latency_ms=latency_ms,
        )
        yield {"type": "usage", "data": {"input": input_tok, "output": output_tok}}
        yield {"type": "done",  "data": None}

十、网关主入口

# gateway/main.py
import uuid, json, time
from fastapi import FastAPI, HTTPException, Header, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
from redis.asyncio import from_url, Redis

from gateway.config import config, TenantConfig
from gateway.router import IntelligentRouter
from gateway.proxy import ClaudeProxy
from gateway.circuit_breaker import CircuitBreakerRegistry
from gateway.cache import SemanticCache
from gateway.metrics import MetricsCollector

app     = FastAPI(title="Claude AI Gateway", version="1.0.0")
router  = IntelligentRouter(config)
breakers= CircuitBreakerRegistry(
    failure_threshold=config.circuit_breaker_threshold,
    recovery_timeout=config.circuit_breaker_timeout,
)

_redis: Redis | None = None

async def get_redis() -> Redis:
    global _redis
    if _redis is None:
        _redis = from_url(config.redis_url, decode_responses=False)
    return _redis

async def get_proxy(redis: Redis = Depends(get_redis)) -> ClaudeProxy:
    return ClaudeProxy(
        config   = config,
        cache    = SemanticCache(redis, config.cache_ttl),
        metrics  = MetricsCollector(redis, config),
        breakers = breakers,
    )

async def authenticate(x_api_key: str = Header(...)) -> TenantConfig:
    """根据 API Key 识别租户"""
    for tenant in config.tenants.values():
        if tenant.api_key == x_api_key:
            return tenant
    raise HTTPException(status_code=401, detail="无效的 API Key")


class GatewayRequest(BaseModel):
    messages:    list[dict]
    model:       Optional[str]   = None
    system:      Optional[str]   = None
    max_tokens:  Optional[int]   = 4096
    temperature: Optional[float] = None
    stream:      bool            = False


@app.post("/v1/chat")
async def chat(
    req:    GatewayRequest,
    tenant: TenantConfig = Depends(authenticate),
    proxy:  ClaudeProxy  = Depends(get_proxy),
    redis:  Redis        = Depends(get_redis),
):
    """统一对话入口"""
    # 智能路由
    decision = router.route(
        message        = req.messages[-1].get("content", "") if req.messages else "",
        tenant         = tenant,
        requested_model= req.model,
    )

    # 合并 system prompt(租户级别 + 请求级别)
    system = req.system or tenant.system_prompt or ""

    # 预算检查
    metrics = MetricsCollector(redis, config)
    if not await metrics.check_budget(tenant.tenant_id, estimated_cost=0.01):
        raise HTTPException(status_code=429, detail="今日预算已用完")

    if req.stream:
        async def event_gen():
            yield f"data: {json.dumps({'model': decision.model_id, 'reason': decision.reason})}\n\n"
            async for event in proxy.stream(
                tenant_id=tenant.tenant_id, model_id=decision.model_id,
                messages=req.messages, system=system, max_tokens=req.max_tokens or 4096,
            ):
                yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"

        return StreamingResponse(event_gen(), media_type="text/event-stream",
                                 headers={"Cache-Control": "no-cache"})

    result = await proxy.complete(
        tenant_id=tenant.tenant_id, model_id=decision.model_id,
        messages=req.messages, system=system,
        max_tokens=req.max_tokens or 4096, temperature=req.temperature,
    )
    return {
        "id":          str(uuid.uuid4()),
        "model":       decision.model_id,
        "routing":     {"reason": decision.reason},
        **result,
    }


@app.get("/v1/stats/{tenant_id}")
async def get_stats(
    tenant_id: str,
    days:      int   = 7,
    tenant:    TenantConfig = Depends(authenticate),
    redis:     Redis        = Depends(get_redis),
):
    """获取租户用量统计"""
    if tenant.tenant_id != tenant_id:
        raise HTTPException(status_code=403, detail="无权查看其他租户数据")
    metrics = MetricsCollector(redis, config)
    stats   = await metrics.get_tenant_stats(tenant_id, days)
    return {"tenant_id": tenant_id, "stats": stats}


@app.get("/v1/health")
async def health():
    return {"status": "ok", "circuit_breakers": breakers.all_status()}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run("gateway.main:app", host="0.0.0.0", port=9000, reload=True)

十一、调用示例

import httpx, json

GATEWAY = "http://localhost:9000"

# 客服系统调用(自动路由到 Haiku)
resp = httpx.post(f"{GATEWAY}/v1/chat",
    headers={"X-Api-Key": "cs-key-abc123"},
    json={"messages": [{"role": "user", "content": "我的订单什么时候发货?"}]},
)
print(resp.json())
# {"model": "claude-haiku-4-5-20251001", "routing": {"reason": "简单任务,使用 Haiku 节省成本"}, ...}

# 代码助手调用(自动路由到 Sonnet)
resp = httpx.post(f"{GATEWAY}/v1/chat",
    headers={"X-Api-Key": "ca-key-ghi789"},
    json={"messages": [{"role": "user", "content": "帮我写一个 Python 二分查找函数"}]},
)
print(resp.json())
# {"model": "claude-sonnet-4-6", "routing": {"reason": "代码相关任务,使用 Sonnet"}, ...}

# 内容生成(强制指定 Opus)
resp = httpx.post(f"{GATEWAY}/v1/chat",
    headers={"X-Api-Key": "cg-key-def456"},
    json={
        "messages": [{"role": "user", "content": "写一篇3000字的市场分析报告"}],
        "model":    "claude-opus-4-6",   # 显式指定
    },
)

# 查看用量统计
stats = httpx.get(f"{GATEWAY}/v1/stats/cs-system",
    headers={"X-Api-Key": "cs-key-abc123"},
)
print(stats.json())

常见问题

Q:网关本身成为单点故障怎么解决?
网关是无状态的(状态全在 Redis),可以水平扩展——用 Kubernetes 部署多副本,前面加负载均衡器(Nginx/HAProxy/Envoy)。Redis 用主从 + Sentinel 或者 Redis Cluster 保证高可用。这样网关层没有单点,Redis 层也有冗余。

Q:智能路由的规则不符合业务需求怎么定制?
IntelligentRouter 里增加或修改规则。规则评估顺序和权重都在代码里,可以按需调整。更灵活的方案是把路由规则外置到配置文件或数据库,支持运行时热更新而不需要重部署网关。

Q:生产环境的租户配置应该存在哪里?
本文用代码内嵌的方式简化了演示,生产环境应该存数据库(PostgreSQL/MySQL)。GatewayConfig 里的 tenants 改为启动时从数据库加载,并定期刷新(或用 Redis 缓存)。API Key 做哈希后存储,不明文存库。

总结

AI 网关把所有 Claude 调用的公共关注点收拢到一处,核心价值体现在四个方面:成本可见(每个租户的用量和花费实时可查)、成本可控(预算上限 + 智能路由降成本)、可靠性(熔断器 + 缓存 + 重试)、可观测(统一日志 + 延迟指标)。本文的实现可以直接作为生产起点,按需替换其中的存储层(Redis → PostgreSQL)或扩展路由规则,不需要从零重写。