File size: 12,459 Bytes
4283b24
3566df5
1f9c688
4283b24
 
 
 
 
 
 
 
 
 
d093138
 
 
1f9c688
4283b24
 
 
 
3566df5
 
 
 
 
 
 
 
4283b24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f9c688
 
4283b24
 
 
3566df5
4283b24
 
3566df5
4283b24
 
 
 
3566df5
4283b24
 
3566df5
4283b24
 
3566df5
4283b24
 
 
 
 
 
 
 
 
 
 
 
d093138
4283b24
 
 
 
 
 
3566df5
4283b24
3566df5
4283b24
3566df5
4283b24
 
 
9a0a838
d093138
1f9c688
 
 
 
 
d093138
9a0a838
 
 
 
 
 
1f9c688
 
 
 
 
 
 
 
 
 
9a0a838
d093138
9a0a838
d093138
 
 
 
 
 
 
 
4283b24
 
 
 
 
 
 
 
 
 
3566df5
4283b24
 
3566df5
d093138
4283b24
 
 
 
 
 
 
 
 
 
d093138
 
 
1f9c688
 
 
 
 
d093138
1f9c688
d093138
 
 
 
 
 
1f9c688
4283b24
 
 
1f9c688
 
 
4283b24
 
 
 
 
 
 
 
 
 
3566df5
4283b24
 
1f9c688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4283b24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3566df5
4283b24
3566df5
4283b24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3566df5
4283b24
 
 
3566df5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import json
import logging
import os
import tarfile
import threading
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import requests


NOMADS_URL = "https://nomads.ncep.noaa.gov/pub/data/nccf/com/gens/prod"
AZURE_URL = "https://noaagefs.blob.core.windows.net/gefs"
BASE_URLS = [AZURE_URL, NOMADS_URL]
AZURE_SAS_ENDPOINT = "https://planetarycomputer.microsoft.com/api/sas/v1/token/noaagefs/gefs"
FORECAST_HOURS: List[int] = list(range(0, 387, 3))
STATE_FILENAME = "state.json"
TARBALL_FILENAME = "gefswave-wave0-latest.tar.gz"

logger = logging.getLogger("gefs_wave")
if not logger.handlers:
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
    logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False


@dataclass
class Cycle:
    date: str  # YYYYMMDD
    cycle: str  # HH

    @property
    def label(self) -> str:
        return f"{self.date} {self.cycle}Z"

    @property
    def directory(self) -> str:
        return f"gefs.{self.date}/{self.cycle}"


class WaveDownloadError(Exception):
    """Raised when a download or discovery step fails."""


class WaveDownloader:
    """Handles discovery and download of the GEFS wave control-member dataset."""

    def __init__(self, data_root: Path, session: Optional[requests.Session] = None) -> None:
        self.data_root = data_root
        self.data_root.mkdir(parents=True, exist_ok=True)
        self.session = session or requests.Session()
        self._azure_sas_token: Optional[str] = None
        self._azure_token_checked = False

    def ensure_latest_cycle(self) -> Dict:
        """Ensure the latest available control-member dataset is present locally."""
        logger.info("Ensuring latest GEFS wave control-member cycle for %s", self.data_root)
        cycle = self._find_latest_cycle()
        if not cycle:
            logger.error("No GEFS wave cycle found in the recent window.")
            raise WaveDownloadError("Could not locate a recent GEFS wave cycle.")

        state = self._read_state()
        if state and state.get("cycle") == cycle.label and self._cycle_complete(cycle):
            logger.info("Cycle %s already cached; using existing files.", cycle.label)
            return state

        logger.info("Downloading cycle %s", cycle.label)
        files = self._download_cycle(cycle)
        tarball_path = self._build_tarball(cycle, files)
        logger.info("Packaged cycle %s into tarball %s", cycle.label, tarball_path.name)
        state = {
            "cycle": cycle.label,
            "files": [str(path.relative_to(self.data_root)) for path in files],
            "tarball": str(tarball_path.relative_to(self.data_root)),
            "updated_at": datetime.now(timezone.utc).isoformat(),
        }
        self._write_state(state)
        return state

    def current_state(self) -> Optional[Dict]:
        return self._read_state()

    def _find_latest_cycle(self, max_days_back: int = 9) -> Optional[Cycle]:
        now = datetime.now(timezone.utc)
        candidate_days = [(now - timedelta(days=days)).strftime("%Y%m%d") for days in range(max_days_back + 1)]
        cycles = ["18", "12", "06", "00"]

        for day in candidate_days:
            for cycle in cycles:
                logger.info("Probing cycle %s %sz", day, cycle)
                if self._cycle_available(day, cycle):
                    logger.info("Cycle %s %sz is available.", day, cycle)
                    return Cycle(date=day, cycle=cycle)
        logger.warning("No GEFS wave cycle found in the last %d days.", max_days_back)
        return None

    def _cycle_available(self, day: str, cycle: str) -> bool:
        filename = f"gefswave.t{cycle}z.c00.global.0p25.f000.grib2"
        for base_url in BASE_URLS:
            path = f"gefs.{day}/{cycle}/wave/gridded/{filename}"
            url = self._build_url(base_url, path)
            if not url:
                logger.debug("Skipping %s for cycle %s %sz due to missing credentials.", base_url, day, cycle)
                continue
            try:
                response = self.session.head(url, timeout=20)
                if response.ok:
                    logger.info("Cycle %s %sz file found on %s", day, cycle, base_url)
                    return True
                if response.status_code == 404:
                    logger.debug("Cycle %s %sz missing on %s (404).", day, cycle, base_url)
                elif response.status_code in (401, 403):
                    logger.warning(
                        "Cycle %s %sz request to %s returned %s; will retry after refreshing credentials.",
                        day,
                        cycle,
                        base_url,
                        response.status_code,
                    )
                    if base_url == AZURE_URL:
                        self._invalidate_azure_token()
                else:
                    logger.debug(
                        "Cycle %s %sz head request on %s returned HTTP %s.",
                        day,
                        cycle,
                        base_url,
                        response.status_code,
                    )
            except requests.RequestException as exc:
                logger.debug("Error probing cycle %s %sz on %s: %s", day, cycle, base_url, exc)
        return False

    def _download_cycle(self, cycle: Cycle) -> List[Path]:
        files: List[Path] = []
        target_dir = self.data_root / cycle.date / cycle.cycle
        target_dir.mkdir(parents=True, exist_ok=True)

        for hour in FORECAST_HOURS:
            filename = f"gefswave.t{cycle.cycle}z.c00.global.0p25.f{hour:03d}.grib2"
            destination = target_dir / filename
            if destination.exists() and destination.stat().st_size > 0:
                logger.debug("File %s already exists; skipping download.", destination)
                files.append(destination)
                continue
            logger.info("Downloading %s", filename)
            self._download_with_fallback(cycle, filename, destination)
            files.append(destination)
        return files

    def _build_tarball(self, cycle: Cycle, files: Iterable[Path]) -> Path:
        tarball_path = self.data_root / TARBALL_FILENAME
        with tarfile.open(tarball_path, "w:gz") as tar:
            for file_path in files:
                tar.add(file_path, arcname=file_path.relative_to(self.data_root))
        return tarball_path

    def _download_with_fallback(self, cycle: Cycle, filename: str, destination: Path) -> None:
        errors = []
        for base_url in BASE_URLS:
            path = f"{cycle.directory}/wave/gridded/{filename}"
            url = self._build_url(base_url, path)
            if not url:
                logger.debug("Skipping download from %s for %s due to missing URL.", base_url, filename)
                continue
            try:
                self._stream_to_file(url, destination, base_url)
                return
            except WaveDownloadError as exc:
                errors.append(str(exc))
                logger.warning("Failed to download %s from %s (%s).", filename, base_url, exc)
        raise WaveDownloadError(f"All download attempts failed for {filename}: {'; '.join(errors)}")

    def _stream_to_file(self, url: str, destination: Path, base_url: str) -> None:
        tmp_path = destination.with_suffix(destination.suffix + ".part")
        try:
            with self.session.get(url, stream=True, timeout=120) as response:
                if response.status_code in (401, 403) and base_url == AZURE_URL:
                    self._invalidate_azure_token()
                    response.raise_for_status()
                response.raise_for_status()
                with tmp_path.open("wb") as handle:
                    for chunk in response.iter_content(chunk_size=1 << 20):
                        if not chunk:
                            continue
                        handle.write(chunk)
            tmp_path.rename(destination)
        except requests.RequestException as exc:
            if tmp_path.exists():
                tmp_path.unlink()
            logger.error("Failed to download %s: %s", url, exc)
            raise WaveDownloadError(f"Failed to download {url}: {exc}") from exc

    def _build_url(self, base_url: str, path: str) -> Optional[str]:
        if base_url == AZURE_URL:
            token = self._get_azure_sas_token()
            if not token:
                return None
            token = token if token.startswith("?") else f"?{token}"
            return f"{base_url}/{path}{token}"
        return f"{base_url}/{path}"

    def _get_azure_sas_token(self) -> Optional[str]:
        if self._azure_sas_token:
            return self._azure_sas_token

        env_token = os.environ.get("GEFS_AZURE_SAS_TOKEN")
        if env_token:
            self._azure_sas_token = env_token.strip()
            logger.info("Using Azure SAS token from environment variable.")
            return self._azure_sas_token

        if self._azure_token_checked:
            return None

        self._azure_token_checked = True
        try:
            response = self.session.get(AZURE_SAS_ENDPOINT, timeout=20)
            response.raise_for_status()
            token = response.json().get("token")
            if token:
                self._azure_sas_token = token.strip()
                logger.info("Fetched Azure SAS token from Planetary Computer API.")
            else:
                logger.warning("Azure SAS endpoint response did not include a token.")
        except requests.RequestException as exc:
            logger.warning("Failed to fetch Azure SAS token: %s", exc)
        return self._azure_sas_token

    def _invalidate_azure_token(self) -> None:
        if self._azure_sas_token:
            logger.info("Invalidating cached Azure SAS token.")
        self._azure_sas_token = None
        self._azure_token_checked = False

    def _cycle_complete(self, cycle: Cycle) -> bool:
        target_dir = self.data_root / cycle.date / cycle.cycle
        if not target_dir.exists():
            return False
        expected = {f"gefswave.t{cycle.cycle}z.c00.global.0p25.f{hour:03d}.grib2" for hour in FORECAST_HOURS}
        existing = {path.name for path in target_dir.glob("*.grib2") if path.stat().st_size > 0}
        return expected.issubset(existing)

    def _read_state(self) -> Optional[Dict]:
        path = self.data_root / STATE_FILENAME
        if not path.exists():
            return None
        try:
            return json.loads(path.read_text())
        except json.JSONDecodeError:
            return None

    def _write_state(self, state: Dict) -> None:
        path = self.data_root / STATE_FILENAME
        path.write_text(json.dumps(state, indent=2))


class WaveDownloadManager:
    """Simple threaded wrapper to keep download work off the request thread."""

    def __init__(self, data_root: Path) -> None:
        self.downloader = WaveDownloader(data_root)
        self._lock = threading.Lock()
        self._worker: Optional[threading.Thread] = None
        self._status: Dict = {"status": "idle"}

    def trigger_refresh(self) -> None:
        with self._lock:
            if self._worker and self._worker.is_alive():
                logger.info("Refresh already in progress; ignoring new trigger.")
                return
            logger.info("Starting GEFS wave refresh thread.")
            self._status = {"status": "running"}
            self._worker = threading.Thread(target=self._run_refresh, daemon=True)
            self._worker.start()

    def status(self) -> Dict:
        with self._lock:
            status = dict(self._status)
        state = self.downloader.current_state()
        if state:
            status["latest_state"] = state
        return status

    def _run_refresh(self) -> None:
        try:
            result = self.downloader.ensure_latest_cycle()
            with self._lock:
                self._status = {"status": "ready", "latest_state": result}
            logger.info("GEFS wave refresh completed successfully.")
        except Exception as exc:
            with self._lock:
                self._status = {"status": "error", "message": str(exc)}
            logger.exception("GEFS wave refresh failed: %s", exc)