Ayoub commited on
Commit
ce5153c
·
1 Parent(s): 579cea9

add metrics computation

Browse files
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
 
3
  import gradio as gr
4
  import numpy as np
 
5
  import torch
6
  from patchify import patchify, unpatchify
7
  from phasepack import phasecong
@@ -12,6 +13,8 @@ from skimage.feature import canny
12
  from skimage.filters import sato
13
 
14
  from src.unet import UNet
 
 
15
 
16
  # ------------------------------------------------------------
17
  # Device
@@ -103,6 +106,31 @@ def sato_fn(img, x, sigmas):
103
  return np.float64(sato(gray, sato_sigmas_list[sigmas]) < x)
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # ------------------------------------------------------------
107
  # Deep learning model loading
108
  # ------------------------------------------------------------
@@ -309,7 +337,7 @@ with gr.Blocks(title="Fractex2D Segmentation") as demo:
309
  gr.Markdown(
310
  """
311
  ## Canny edge detection
312
- Canny edge detection (scikit-image) with normalized thresholds https://doi.org/10.1109/TPAMI.1986.4767851.
313
  - **sigma** controls Gaussian smoothing
314
  - **lt / ht** are low/high thresholds in the range 0–1
315
  """
@@ -395,6 +423,59 @@ with gr.Blocks(title="Fractex2D Segmentation") as demo:
395
  outputs=pc_out,
396
  )
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  # ------------------------------------------------------------
399
  # Extra reference
400
  # ------------------------------------------------------------
 
2
 
3
  import gradio as gr
4
  import numpy as np
5
+ import pandas as pd
6
  import torch
7
  from patchify import patchify, unpatchify
8
  from phasepack import phasecong
 
13
  from skimage.filters import sato
14
 
15
  from src.unet import UNet
16
+ from src.train import eval_single
17
+ from src.dataset_benchm import expand_wide_fractures_gt, dilate_labels
18
 
19
  # ------------------------------------------------------------
20
  # Device
 
106
  return np.float64(sato(gray, sato_sigmas_list[sigmas]) < x)
107
 
108
 
109
+ # ------------------------------------------------------------
110
+ # Compute metrics
111
+ # ------------------------------------------------------------
112
+ def compute_metrics_ui(gt_img, pred_img, threshold):
113
+ if gt_img is None or pred_img is None:
114
+ return None
115
+
116
+ # Normalise to [0,1]
117
+ gt = np.array(gt_img, dtype=np.uint8)
118
+ pred = np.array(pred_img, dtype=np.uint8)
119
+
120
+ if gt.ndim == 3:
121
+ gt = gt[..., 0]
122
+ if pred.ndim == 3:
123
+ pred = pred[..., 0]
124
+
125
+ gt = dilate_labels(gt)
126
+
127
+ metrics = eval_single(gt, pred, threshold=threshold, device=device)
128
+
129
+ df = pd.DataFrame([metrics])
130
+ df = df.round(3)
131
+ return df
132
+
133
+
134
  # ------------------------------------------------------------
135
  # Deep learning model loading
136
  # ------------------------------------------------------------
 
337
  gr.Markdown(
338
  """
339
  ## Canny edge detection
340
+ Canny edge detection (scikit-image) with normalised thresholds https://doi.org/10.1109/TPAMI.1986.4767851.
341
  - **sigma** controls Gaussian smoothing
342
  - **lt / ht** are low/high thresholds in the range 0–1
343
  """
 
423
  outputs=pc_out,
424
  )
425
 
426
+ # ------------------------------------------------------------
427
+ # TAB 5 — METRICS
428
+ # ------------------------------------------------------------
429
+ with gr.Tab("Metrics computation"):
430
+ gr.Markdown(
431
+ """
432
+ ## Segmentation Metrics
433
+ Compute quantitative metrics between a **prediction** and a **ground-truth** (1px wide annotation).
434
+ Both images must be aligned and have the same resolution.
435
+ """
436
+ )
437
+
438
+ with gr.Row():
439
+ gt_input = gr.Image(label="Ground truth", type="numpy")
440
+ pred_input = gr.Image(label="Prediction", type="numpy")
441
+
442
+ with gr.Row():
443
+ thresh = gr.Slider(
444
+ 0, 1,
445
+ value=0.1,
446
+ step=0.01,
447
+ label="Binarisation threshold"
448
+ )
449
+
450
+ with gr.Row():
451
+ with gr.Column(scale=1):
452
+ pass
453
+ metric_btn = gr.Button("Compute metrics")
454
+ with gr.Column(scale=1):
455
+ pass
456
+
457
+ metric_table = gr.Dataframe(
458
+ headers=[
459
+ "mse", "psnr", "ssim", "ae",
460
+ "acc", "prec", "rec", "spec",
461
+ "f1", "dice", "iou", "ck", "roc_auc"
462
+ ],
463
+ label="Metrics (single image pair)"
464
+ )
465
+
466
+ metric_btn.click(
467
+ fn=compute_metrics_ui,
468
+ inputs=[gt_input, pred_input, thresh],
469
+ outputs=metric_table,
470
+ )
471
+
472
+ gr.Examples(
473
+ examples=[
474
+ ["examples/kl5-s3_1-gt.png", "examples/unet-p1_pred_kl5-s3_1.png", 0.1],
475
+ ],
476
+ inputs=[gt_input, pred_input, thresh],
477
+ )
478
+
479
  # ------------------------------------------------------------
480
  # Extra reference
481
  # ------------------------------------------------------------
examples/kl5-s3_1-gt.png ADDED

Git LFS Details

  • SHA256: 3571d3b335a15fe5c7686eced6e03b4245a26c927d1565eb8743190fd4dfc96c
  • Pointer size: 130 Bytes
  • Size of remote file: 21.2 kB
examples/unet-p1_pred_kl5-s3_1.png ADDED

Git LFS Details

  • SHA256: c4a1a16497d5ac2b559c4aa0032006657abca3d273b67d90fde75bfc0e56189c
  • Pointer size: 130 Bytes
  • Size of remote file: 66.2 kB
src/dataset_benchm.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import random
3
+ from typing import List, Optional, Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms.v2 as t
8
+ import torchvision.transforms.v2.functional as TF
9
+ from skimage import io
10
+ from skimage.filters.rank import maximum
11
+ from skimage.measure import label
12
+ from skimage.morphology import binary_dilation, dilation, disk
13
+ from skimage.segmentation import expand_labels
14
+ from torch.utils.data import ConcatDataset, DataLoader, Dataset
15
+
16
+
17
+ # -------------------------
18
+ # Label pre-processing
19
+ # -------------------------
20
+ def expand_wide_fractures_gt(
21
+ img: np.ndarray,
22
+ gt: np.ndarray,
23
+ disk_size: int = 2,
24
+ thresh: int = 30,
25
+ gt_thresh: int = 100,
26
+ gt_ext: str = "png",
27
+ ) -> np.ndarray:
28
+ """
29
+ Expand a binary/soft ground-truth mask to include nearby wide/dark fractures.
30
+
31
+ Method:
32
+ - Use green channel (index 1) as a grayscale proxy.
33
+ - Apply a maximum filter to emphasize large dark regions.
34
+ - Threshold and dilate to form a candidate mask.
35
+ - Keep only connected components that overlap the original GT.
36
+ - Return a combined mask as uint8 (0..255). If gt_ext contains "tif" the
37
+ original `gt` is assumed to be already in [0,1] or in the original dtype;
38
+ the code preserves existing scaling behavior from the original script.
39
+
40
+ Args:
41
+ img: HxWxC image (expects at least 2 channels; green channel used).
42
+ gt: HxW ground-truth mask (expected in [0..1] or [0..255]).
43
+ disk_size: radius for morphological operations.
44
+ thresh: threshold applied to the maximum-filtered gray image.
45
+ gt_thresh: threshold to consider a pixel part of the original GT.
46
+ gt_ext: file extension of GT (affects final combination step).
47
+
48
+ Returns:
49
+ Expanded GT mask as np.uint8 (values 0 or 255).
50
+ """
51
+ if img.ndim < 3 or img.shape[2] < 2:
52
+ raise ValueError("img must have at least 2 channels (uses green channel).")
53
+
54
+ # use green channel as grayscale proxy
55
+ gray = img[..., 1].astype(np.uint8)
56
+
57
+ # keep large dark areas via maximum filter, then threshold and dilate
58
+ imax = maximum(gray, disk(disk_size))
59
+ candidate = binary_dilation(imax < thresh, disk(disk_size))
60
+
61
+ # combine candidate with existing GT (considering gt_thresh)
62
+ gt_bool = gt > gt_thresh
63
+ combined = np.logical_or(candidate, gt_bool)
64
+
65
+ # remove connected components that do not overlap original GT
66
+ labeled, num = label(combined, connectivity=1, return_num=True)
67
+ for comp_id in range(1, num + 1):
68
+ comp_mask = labeled == comp_id
69
+ if not np.any(gt_bool[comp_mask]):
70
+ combined[comp_mask] = False
71
+
72
+ # produce uint8 [0,255] result with behavior matching original code
73
+ if "tif" in gt_ext:
74
+ # preserve original gt scaling behavior from source
75
+ new_gt = (np.array(gt * 255, dtype=np.uint8) | np.array(combined * 255, dtype=np.uint8))
76
+ else:
77
+ new_gt = (np.array(gt, dtype=np.uint8) | np.array(combined * 255, dtype=np.uint8))
78
+
79
+ return new_gt
80
+
81
+
82
+ def dilate_labels(image: np.ndarray) -> np.ndarray:
83
+ """
84
+ Smooth label boundaries by multi-scale dilation and blending.
85
+
86
+ - Expand labels to fill tiny gaps (expand_labels).
87
+ - Create three dilation masks with increasing disks and blend them into
88
+ a smoothed label map with decreasing weights.
89
+
90
+ Args:
91
+ image: integer-labeled image or binary mask (HxW).
92
+
93
+ Returns:
94
+ np.uint8 array (HxW) with blended/smoothed label boundaries.
95
+ """
96
+ expanded = expand_labels(image, distance=2)
97
+
98
+ # Multi-scale dilation masks (exclusive differences)
99
+ d1 = dilation(expanded, disk(2)) ^ expanded
100
+ d2 = dilation(expanded, disk(5)) ^ d1 ^ expanded
101
+ d3 = dilation(expanded, disk(7)) ^ d2 ^ d1 ^ expanded
102
+
103
+ blended = expanded + d1 / 3.0 + d2 / 5.0 + d3 / 9.0
104
+ return np.array(blended, dtype=np.uint8)
105
+
106
+
107
+ # -------------------------
108
+ # Augmentation helpers
109
+ # -------------------------
110
+ def _apply_random_flips(image: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ """Random horizontal and vertical flips (50% each)."""
112
+ if random.random() > 0.5:
113
+ image, mask = TF.hflip(image), TF.hflip(mask)
114
+ if random.random() > 0.5:
115
+ image, mask = TF.vflip(image), TF.vflip(mask)
116
+ return image, mask
117
+
118
+
119
+ def _apply_random_photometric_augmentations(image: torch.Tensor, prob_config: Optional[dict] = None) -> torch.Tensor:
120
+ """
121
+ Photometric augmentations applied independently with small probabilities.
122
+
123
+ The function preserves an extra channel (e.g. DEM) if image has 4 channels:
124
+ - augment only the first three (RGB) channels, then concatenate the extra.
125
+ """
126
+ if prob_config is None:
127
+ prob_config = {
128
+ "gaussian_blur": 0.05,
129
+ "darken_low": 0.05,
130
+ "brighten": 0.15,
131
+ "contrast": 0.05,
132
+ "saturation": 0.05,
133
+ }
134
+
135
+ has_extra = image.shape[0] == 4
136
+ rgb = image[:3] if has_extra else image
137
+
138
+ # gaussian blur
139
+ if random.random() < prob_config["gaussian_blur"]:
140
+ sigma = random.uniform(0.1, 2.0)
141
+ rgb = TF.gaussian_blur(rgb, kernel_size=5, sigma=sigma)
142
+
143
+ # darken (factor < 1)
144
+ if random.random() < prob_config["darken_low"]:
145
+ factor = random.uniform(0.7, 0.9)
146
+ rgb = TF.adjust_brightness(rgb, factor)
147
+
148
+ # brighten (factor > 1)
149
+ if random.random() < prob_config["brighten"]:
150
+ factor = random.uniform(1.1, 1.7)
151
+ rgb = TF.adjust_brightness(rgb, factor)
152
+
153
+ # contrast
154
+ if random.random() < prob_config["contrast"]:
155
+ factor = random.uniform(0.7, 1.5)
156
+ rgb = TF.adjust_contrast(rgb, factor)
157
+
158
+ # saturation
159
+ if random.random() < prob_config["saturation"]:
160
+ factor = random.uniform(0.7, 1.5)
161
+ rgb = TF.adjust_saturation(rgb, factor)
162
+
163
+ if has_extra:
164
+ image = torch.cat([rgb, image[3:]], dim=0)
165
+ else:
166
+ image = rgb
167
+
168
+ return image
169
+
170
+
171
+ # -------------------------
172
+ # Base dataset utilities
173
+ # -------------------------
174
+ def _read_image(path: Path) -> np.ndarray:
175
+ """Read image with skimage.io and ensure dtype uint8."""
176
+ arr = io.imread(str(path))
177
+ # convert floats to uint8 if necessary
178
+ if arr.dtype != np.uint8:
179
+ arr = arr.astype(np.uint8)
180
+ return arr
181
+
182
+
183
+ def _read_mask(path: Path) -> np.ndarray:
184
+ """Read mask and convert to uint8 0..255."""
185
+ arr = io.imread(str(path))
186
+ if arr.dtype != np.uint8:
187
+ arr = (arr * 255).astype(np.uint8) if arr.max() <= 1.0 else arr.astype(np.uint8)
188
+ return arr
189
+
190
+
191
+ # -------------------------
192
+ # Dataset classes
193
+ # -------------------------
194
+ class BaseCrackDataset(Dataset):
195
+ """
196
+ Minimal common functionality for the specific dataset wrappers used downstream.
197
+
198
+ Subclasses must provide:
199
+ - self.images (list[Path])
200
+ - self.masks (list[Path])
201
+ - optional self.dems (list[Path]) when in_channels==4
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ images: Sequence[Path],
207
+ masks: Sequence[Path],
208
+ dem_paths: Optional[Sequence[Path]] = None,
209
+ topo: bool = False,
210
+ transform: bool = False,
211
+ expand: bool = True,
212
+ dilate: bool = True,
213
+ in_channels: int = 3,
214
+ ):
215
+ self.images = list(images)
216
+ self.masks = list(masks)
217
+ self.dems = list(dem_paths) if dem_paths is not None else None
218
+
219
+ self.topo = topo
220
+ self.transform = transform
221
+ self.expand = expand
222
+ self.dilate = dilate
223
+ self.in_channels = in_channels
224
+
225
+ def __len__(self) -> int:
226
+ return len(self.images)
227
+
228
+ def _load_pair(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
229
+ """
230
+ Load image/mask pair, apply optional expand/dilate and channel handling,
231
+ then perform flips and photometric augmentations.
232
+ """
233
+ img_np = _read_image(Path(self.images[idx]))
234
+ gt_np = _read_mask(Path(self.masks[idx]))
235
+
236
+ # expand wide fractures (if requested)
237
+ if self.expand:
238
+ gt_np = expand_wide_fractures_gt(img_np[:, :, :3].astype(np.uint8), gt_np)
239
+
240
+ # dilate labels (if requested)
241
+ if self.dilate:
242
+ gt_np = dilate_labels(gt_np)
243
+
244
+ # build image tensor. If dataset provides DEM as a separate file, append as 4th channel.
245
+ img_tensor = torch.from_numpy(img_np[:, :, :3])
246
+ if self.in_channels == 4:
247
+ # if DEM present inside the image array or as separate file, handle both cases
248
+ if img_np.shape[2] >= 4:
249
+ dem_np = img_np[:, :, 3].astype(np.float32)
250
+ elif self.dems is not None:
251
+ dem_np = _read_image(Path(self.dems[idx])).astype(np.float32)
252
+ else:
253
+ raise RuntimeError("Requested 4 input channels but no DEM found.")
254
+ # normalize DEM to [0,1]
255
+ dem_tensor = torch.from_numpy(dem_np).float()
256
+ dem_tensor = (dem_tensor - dem_tensor.min()) / (dem_tensor.max() - dem_tensor.min() + 1e-8)
257
+ img_tensor = torch.cat((img_tensor, dem_tensor.unsqueeze(2)), axis=2)
258
+
259
+ # reformat to C,H,W and normalize image to [0,1]
260
+ img_tensor = img_tensor.permute(2, 0, 1).float() / 255.0
261
+
262
+ mask_tensor = torch.from_numpy(gt_np).unsqueeze(0).float() / 255.0
263
+
264
+ # random flips
265
+ img_tensor, mask_tensor = _apply_random_flips(img_tensor, mask_tensor)
266
+
267
+ # photometric augmentations
268
+ if self.transform:
269
+ img_tensor = _apply_random_photometric_augmentations(img_tensor)
270
+
271
+ return img_tensor.float(), mask_tensor.float()
272
+
273
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
274
+ idx = index % len(self.images)
275
+ return self._load_pair(idx)
276
+
277
+
278
+ # -------------------------
279
+ # Concrete dataset wrappers
280
+ # -------------------------
281
+ def _read_list_file(list_path: Path) -> List[str]:
282
+ """Read non-empty lines from a list file and return them as strings."""
283
+ with list_path.open("r") as f:
284
+ return [ln.strip() for ln in f if ln.strip()]
285
+
286
+
287
+ class OVAS(BaseCrackDataset):
288
+ """OVAS dataset wrapper. Expects directory structure: <root>/<subset>/{image,gt,dem}."""
289
+
290
+ def __init__(
291
+ self,
292
+ subset: str,
293
+ list_file: Optional[str] = "list.txt",
294
+ topo: bool = False,
295
+ transform: bool = False,
296
+ expand: bool = True,
297
+ dilate: bool = True,
298
+ in_channels: int = 3,
299
+ ):
300
+ root = Path("data/ovaskainen23_") / subset
301
+ ext_img = "png"
302
+ ext_gt = "tif"
303
+
304
+ names = []
305
+ if list_file:
306
+ names = _read_list_file(root / list_file)
307
+
308
+ images = [
309
+ (root / "image" / n).with_suffix("." + ext_img)
310
+ for n in names
311
+ if n.endswith("." + ext_gt)
312
+ ]
313
+ masks = [root / "gt" / n for n in names if n.endswith("." + ext_gt)]
314
+ dems = [root / "dem" / n for n in names if n.endswith("." + ext_gt)]
315
+ else:
316
+ images = sorted(path for path in (root / "image").iterdir() if path.suffix.lower().lstrip(".") == ext_img)
317
+ masks = sorted(path for path in (root / "gt").iterdir() if path.suffix.lower().lstrip(".") == ext_gt)
318
+ dems = sorted(path for path in (root / "dem").iterdir() if path.suffix.lower().lstrip(".") == ext_gt)
319
+
320
+ super().__init__(images=images, masks=masks, dem_paths=dems, topo=topo, transform=transform,
321
+ expand=expand, dilate=dilate, in_channels=in_channels)
322
+
323
+
324
+ class MATTEO(BaseCrackDataset):
325
+ """MATTEO dataset wrapper. Expects .tif files; includes DEM channel inside the image."""
326
+
327
+ def __init__(
328
+ self,
329
+ subset: str,
330
+ list_file: Optional[str] = "list.txt",
331
+ topo: bool = False,
332
+ transform: bool = False,
333
+ expand: bool = True,
334
+ dilate: bool = True,
335
+ in_channels: int = 3,
336
+ ):
337
+ root = Path("data/matteo21") / subset
338
+ ext = "tif"
339
+
340
+ if list_file:
341
+ names = _read_list_file(root / list_file)
342
+ else:
343
+ names = [p.name for p in (root / "image").iterdir() if p.suffix.lstrip(".") == ext]
344
+
345
+ images = sorted(root / "image" / name for name in names)
346
+ masks = sorted(root / "gt" / name for name in names)
347
+
348
+ super().__init__(images=images, masks=masks, dem_paths=None, topo=topo, transform=transform,
349
+ expand=expand, dilate=dilate, in_channels=in_channels)
350
+
351
+
352
+ class SAMSU(BaseCrackDataset):
353
+ """SAMSU dataset wrapper. Similar layout to OVAS."""
354
+
355
+ def __init__(
356
+ self,
357
+ subset: str,
358
+ list_file: Optional[str] = "list.txt",
359
+ topo: bool = False,
360
+ transform: bool = False,
361
+ expand: bool = True,
362
+ dilate: bool = True,
363
+ in_channels: int = 3,
364
+ ):
365
+ root = Path("data/samsu19") / subset
366
+ ext_img = "png"
367
+ ext_gt = "tif"
368
+
369
+ names = []
370
+ if list_file:
371
+ names = _read_list_file(root / list_file)
372
+ images = [
373
+ (root / "image" / n).with_suffix("." + ext_img)
374
+ for n in names
375
+ if n.endswith("." + ext_gt)
376
+ ]
377
+ masks = [root / "gt" / n for n in names if n.endswith("." + ext_gt)]
378
+ dems = [root / "dem" / n for n in names if n.endswith("." + ext_gt)]
379
+ else:
380
+ images = sorted(p for p in (root / "image").iterdir() if p.suffix.lstrip(".") == ext_img)
381
+ masks = sorted(p for p in (root / "gt").iterdir() if p.suffix.lstrip(".") == ext_gt)
382
+ dems = sorted(p for p in (root / "dem").iterdir() if p.suffix.lstrip(".") == ext_gt)
383
+
384
+ super().__init__(images=images, masks=masks, dem_paths=dems, topo=topo, transform=transform,
385
+ expand=expand, dilate=dilate, in_channels=in_channels)
386
+
387
+
388
+ class GeoCrack(BaseCrackDataset):
389
+ """GeoCrack dataset wrapper (simple PNG images)."""
390
+
391
+ def __init__(
392
+ self,
393
+ subset: str,
394
+ topo: bool = False,
395
+ transform: bool = False,
396
+ expand: bool = True,
397
+ dilate: bool = True,
398
+ in_channels: int = 3,
399
+ ):
400
+ root = Path("data/GeoCrack_") / subset
401
+ ext = "png"
402
+
403
+ images = sorted(p for p in (root / "image").iterdir() if p.suffix.lstrip(".") == ext)
404
+ masks = sorted(p for p in (root / "gt").iterdir() if p.suffix.lstrip(".") == ext)
405
+
406
+ super().__init__(images=images, masks=masks, dem_paths=None, topo=topo, transform=transform,
407
+ expand=expand, dilate=dilate, in_channels=in_channels)
408
+
409
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
410
+ img, mask = super().__getitem__(index)
411
+ # consistent resizing used originally
412
+ img = t.Resize(256)(img)
413
+ mask = t.Resize(256)(mask)
414
+ return img.float(), mask.float()
415
+
416
+
417
+ class DIC(BaseCrackDataset):
418
+ """DIC dataset wrapper: single-channel images and PNG masks."""
419
+
420
+ def __init__(
421
+ self,
422
+ subset: str,
423
+ topo: bool = False,
424
+ transform: bool = False,
425
+ expand: bool = False,
426
+ dilate: bool = False,
427
+ in_channels: int = 1,
428
+ ):
429
+ root = Path("data/DIC") / subset
430
+ ext_img = "tif"
431
+ ext_mask = "png"
432
+
433
+ images = sorted(p for p in (root / "image").iterdir() if p.suffix.lstrip(".") == ext_img)
434
+ masks = sorted(p for p in (root / "gt").iterdir() if p.suffix.lstrip(".") == ext_mask)
435
+
436
+ super().__init__(images=images, masks=masks, dem_paths=None, topo=topo, transform=transform,
437
+ expand=expand, dilate=dilate, in_channels=in_channels)
438
+
439
+ def _load_pair(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
440
+ """
441
+ Override to handle single-channel image format (the base expects >=3 channels).
442
+ """
443
+ img_np = _read_image(Path(self.images[idx]))
444
+ gt_np = _read_mask(Path(self.masks[idx]))
445
+
446
+ # ensure single channel
447
+ if img_np.ndim == 3:
448
+ img_np = img_np[..., 0]
449
+
450
+ img_tensor = torch.from_numpy(img_np).unsqueeze(0).float() / 255.0
451
+ mask_tensor = torch.from_numpy(gt_np).unsqueeze(0).float() / 255.0
452
+
453
+ img_tensor, mask_tensor = _apply_random_flips(img_tensor, mask_tensor)
454
+
455
+ if self.transform:
456
+ img_tensor = _apply_random_photometric_augmentations(img_tensor)
457
+
458
+ img_tensor = t.Resize(256)(img_tensor)
459
+ mask_tensor = t.Resize(256)(mask_tensor)
460
+
461
+ return img_tensor.float(), mask_tensor.float()
462
+
463
+
464
+ # -------------------------
465
+ # Dataset registry & loader builder
466
+ # -------------------------
467
+ DATASETS = {
468
+ "ovaskainen23": OVAS,
469
+ "matteo21": MATTEO,
470
+ "samsu19": SAMSU,
471
+ "geocrack": GeoCrack,
472
+ "dic": DIC,
473
+ }
474
+
475
+
476
+ def all_datasets(
477
+ batch_size: int = 32,
478
+ datasets: str = "samsu19-matteo21-ovaskainen23",
479
+ in_channels: int = 4,
480
+ out_channels: int = 1,
481
+ shape: int = 256,
482
+ expand: bool = True,
483
+ dilate: bool = True,
484
+ shuffle_train: bool = True,
485
+ do_transform: bool = True,
486
+ ) -> Tuple[DataLoader, DataLoader, DataLoader]:
487
+ """
488
+ Create concatenated train/val/test DataLoaders from multiple dataset names.
489
+
490
+ Args:
491
+ batch_size: batch size for DataLoaders.
492
+ datasets: dash-separated dataset keys from DATASETS dict.
493
+ in_channels: number of input channels requested (3 or 4).
494
+ out_channels: number of output channels (kept for API compatibility).
495
+ shape: target shape (not used directly here; datasets may resize internally).
496
+ expand, dilate: whether to apply expand/dilate preprocessing.
497
+ shuffle_train: whether to shuffle the training DataLoader.
498
+ do_transform: whether to enable augmentations.
499
+
500
+ Returns:
501
+ Tuple(train_loader, val_loader, test_loader)
502
+ """
503
+ keys = [k.strip() for k in datasets.split("-") if k.strip()]
504
+ all_train = []
505
+ all_val = []
506
+ all_test = []
507
+
508
+ for name in keys:
509
+ if name not in DATASETS:
510
+ raise KeyError(f"Unknown dataset key: {name}")
511
+ DS = DATASETS[name]
512
+ all_train.append(DS(subset="train", transform=do_transform, expand=expand, dilate=dilate, in_channels=in_channels))
513
+ all_val.append(DS(subset="valid", transform=False, expand=expand, dilate=dilate, in_channels=in_channels))
514
+ all_test.append(DS(subset="test", transform=False, expand=expand, dilate=dilate, in_channels=in_channels))
515
+
516
+ trainset = ConcatDataset(all_train)
517
+ valset = ConcatDataset(all_val)
518
+ testset = ConcatDataset(all_test)
519
+
520
+ trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=shuffle_train)
521
+ valloader = DataLoader(valset, batch_size=batch_size, shuffle=False)
522
+ testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
523
+
524
+ return trainloader, valloader, testloader
src/train.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from skimage.morphology import label, skeletonize
5
+ from skimage.util import view_as_windows
6
+ from torchmetrics import MeanAbsoluteError, MeanSquaredError
7
+ from torchmetrics.classification import (
8
+ BinaryAccuracy, BinaryAUROC, BinaryCohenKappa, BinaryF1Score,
9
+ BinaryJaccardIndex, BinaryPrecision, BinaryRecall, BinarySpecificity
10
+ )
11
+ from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
12
+ from torchmetrics.segmentation import DiceScore
13
+ from tqdm.auto import tqdm
14
+
15
+
16
+ def remove_junctions(skel: np.ndarray) -> np.ndarray:
17
+ """Remove junction points from a binary skeleton."""
18
+ skel = skel.astype(np.uint8)
19
+ mask = np.zeros_like(skel)
20
+ windows = view_as_windows(skel, (3, 3))
21
+ for i in range(windows.shape[0]):
22
+ for j in range(windows.shape[1]):
23
+ if windows[i, j].sum() > 4:
24
+ mask[i:i+3, j:j+3] = 1
25
+ return skel * (1 - mask)
26
+
27
+
28
+ def fracture_similarity(pred_mask: torch.Tensor, true_mask: torch.Tensor) -> float:
29
+ """Compute similarity score between predicted and true fracture masks."""
30
+ pred_skel = skeletonize((pred_mask > 0.1).cpu().numpy())
31
+ true_skel = skeletonize((true_mask > 0.1).cpu().numpy())
32
+ pred_clean = remove_junctions(pred_skel)
33
+ true_clean = remove_junctions(true_skel)
34
+ pred_labeled = label(pred_clean)
35
+ true_labeled = label(true_clean)
36
+ pred_lengths = np.bincount(pred_labeled.ravel())[1:]
37
+ true_lengths = np.bincount(true_labeled.ravel())[1:]
38
+ bins = np.linspace(0, 260, 20)
39
+ pred_hist, _ = np.histogram(pred_lengths, bins=bins)
40
+ true_hist, _ = np.histogram(true_lengths, bins=bins)
41
+ pred_hist = pred_hist + 1e-6
42
+ true_hist = true_hist + 1e-6
43
+ chi_dist = 0.5 * np.sum((pred_hist - true_hist)**2 / (pred_hist + true_hist))
44
+ return chi_dist
45
+
46
+
47
+ def train_loop(model, optimizer, criterion, train_loader, device='cpu', mdl=None):
48
+ """Train the model for one epoch."""
49
+ running_loss = 0
50
+ model = model.to(device)
51
+ model.train()
52
+ pbar = tqdm(train_loader, desc="Iterating over train data")
53
+
54
+ for images, labels in pbar:
55
+ images, labels = images.to(device), labels.to(device)
56
+ out = model(images)['out'] if mdl == 'fcn_resnet101' else model(images)
57
+ loss = criterion(out, labels)
58
+ running_loss += loss.item() * images.shape[0]
59
+ optimizer.zero_grad()
60
+ loss.backward()
61
+ optimizer.step()
62
+
63
+ running_loss /= len(train_loader.sampler)
64
+ return running_loss
65
+
66
+
67
+ def eval_loop(model, scheduler, criterion, eval_loader, threshold=0.5, device='cpu',
68
+ mdl=None, ignore_index=None):
69
+ """Evaluate the model on a validation or test dataset."""
70
+ running_loss = 0
71
+ model.eval()
72
+ if ignore_index not in [0, 1]:
73
+ ignore_index = None
74
+
75
+ with torch.no_grad():
76
+ # Metrics
77
+ acc_metric = BinaryAccuracy(ignore_index=ignore_index).to(device)
78
+ f1_metric = BinaryF1Score(ignore_index=ignore_index).to(device)
79
+ prec_metric = BinaryPrecision(ignore_index=ignore_index).to(device)
80
+ rec_metric = BinaryRecall(ignore_index=ignore_index).to(device)
81
+ spec_metric = BinarySpecificity(ignore_index=ignore_index).to(device)
82
+ auroc_metric = BinaryAUROC(ignore_index=ignore_index).to(device)
83
+ iou_metric = BinaryJaccardIndex(ignore_index=ignore_index).to(device)
84
+ dice_metric = DiceScore(num_classes=1, average="micro",
85
+ aggregation_level='global').to(device)
86
+ ck_metric = BinaryCohenKappa().to(device)
87
+ mse_metric = MeanSquaredError().to(device)
88
+ ae_metric = MeanAbsoluteError().to(device)
89
+ psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
90
+ ssim_metric = StructuralSimilarityIndexMeasure().to(device)
91
+ fracture_sim_scores = []
92
+
93
+ pbar = tqdm(eval_loader, desc='Iterating over evaluation/test data')
94
+ for imgs, labels in pbar:
95
+ imgs, labels = imgs.to(device), labels.to(device)
96
+ out = model(imgs)['out'] if mdl == 'fcn_resnet101' else model(imgs)
97
+ loss = criterion(out, labels)
98
+ running_loss += loss.item() * imgs.shape[0]
99
+
100
+ predicted = out
101
+ if mdl == 'Segformer':
102
+ predicted[predicted > 0.99] = 0.
103
+ predicted_clf = (out > threshold).float()
104
+ labels_clf = (labels > 0.).float()
105
+ labels = labels.float()
106
+
107
+ # Compute metrics
108
+ acc_metric(predicted_clf, labels_clf)
109
+ f1_metric(predicted_clf, labels_clf)
110
+ prec_metric(predicted_clf, labels_clf)
111
+ rec_metric(predicted_clf, labels_clf)
112
+ spec_metric(predicted_clf, labels_clf)
113
+ if labels_clf.numel() > 0 and labels_clf.min() != labels_clf.max():
114
+ auroc_metric(predicted_clf, labels_clf)
115
+ dice_metric(predicted_clf, labels_clf)
116
+ iou_metric(predicted_clf, labels_clf)
117
+ ck_metric(predicted_clf, labels_clf)
118
+ mse_metric(predicted, labels)
119
+ psnr_metric(predicted, labels)
120
+ ssim_metric(predicted, labels)
121
+ ae_metric(predicted, labels)
122
+
123
+ for i in range(imgs.shape[0]):
124
+ pred_mask = predicted_clf[i, 0].detach().cpu()
125
+ true_mask = labels_clf[i, 0].detach().cpu()
126
+ fracture_sim_scores.append(fracture_similarity(pred_mask, true_mask))
127
+
128
+ avg_fracture_sim = float(np.mean(fracture_sim_scores)) if fracture_sim_scores else float('nan')
129
+
130
+ return {
131
+ 'mse': mse_metric.compute().item(),
132
+ 'psnr': psnr_metric.compute().item(),
133
+ 'ssim': ssim_metric.compute().item(),
134
+ 'ae': ae_metric.compute().item(),
135
+ 'acc': acc_metric.compute().item(),
136
+ 'f1': f1_metric.compute().item(),
137
+ 'prec': prec_metric.compute().item(),
138
+ 'rec': rec_metric.compute().item(),
139
+ 'spec': spec_metric.compute().item(),
140
+ 'dice': dice_metric.compute().item(),
141
+ 'iou': iou_metric.compute().item(),
142
+ 'ck': ck_metric.compute().item(),
143
+ 'roc_auc': auroc_metric.compute().item(),
144
+ 'loss': running_loss / len(eval_loader.sampler),
145
+ 'frac_sim': avg_fracture_sim,
146
+ }
147
+
148
+
149
+ def eval_single(gt, pred, threshold=0.5, device="cpu", ignore_index=None):
150
+ """Evaluate metrics for a single prediction and ground truth pair."""
151
+ gt = torch.from_numpy(gt).to(device).float().unsqueeze(0).unsqueeze(0)
152
+ pred = torch.from_numpy(pred).to(device).float().unsqueeze(0).unsqueeze(0)
153
+
154
+ pred_clf = (pred > threshold).long()
155
+ gt_clf = (gt > 0).long()
156
+ if ignore_index not in [0, 1]:
157
+ ignore_index = None
158
+
159
+ # Metrics
160
+ acc_metric = BinaryAccuracy(ignore_index=ignore_index).to(device)
161
+ f1_metric = BinaryF1Score(ignore_index=ignore_index).to(device)
162
+ prec_metric = BinaryPrecision(ignore_index=ignore_index).to(device)
163
+ rec_metric = BinaryRecall(ignore_index=ignore_index).to(device)
164
+ spec_metric = BinarySpecificity(ignore_index=ignore_index).to(device)
165
+ auroc_metric = BinaryAUROC(ignore_index=ignore_index).to(device)
166
+ iou_metric = BinaryJaccardIndex(ignore_index=ignore_index).to(device)
167
+ dice_metric = DiceScore(num_classes=1, average="micro").to(device)
168
+ ck_metric = BinaryCohenKappa().to(device)
169
+ mse_metric = MeanSquaredError().to(device)
170
+ ae_metric = MeanAbsoluteError().to(device)
171
+ psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
172
+ ssim_metric = StructuralSimilarityIndexMeasure().to(device)
173
+
174
+ # Compute metrics
175
+ acc_metric(pred_clf, gt_clf)
176
+ f1_metric(pred_clf, gt_clf)
177
+ prec_metric(pred_clf, gt_clf)
178
+ rec_metric(pred_clf, gt_clf)
179
+ spec_metric(pred_clf, gt_clf)
180
+ if gt_clf.numel() > 0 and gt_clf.min() != gt_clf.max():
181
+ auroc_metric(pred, gt_clf.int())
182
+ dice_metric(pred_clf, gt_clf)
183
+ iou_metric(pred_clf, gt_clf)
184
+ ck_metric(pred_clf, gt_clf)
185
+ mse_metric(pred, gt)
186
+ psnr_metric(pred, gt)
187
+ ssim_metric(pred, gt)
188
+ ae_metric(pred, gt)
189
+
190
+ return {
191
+ 'mse': mse_metric.compute().item(),
192
+ 'psnr': psnr_metric.compute().item(),
193
+ 'ssim': ssim_metric.compute().item(),
194
+ 'ae': ae_metric.compute().item(),
195
+ 'acc': acc_metric.compute().item(),
196
+ 'f1': f1_metric.compute().item(),
197
+ 'prec': prec_metric.compute().item(),
198
+ 'rec': rec_metric.compute().item(),
199
+ 'spec': spec_metric.compute().item(),
200
+ 'dice': dice_metric.compute().item(),
201
+ 'iou': iou_metric.compute().item(),
202
+ 'ck': ck_metric.compute().item(),
203
+ 'roc_auc': auroc_metric.compute().item(),
204
+ }
205
+
206
+
207
+ def save_metrics(metrics: dict, kind: str, writer, epoch: int):
208
+ """Log metrics to a TensorBoard writer."""
209
+ writer.add_scalar(f"Loss/{kind}", metrics['loss'], epoch)
210
+ writer.add_scalar(f"ACC/{kind}", metrics['acc'], epoch)
211
+ writer.add_scalar(f"F1/{kind}", metrics['f1'], epoch)
212
+ writer.add_scalar(f"PREC/{kind}", metrics['prec'], epoch)
213
+ writer.add_scalar(f"REC/{kind}", metrics['rec'], epoch)
214
+ writer.add_scalar(f"ROC_AUC/{kind}", metrics['roc_auc'], epoch)
215
+ writer.add_scalar(f"MSE/{kind}", metrics['mse'], epoch)
216
+ writer.add_scalar(f"PSNR/{kind}", metrics['psnr'], epoch)
217
+ writer.add_scalar(f"SSIM/{kind}", metrics['ssim'], epoch)
218
+ writer.add_scalar(f"SPEC/{kind}", metrics['spec'], epoch)
219
+ writer.add_scalar(f"DICE/{kind}", metrics['dice'], epoch)
220
+ writer.add_scalar(f"AE/{kind}", metrics['ae'], epoch)
221
+ writer.add_scalar(f"IoU/{kind}", metrics['iou'], epoch)
222
+ writer.flush()