nakas commited on
Commit
1f9c688
·
1 Parent(s): 9a0a838

Handle Azure SAS tokens and credential refresh

Browse files
Files changed (1) hide show
  1. gefs_wave.py +71 -4
gefs_wave.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import logging
 
3
  import tarfile
4
  import threading
5
  from dataclasses import dataclass
@@ -13,6 +14,7 @@ import requests
13
  NOMADS_URL = "https://nomads.ncep.noaa.gov/pub/data/nccf/com/gens/prod"
14
  AZURE_URL = "https://noaagefs.blob.core.windows.net/gefs"
15
  BASE_URLS = [AZURE_URL, NOMADS_URL]
 
16
  FORECAST_HOURS: List[int] = list(range(0, 387, 3))
17
  STATE_FILENAME = "state.json"
18
  TARBALL_FILENAME = "gefswave-wave0-latest.tar.gz"
@@ -51,6 +53,8 @@ class WaveDownloader:
51
  self.data_root = data_root
52
  self.data_root.mkdir(parents=True, exist_ok=True)
53
  self.session = session or requests.Session()
 
 
54
 
55
  def ensure_latest_cycle(self) -> Dict:
56
  """Ensure the latest available control-member dataset is present locally."""
@@ -98,7 +102,11 @@ class WaveDownloader:
98
  def _cycle_available(self, day: str, cycle: str) -> bool:
99
  filename = f"gefswave.t{cycle}z.c00.global.0p25.f000.grib2"
100
  for base_url in BASE_URLS:
101
- url = f"{base_url}/gefs.{day}/{cycle}/wave/gridded/{filename}"
 
 
 
 
102
  try:
103
  response = self.session.head(url, timeout=20)
104
  if response.ok:
@@ -106,6 +114,16 @@ class WaveDownloader:
106
  return True
107
  if response.status_code == 404:
108
  logger.debug("Cycle %s %sz missing on %s (404).", day, cycle, base_url)
 
 
 
 
 
 
 
 
 
 
109
  else:
110
  logger.debug(
111
  "Cycle %s %sz head request on %s returned HTTP %s.",
@@ -145,19 +163,26 @@ class WaveDownloader:
145
  def _download_with_fallback(self, cycle: Cycle, filename: str, destination: Path) -> None:
146
  errors = []
147
  for base_url in BASE_URLS:
148
- url = f"{base_url}/{cycle.directory}/wave/gridded/{filename}"
 
 
 
 
149
  try:
150
- self._stream_to_file(url, destination)
151
  return
152
  except WaveDownloadError as exc:
153
  errors.append(str(exc))
154
  logger.warning("Failed to download %s from %s (%s).", filename, base_url, exc)
155
  raise WaveDownloadError(f"All download attempts failed for {filename}: {'; '.join(errors)}")
156
 
157
- def _stream_to_file(self, url: str, destination: Path) -> None:
158
  tmp_path = destination.with_suffix(destination.suffix + ".part")
159
  try:
160
  with self.session.get(url, stream=True, timeout=120) as response:
 
 
 
161
  response.raise_for_status()
162
  with tmp_path.open("wb") as handle:
163
  for chunk in response.iter_content(chunk_size=1 << 20):
@@ -171,6 +196,48 @@ class WaveDownloader:
171
  logger.error("Failed to download %s: %s", url, exc)
172
  raise WaveDownloadError(f"Failed to download {url}: {exc}") from exc
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def _cycle_complete(self, cycle: Cycle) -> bool:
175
  target_dir = self.data_root / cycle.date / cycle.cycle
176
  if not target_dir.exists():
 
1
  import json
2
  import logging
3
+ import os
4
  import tarfile
5
  import threading
6
  from dataclasses import dataclass
 
14
  NOMADS_URL = "https://nomads.ncep.noaa.gov/pub/data/nccf/com/gens/prod"
15
  AZURE_URL = "https://noaagefs.blob.core.windows.net/gefs"
16
  BASE_URLS = [AZURE_URL, NOMADS_URL]
17
+ AZURE_SAS_ENDPOINT = "https://planetarycomputer.microsoft.com/api/sas/v1/token/noaagefs/gefs"
18
  FORECAST_HOURS: List[int] = list(range(0, 387, 3))
19
  STATE_FILENAME = "state.json"
20
  TARBALL_FILENAME = "gefswave-wave0-latest.tar.gz"
 
53
  self.data_root = data_root
54
  self.data_root.mkdir(parents=True, exist_ok=True)
55
  self.session = session or requests.Session()
56
+ self._azure_sas_token: Optional[str] = None
57
+ self._azure_token_checked = False
58
 
59
  def ensure_latest_cycle(self) -> Dict:
60
  """Ensure the latest available control-member dataset is present locally."""
 
102
  def _cycle_available(self, day: str, cycle: str) -> bool:
103
  filename = f"gefswave.t{cycle}z.c00.global.0p25.f000.grib2"
104
  for base_url in BASE_URLS:
105
+ path = f"gefs.{day}/{cycle}/wave/gridded/{filename}"
106
+ url = self._build_url(base_url, path)
107
+ if not url:
108
+ logger.debug("Skipping %s for cycle %s %sz due to missing credentials.", base_url, day, cycle)
109
+ continue
110
  try:
111
  response = self.session.head(url, timeout=20)
112
  if response.ok:
 
114
  return True
115
  if response.status_code == 404:
116
  logger.debug("Cycle %s %sz missing on %s (404).", day, cycle, base_url)
117
+ elif response.status_code in (401, 403):
118
+ logger.warning(
119
+ "Cycle %s %sz request to %s returned %s; will retry after refreshing credentials.",
120
+ day,
121
+ cycle,
122
+ base_url,
123
+ response.status_code,
124
+ )
125
+ if base_url == AZURE_URL:
126
+ self._invalidate_azure_token()
127
  else:
128
  logger.debug(
129
  "Cycle %s %sz head request on %s returned HTTP %s.",
 
163
  def _download_with_fallback(self, cycle: Cycle, filename: str, destination: Path) -> None:
164
  errors = []
165
  for base_url in BASE_URLS:
166
+ path = f"{cycle.directory}/wave/gridded/{filename}"
167
+ url = self._build_url(base_url, path)
168
+ if not url:
169
+ logger.debug("Skipping download from %s for %s due to missing URL.", base_url, filename)
170
+ continue
171
  try:
172
+ self._stream_to_file(url, destination, base_url)
173
  return
174
  except WaveDownloadError as exc:
175
  errors.append(str(exc))
176
  logger.warning("Failed to download %s from %s (%s).", filename, base_url, exc)
177
  raise WaveDownloadError(f"All download attempts failed for {filename}: {'; '.join(errors)}")
178
 
179
+ def _stream_to_file(self, url: str, destination: Path, base_url: str) -> None:
180
  tmp_path = destination.with_suffix(destination.suffix + ".part")
181
  try:
182
  with self.session.get(url, stream=True, timeout=120) as response:
183
+ if response.status_code in (401, 403) and base_url == AZURE_URL:
184
+ self._invalidate_azure_token()
185
+ response.raise_for_status()
186
  response.raise_for_status()
187
  with tmp_path.open("wb") as handle:
188
  for chunk in response.iter_content(chunk_size=1 << 20):
 
196
  logger.error("Failed to download %s: %s", url, exc)
197
  raise WaveDownloadError(f"Failed to download {url}: {exc}") from exc
198
 
199
+ def _build_url(self, base_url: str, path: str) -> Optional[str]:
200
+ if base_url == AZURE_URL:
201
+ token = self._get_azure_sas_token()
202
+ if not token:
203
+ return None
204
+ token = token if token.startswith("?") else f"?{token}"
205
+ return f"{base_url}/{path}{token}"
206
+ return f"{base_url}/{path}"
207
+
208
+ def _get_azure_sas_token(self) -> Optional[str]:
209
+ if self._azure_sas_token:
210
+ return self._azure_sas_token
211
+
212
+ env_token = os.environ.get("GEFS_AZURE_SAS_TOKEN")
213
+ if env_token:
214
+ self._azure_sas_token = env_token.strip()
215
+ logger.info("Using Azure SAS token from environment variable.")
216
+ return self._azure_sas_token
217
+
218
+ if self._azure_token_checked:
219
+ return None
220
+
221
+ self._azure_token_checked = True
222
+ try:
223
+ response = self.session.get(AZURE_SAS_ENDPOINT, timeout=20)
224
+ response.raise_for_status()
225
+ token = response.json().get("token")
226
+ if token:
227
+ self._azure_sas_token = token.strip()
228
+ logger.info("Fetched Azure SAS token from Planetary Computer API.")
229
+ else:
230
+ logger.warning("Azure SAS endpoint response did not include a token.")
231
+ except requests.RequestException as exc:
232
+ logger.warning("Failed to fetch Azure SAS token: %s", exc)
233
+ return self._azure_sas_token
234
+
235
+ def _invalidate_azure_token(self) -> None:
236
+ if self._azure_sas_token:
237
+ logger.info("Invalidating cached Azure SAS token.")
238
+ self._azure_sas_token = None
239
+ self._azure_token_checked = False
240
+
241
  def _cycle_complete(self, cycle: Cycle) -> bool:
242
  target_dir = self.data_root / cycle.date / cycle.cycle
243
  if not target_dir.exists():