lasagnakanada commited on
Commit
d6282d0
·
1 Parent(s): 1ead44d

fix(llm): retry on 429 with backoff; raise LLMProviderRateLimit; surface as HTTP 429 at endpoint; add tests

Browse files
app/api/v1/endpoints/core.py CHANGED
@@ -9,6 +9,7 @@ from uuid import UUID
9
 
10
  import structlog
11
  from fastapi import APIRouter, Depends, HTTPException, status
 
12
  from sqlalchemy.ext.asyncio import AsyncSession
13
 
14
  from app.core.services import (
@@ -198,6 +199,16 @@ async def moderate_text(
198
  status_code=status.HTTP_400_BAD_REQUEST,
199
  detail={"error": "ValidationError", "message": str(e)},
200
  )
 
 
 
 
 
 
 
 
 
 
201
  except Exception as e:
202
  logger.error("moderate_request_failed", error=str(e), exc_info=True)
203
  raise HTTPException(
 
9
 
10
  import structlog
11
  from fastapi import APIRouter, Depends, HTTPException, status
12
+ from app.core.exceptions import LLMProviderRateLimit
13
  from sqlalchemy.ext.asyncio import AsyncSession
14
 
15
  from app.core.services import (
 
199
  status_code=status.HTTP_400_BAD_REQUEST,
200
  detail={"error": "ValidationError", "message": str(e)},
201
  )
202
+ except LLMProviderRateLimit as e:
203
+ logger.warning("generate_request_provider_rate_limited", error=str(e))
204
+ raise HTTPException(
205
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
206
+ detail={
207
+ "error": "UpstreamProviderRateLimited",
208
+ "message": str(e),
209
+ "provider_details": getattr(e, "details", {}),
210
+ },
211
+ )
212
  except Exception as e:
213
  logger.error("moderate_request_failed", error=str(e), exc_info=True)
214
  raise HTTPException(
app/core/exceptions.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class LLMProviderError(Exception):
2
+ """Base exception for upstream LLM provider failures."""
3
+
4
+
5
+ class LLMProviderRateLimit(LLMProviderError):
6
+ """Raised when upstream LLM provider is rate limited or out of credits."""
7
+
8
+ def __init__(
9
+ self,
10
+ message: str = "Upstream LLM provider rate limit or exhausted credits",
11
+ details: dict | None = None,
12
+ ):
13
+ super().__init__(message)
14
+ self.details = details or {}
app/core/services/generation_service.py CHANGED
@@ -11,6 +11,7 @@ from typing import Dict, List, Optional
11
  import structlog
12
  from sqlalchemy.ext.asyncio import AsyncSession
13
  from app.config import settings
 
14
 
15
  from app.core.services.llm_service import get_llm_service
16
  from app.core.services.moderation_service import get_moderation_service
@@ -182,6 +183,9 @@ class GenerationService:
182
  }
183
 
184
  except Exception as e:
 
 
 
185
  if self.session is not None:
186
  await self.session.rollback()
187
  logger.error(
 
11
  import structlog
12
  from sqlalchemy.ext.asyncio import AsyncSession
13
  from app.config import settings
14
+ from app.core.exceptions import LLMProviderRateLimit
15
 
16
  from app.core.services.llm_service import get_llm_service
17
  from app.core.services.moderation_service import get_moderation_service
 
183
  }
184
 
185
  except Exception as e:
186
+ # propagate upstream provider rate-limit errors so API layer can return 429
187
+ if isinstance(e, LLMProviderRateLimit):
188
+ raise
189
  if self.session is not None:
190
  await self.session.rollback()
191
  logger.error(
app/core/services/llm_service.py CHANGED
@@ -11,7 +11,10 @@ import os
11
  from typing import Dict, List, Optional
12
 
13
  import structlog
14
- from httpx import AsyncClient
 
 
 
15
 
16
  from app.config import get_settings, settings
17
  from app.schemas.core import ChatMessage, TaskType
@@ -413,22 +416,96 @@ class GrokLLMService:
413
  "max_tokens": max_length,
414
  }
415
 
416
- async with AsyncClient(timeout=settings.model_inference_timeout) as client:
417
- resp = await client.post(
418
- f"{self.base_url}/chat/completions", json=payload, headers=headers
419
- )
420
- if resp.status_code >= 400:
421
- # Log full response text to aid debugging (xAI returns useful error JSON)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  try:
423
- err_json = resp.json()
424
- except Exception:
425
- err_json = {"raw": resp.text}
426
- logger.error(
427
- "grok_api_error",
428
- status_code=resp.status_code,
429
- error=err_json,
430
- )
431
- resp.raise_for_status()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  data = resp.json()
433
  # Assuming OpenAI-like response structure
434
  text = data["choices"][0]["message"]["content"].strip()
 
11
  from typing import Dict, List, Optional
12
 
13
  import structlog
14
+ from httpx import AsyncClient, Timeout, ReadTimeout, RequestError
15
+ import asyncio
16
+
17
+ from app.core.exceptions import LLMProviderRateLimit
18
 
19
  from app.config import get_settings, settings
20
  from app.schemas.core import ChatMessage, TaskType
 
416
  "max_tokens": max_length,
417
  }
418
 
419
+ # Use an explicit Timeout so we control connect vs read timeouts.
420
+ # model_inference_timeout is an integer (seconds) from settings.
421
+ # httpx.Timeout requires either a `default` or all four values set.
422
+ # Set connect/read/write/pool explicitly to avoid the library error.
423
+ read_timeout = float(settings.model_inference_timeout)
424
+ timeout = Timeout(connect=5.0, read=read_timeout, write=read_timeout, pool=5.0)
425
+ max_attempts = int(getattr(settings, "grok_retry_attempts", 3))
426
+ backoff_base = float(getattr(settings, "grok_retry_backoff_base", 1.0))
427
+
428
+ logger.info(
429
+ "grok_request_payload",
430
+ base_url=self.base_url,
431
+ endpoint="/chat/completions",
432
+ model=self.model_name,
433
+ messages=[
434
+ {
435
+ "role": m["role"],
436
+ "content": (
437
+ m["content"][:50] + "..." if len(m["content"]) > 50 else m["content"]
438
+ ),
439
+ }
440
+ for m in messages
441
+ ],
442
+ temperature=temperature,
443
+ top_p=top_p,
444
+ max_tokens=max_length,
445
+ )
446
+ async with AsyncClient(timeout=timeout) as client:
447
+ for attempt in range(1, max_attempts + 1):
448
+
449
  try:
450
+ resp = await client.post(
451
+ f"{self.base_url}/chat/completions", json=payload, headers=headers
452
+ )
453
+ logger.info(
454
+ "grok_response_received",
455
+ status_code=resp.status_code,
456
+ headers=dict(resp.headers),
457
+ )
458
+ except ReadTimeout:
459
+ logger.error(
460
+ "grok_request_timeout",
461
+ base_url=self.base_url,
462
+ timeout_seconds=settings.model_inference_timeout,
463
+ )
464
+ raise
465
+ except RequestError as exc:
466
+ logger.error(
467
+ "grok_request_error",
468
+ base_url=self.base_url,
469
+ error=str(exc),
470
+ )
471
+ raise
472
+
473
+ if resp.status_code == 429:
474
+ # rate limit — retry with exponential backoff up to max_attempts
475
+ try:
476
+ err_json = resp.json()
477
+ except Exception:
478
+ err_json = {"raw": resp.text}
479
+
480
+ logger.warning(
481
+ "grok_api_rate_limited",
482
+ attempt=attempt,
483
+ max_attempts=max_attempts,
484
+ error=err_json,
485
+ )
486
+
487
+ if attempt < max_attempts:
488
+ sleep_seconds = backoff_base * (2 ** (attempt - 1))
489
+ await asyncio.sleep(sleep_seconds)
490
+ continue
491
+ else:
492
+ logger.error("grok_api_rate_limit_exhausted", error=err_json)
493
+ raise LLMProviderRateLimit(
494
+ "Upstream LLM provider rate limit or exhausted credits",
495
+ details=err_json,
496
+ )
497
+
498
+ if resp.status_code >= 400:
499
+ try:
500
+ err_json = resp.json()
501
+ except Exception:
502
+ err_json = {"raw": resp.text}
503
+ logger.error(
504
+ "grok_api_error",
505
+ status_code=resp.status_code,
506
+ error=err_json,
507
+ )
508
+ resp.raise_for_status()
509
  data = resp.json()
510
  # Assuming OpenAI-like response structure
511
  text = data["choices"][0]["message"]["content"].strip()
tests/test_llm_rate_limit.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from app.core.services.llm_service import GrokLLMService
4
+ from app.core.exceptions import LLMProviderRateLimit
5
+
6
+
7
+ class Fake429Resp:
8
+ def __init__(self, text="rate limit", payload=None):
9
+ self.status_code = 429
10
+ self.text = text
11
+ self._payload = payload or {
12
+ "code": "Some resource has been exhausted",
13
+ "error": "out of credits",
14
+ }
15
+
16
+ def json(self):
17
+ return self._payload
18
+
19
+
20
+ class FakeAsyncClient:
21
+ def __init__(self, *args, **kwargs):
22
+ self._calls = 0
23
+
24
+ async def __aenter__(self):
25
+ return self
26
+
27
+ async def __aexit__(self, exc_type, exc, tb):
28
+ return False
29
+
30
+ async def post(self, *args, **kwargs):
31
+ self._calls += 1
32
+ return Fake429Resp()
33
+
34
+
35
+ @pytest.mark.asyncio
36
+ async def test_grok_service_raises_rate_limit(monkeypatch):
37
+ monkeypatch.setattr("app.core.services.llm_service.AsyncClient", FakeAsyncClient)
38
+
39
+ svc = GrokLLMService()
40
+ svc.api_key = "fake"
41
+ svc.base_url = "https://fake"
42
+ svc.model_name = "grok-test"
43
+
44
+ with pytest.raises(LLMProviderRateLimit):
45
+ await svc.generate("hi there")