GFES_WAVE / gefs_wave.py
nakas's picture
Handle Azure SAS tokens and credential refresh
1f9c688
raw
history blame
12.5 kB
import json
import logging
import os
import tarfile
import threading
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Dict, Iterable, List, Optional
import requests
NOMADS_URL = "https://nomads.ncep.noaa.gov/pub/data/nccf/com/gens/prod"
AZURE_URL = "https://noaagefs.blob.core.windows.net/gefs"
BASE_URLS = [AZURE_URL, NOMADS_URL]
AZURE_SAS_ENDPOINT = "https://planetarycomputer.microsoft.com/api/sas/v1/token/noaagefs/gefs"
FORECAST_HOURS: List[int] = list(range(0, 387, 3))
STATE_FILENAME = "state.json"
TARBALL_FILENAME = "gefswave-wave0-latest.tar.gz"
logger = logging.getLogger("gefs_wave")
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False
@dataclass
class Cycle:
date: str # YYYYMMDD
cycle: str # HH
@property
def label(self) -> str:
return f"{self.date} {self.cycle}Z"
@property
def directory(self) -> str:
return f"gefs.{self.date}/{self.cycle}"
class WaveDownloadError(Exception):
"""Raised when a download or discovery step fails."""
class WaveDownloader:
"""Handles discovery and download of the GEFS wave control-member dataset."""
def __init__(self, data_root: Path, session: Optional[requests.Session] = None) -> None:
self.data_root = data_root
self.data_root.mkdir(parents=True, exist_ok=True)
self.session = session or requests.Session()
self._azure_sas_token: Optional[str] = None
self._azure_token_checked = False
def ensure_latest_cycle(self) -> Dict:
"""Ensure the latest available control-member dataset is present locally."""
logger.info("Ensuring latest GEFS wave control-member cycle for %s", self.data_root)
cycle = self._find_latest_cycle()
if not cycle:
logger.error("No GEFS wave cycle found in the recent window.")
raise WaveDownloadError("Could not locate a recent GEFS wave cycle.")
state = self._read_state()
if state and state.get("cycle") == cycle.label and self._cycle_complete(cycle):
logger.info("Cycle %s already cached; using existing files.", cycle.label)
return state
logger.info("Downloading cycle %s", cycle.label)
files = self._download_cycle(cycle)
tarball_path = self._build_tarball(cycle, files)
logger.info("Packaged cycle %s into tarball %s", cycle.label, tarball_path.name)
state = {
"cycle": cycle.label,
"files": [str(path.relative_to(self.data_root)) for path in files],
"tarball": str(tarball_path.relative_to(self.data_root)),
"updated_at": datetime.now(timezone.utc).isoformat(),
}
self._write_state(state)
return state
def current_state(self) -> Optional[Dict]:
return self._read_state()
def _find_latest_cycle(self, max_days_back: int = 9) -> Optional[Cycle]:
now = datetime.now(timezone.utc)
candidate_days = [(now - timedelta(days=days)).strftime("%Y%m%d") for days in range(max_days_back + 1)]
cycles = ["18", "12", "06", "00"]
for day in candidate_days:
for cycle in cycles:
logger.info("Probing cycle %s %sz", day, cycle)
if self._cycle_available(day, cycle):
logger.info("Cycle %s %sz is available.", day, cycle)
return Cycle(date=day, cycle=cycle)
logger.warning("No GEFS wave cycle found in the last %d days.", max_days_back)
return None
def _cycle_available(self, day: str, cycle: str) -> bool:
filename = f"gefswave.t{cycle}z.c00.global.0p25.f000.grib2"
for base_url in BASE_URLS:
path = f"gefs.{day}/{cycle}/wave/gridded/{filename}"
url = self._build_url(base_url, path)
if not url:
logger.debug("Skipping %s for cycle %s %sz due to missing credentials.", base_url, day, cycle)
continue
try:
response = self.session.head(url, timeout=20)
if response.ok:
logger.info("Cycle %s %sz file found on %s", day, cycle, base_url)
return True
if response.status_code == 404:
logger.debug("Cycle %s %sz missing on %s (404).", day, cycle, base_url)
elif response.status_code in (401, 403):
logger.warning(
"Cycle %s %sz request to %s returned %s; will retry after refreshing credentials.",
day,
cycle,
base_url,
response.status_code,
)
if base_url == AZURE_URL:
self._invalidate_azure_token()
else:
logger.debug(
"Cycle %s %sz head request on %s returned HTTP %s.",
day,
cycle,
base_url,
response.status_code,
)
except requests.RequestException as exc:
logger.debug("Error probing cycle %s %sz on %s: %s", day, cycle, base_url, exc)
return False
def _download_cycle(self, cycle: Cycle) -> List[Path]:
files: List[Path] = []
target_dir = self.data_root / cycle.date / cycle.cycle
target_dir.mkdir(parents=True, exist_ok=True)
for hour in FORECAST_HOURS:
filename = f"gefswave.t{cycle.cycle}z.c00.global.0p25.f{hour:03d}.grib2"
destination = target_dir / filename
if destination.exists() and destination.stat().st_size > 0:
logger.debug("File %s already exists; skipping download.", destination)
files.append(destination)
continue
logger.info("Downloading %s", filename)
self._download_with_fallback(cycle, filename, destination)
files.append(destination)
return files
def _build_tarball(self, cycle: Cycle, files: Iterable[Path]) -> Path:
tarball_path = self.data_root / TARBALL_FILENAME
with tarfile.open(tarball_path, "w:gz") as tar:
for file_path in files:
tar.add(file_path, arcname=file_path.relative_to(self.data_root))
return tarball_path
def _download_with_fallback(self, cycle: Cycle, filename: str, destination: Path) -> None:
errors = []
for base_url in BASE_URLS:
path = f"{cycle.directory}/wave/gridded/{filename}"
url = self._build_url(base_url, path)
if not url:
logger.debug("Skipping download from %s for %s due to missing URL.", base_url, filename)
continue
try:
self._stream_to_file(url, destination, base_url)
return
except WaveDownloadError as exc:
errors.append(str(exc))
logger.warning("Failed to download %s from %s (%s).", filename, base_url, exc)
raise WaveDownloadError(f"All download attempts failed for {filename}: {'; '.join(errors)}")
def _stream_to_file(self, url: str, destination: Path, base_url: str) -> None:
tmp_path = destination.with_suffix(destination.suffix + ".part")
try:
with self.session.get(url, stream=True, timeout=120) as response:
if response.status_code in (401, 403) and base_url == AZURE_URL:
self._invalidate_azure_token()
response.raise_for_status()
response.raise_for_status()
with tmp_path.open("wb") as handle:
for chunk in response.iter_content(chunk_size=1 << 20):
if not chunk:
continue
handle.write(chunk)
tmp_path.rename(destination)
except requests.RequestException as exc:
if tmp_path.exists():
tmp_path.unlink()
logger.error("Failed to download %s: %s", url, exc)
raise WaveDownloadError(f"Failed to download {url}: {exc}") from exc
def _build_url(self, base_url: str, path: str) -> Optional[str]:
if base_url == AZURE_URL:
token = self._get_azure_sas_token()
if not token:
return None
token = token if token.startswith("?") else f"?{token}"
return f"{base_url}/{path}{token}"
return f"{base_url}/{path}"
def _get_azure_sas_token(self) -> Optional[str]:
if self._azure_sas_token:
return self._azure_sas_token
env_token = os.environ.get("GEFS_AZURE_SAS_TOKEN")
if env_token:
self._azure_sas_token = env_token.strip()
logger.info("Using Azure SAS token from environment variable.")
return self._azure_sas_token
if self._azure_token_checked:
return None
self._azure_token_checked = True
try:
response = self.session.get(AZURE_SAS_ENDPOINT, timeout=20)
response.raise_for_status()
token = response.json().get("token")
if token:
self._azure_sas_token = token.strip()
logger.info("Fetched Azure SAS token from Planetary Computer API.")
else:
logger.warning("Azure SAS endpoint response did not include a token.")
except requests.RequestException as exc:
logger.warning("Failed to fetch Azure SAS token: %s", exc)
return self._azure_sas_token
def _invalidate_azure_token(self) -> None:
if self._azure_sas_token:
logger.info("Invalidating cached Azure SAS token.")
self._azure_sas_token = None
self._azure_token_checked = False
def _cycle_complete(self, cycle: Cycle) -> bool:
target_dir = self.data_root / cycle.date / cycle.cycle
if not target_dir.exists():
return False
expected = {f"gefswave.t{cycle.cycle}z.c00.global.0p25.f{hour:03d}.grib2" for hour in FORECAST_HOURS}
existing = {path.name for path in target_dir.glob("*.grib2") if path.stat().st_size > 0}
return expected.issubset(existing)
def _read_state(self) -> Optional[Dict]:
path = self.data_root / STATE_FILENAME
if not path.exists():
return None
try:
return json.loads(path.read_text())
except json.JSONDecodeError:
return None
def _write_state(self, state: Dict) -> None:
path = self.data_root / STATE_FILENAME
path.write_text(json.dumps(state, indent=2))
class WaveDownloadManager:
"""Simple threaded wrapper to keep download work off the request thread."""
def __init__(self, data_root: Path) -> None:
self.downloader = WaveDownloader(data_root)
self._lock = threading.Lock()
self._worker: Optional[threading.Thread] = None
self._status: Dict = {"status": "idle"}
def trigger_refresh(self) -> None:
with self._lock:
if self._worker and self._worker.is_alive():
logger.info("Refresh already in progress; ignoring new trigger.")
return
logger.info("Starting GEFS wave refresh thread.")
self._status = {"status": "running"}
self._worker = threading.Thread(target=self._run_refresh, daemon=True)
self._worker.start()
def status(self) -> Dict:
with self._lock:
status = dict(self._status)
state = self.downloader.current_state()
if state:
status["latest_state"] = state
return status
def _run_refresh(self) -> None:
try:
result = self.downloader.ensure_latest_cycle()
with self._lock:
self._status = {"status": "ready", "latest_state": result}
logger.info("GEFS wave refresh completed successfully.")
except Exception as exc:
with self._lock:
self._status = {"status": "error", "message": str(exc)}
logger.exception("GEFS wave refresh failed: %s", exc)