|
|
""" |
|
|
Axion: SAR-to-Optical Translation - HuggingFace Space |
|
|
Fixed for ZeroGPU with lazy loading |
|
|
""" |
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
from PIL import Image, ImageEnhance |
|
|
import gradio as gr |
|
|
import tempfile |
|
|
import time |
|
|
|
|
|
print("[Axion] Starting app...") |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
GPU_AVAILABLE = True |
|
|
print("[Axion] ZeroGPU available") |
|
|
except ImportError: |
|
|
GPU_AVAILABLE = False |
|
|
spaces = None |
|
|
print("[Axion] Running without ZeroGPU") |
|
|
|
|
|
|
|
|
|
|
|
_torch = None |
|
|
_model_modules = None |
|
|
|
|
|
def get_torch(): |
|
|
global _torch |
|
|
if _torch is None: |
|
|
print("[Axion] Importing torch...") |
|
|
import torch |
|
|
_torch = torch |
|
|
print(f"[Axion] PyTorch {torch.__version__} loaded") |
|
|
return _torch |
|
|
|
|
|
def get_model_modules(): |
|
|
global _model_modules |
|
|
if _model_modules is None: |
|
|
print("[Axion] Importing model modules...") |
|
|
from unet import UNet |
|
|
from diffusion import GaussianDiffusion |
|
|
_model_modules = (UNet, GaussianDiffusion) |
|
|
print("[Axion] Model modules loaded") |
|
|
return _model_modules |
|
|
|
|
|
|
|
|
def load_sar_image(filepath): |
|
|
"""Load SAR image from various formats.""" |
|
|
try: |
|
|
import rasterio |
|
|
with rasterio.open(filepath) as src: |
|
|
data = src.read(1) |
|
|
if data.dtype in [np.float32, np.float64]: |
|
|
valid = data[np.isfinite(data)] |
|
|
if len(valid) > 0: |
|
|
p2, p98 = np.percentile(valid, [2, 98]) |
|
|
data = np.clip(data, p2, p98) |
|
|
data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8) |
|
|
elif data.dtype == np.uint16: |
|
|
p2, p98 = np.percentile(data, [2, 98]) |
|
|
data = np.clip(data, p2, p98) |
|
|
data = ((data - p2) / (p98 - p2 + 1e-8) * 255).astype(np.uint8) |
|
|
return Image.fromarray(data).convert('RGB') |
|
|
except: |
|
|
pass |
|
|
|
|
|
return Image.open(filepath).convert('RGB') |
|
|
|
|
|
|
|
|
def create_blend_weights(tile_size, overlap): |
|
|
"""Create smooth blending weights for seamless output.""" |
|
|
ramp = np.linspace(0, 1, overlap) |
|
|
weight = np.ones((tile_size, tile_size)) |
|
|
weight[:overlap, :] *= ramp[:, np.newaxis] |
|
|
weight[-overlap:, :] *= ramp[::-1, np.newaxis] |
|
|
weight[:, :overlap] *= ramp[np.newaxis, :] |
|
|
weight[:, -overlap:] *= ramp[np.newaxis, ::-1] |
|
|
return weight[:, :, np.newaxis] |
|
|
|
|
|
|
|
|
def build_model(device): |
|
|
"""Build and load the Axion model.""" |
|
|
torch = get_torch() |
|
|
UNet, GaussianDiffusion = get_model_modules() |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
print("[Axion] Building model architecture...") |
|
|
|
|
|
image_size = 256 |
|
|
num_inference_steps = 1 |
|
|
|
|
|
|
|
|
unet = UNet( |
|
|
in_channel=3, |
|
|
out_channel=3, |
|
|
norm_groups=16, |
|
|
inner_channel=64, |
|
|
channel_mults=[1, 2, 4, 8, 16], |
|
|
attn_res=[], |
|
|
res_blocks=1, |
|
|
dropout=0, |
|
|
image_size=image_size, |
|
|
condition_ch=3 |
|
|
) |
|
|
|
|
|
|
|
|
schedule_opt = { |
|
|
'schedule': 'linear', |
|
|
'n_timestep': num_inference_steps, |
|
|
'linear_start': 1e-6, |
|
|
'linear_end': 1e-2, |
|
|
'ddim': 1, |
|
|
'lq_noiselevel': 0 |
|
|
} |
|
|
|
|
|
opt = { |
|
|
'stage': 2, |
|
|
'ddim_steps': num_inference_steps, |
|
|
'model': { |
|
|
'beta_schedule': { |
|
|
'train': {'n_timestep': 1000}, |
|
|
'val': schedule_opt |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
model = GaussianDiffusion( |
|
|
denoise_fn=unet, |
|
|
image_size=image_size, |
|
|
channels=3, |
|
|
loss_type='l1', |
|
|
conditional=True, |
|
|
schedule_opt=schedule_opt, |
|
|
xT_noise_r=0, |
|
|
seed=1, |
|
|
opt=opt |
|
|
) |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
print("[Axion] Downloading weights...") |
|
|
weights_path = hf_hub_download( |
|
|
repo_id="Dhenenjay/Axion-S2O", |
|
|
filename="I700000_E719_gen.pth" |
|
|
) |
|
|
|
|
|
print(f"[Axion] Loading weights from: {weights_path}") |
|
|
state_dict = torch.load(weights_path, map_location=device, weights_only=False) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
model.eval() |
|
|
|
|
|
print("[Axion] Model ready!") |
|
|
return model |
|
|
|
|
|
|
|
|
def preprocess(image, device, image_size=256): |
|
|
"""Preprocess input SAR image.""" |
|
|
torch = get_torch() |
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
if image.size != (image_size, image_size): |
|
|
image = image.resize((image_size, image_size), Image.LANCZOS) |
|
|
|
|
|
img_np = np.array(image).astype(np.float32) / 255.0 |
|
|
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1) |
|
|
img_tensor = img_tensor * 2.0 - 1.0 |
|
|
|
|
|
return img_tensor.unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
def postprocess(tensor): |
|
|
"""Postprocess output tensor to PIL Image.""" |
|
|
torch = get_torch() |
|
|
|
|
|
tensor = tensor.squeeze(0).cpu() |
|
|
tensor = torch.clamp(tensor, -1, 1) |
|
|
tensor = (tensor + 1.0) / 2.0 |
|
|
|
|
|
img_np = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
return Image.fromarray(img_np) |
|
|
|
|
|
|
|
|
def translate_tile(model, sar_pil, device, seed=42): |
|
|
"""Translate a single tile.""" |
|
|
torch = get_torch() |
|
|
|
|
|
if seed is not None: |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
sar_tensor = preprocess(sar_pil, device) |
|
|
|
|
|
model.set_new_noise_schedule( |
|
|
{ |
|
|
'schedule': 'linear', |
|
|
'n_timestep': 1, |
|
|
'linear_start': 1e-6, |
|
|
'linear_end': 1e-2, |
|
|
'ddim': 1, |
|
|
'lq_noiselevel': 0 |
|
|
}, |
|
|
device, |
|
|
num_train_timesteps=1000 |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output, _ = model.super_resolution( |
|
|
sar_tensor, |
|
|
continous=False, |
|
|
seed=seed if seed is not None else 1, |
|
|
img_s1=sar_tensor |
|
|
) |
|
|
|
|
|
return postprocess(output) |
|
|
|
|
|
|
|
|
def enhance_image(image, contrast=1.1, sharpness=1.2, color=1.1): |
|
|
"""Professional post-processing.""" |
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
image = ImageEnhance.Contrast(image).enhance(contrast) |
|
|
image = ImageEnhance.Sharpness(image).enhance(sharpness) |
|
|
image = ImageEnhance.Color(image).enhance(color) |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
def process_image(image, model, device, overlap=64): |
|
|
"""Process image at full resolution with seamless tiling.""" |
|
|
if isinstance(image, Image.Image): |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
img_np = np.array(image).astype(np.float32) / 255.0 |
|
|
else: |
|
|
img_np = image |
|
|
|
|
|
h, w = img_np.shape[:2] |
|
|
tile_size = 256 |
|
|
step = tile_size - overlap |
|
|
|
|
|
|
|
|
pad_h = (step - (h - overlap) % step) % step |
|
|
pad_w = (step - (w - overlap) % step) % step |
|
|
img_padded = np.pad(img_np, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect') |
|
|
|
|
|
h_pad, w_pad = img_padded.shape[:2] |
|
|
|
|
|
|
|
|
output = np.zeros((h_pad, w_pad, 3), dtype=np.float32) |
|
|
weights = np.zeros((h_pad, w_pad, 1), dtype=np.float32) |
|
|
blend_weight = create_blend_weights(tile_size, overlap) |
|
|
|
|
|
|
|
|
y_positions = list(range(0, h_pad - tile_size + 1, step)) |
|
|
x_positions = list(range(0, w_pad - tile_size + 1, step)) |
|
|
total_tiles = len(y_positions) * len(x_positions) |
|
|
|
|
|
print(f"[Axion] Processing {total_tiles} tiles ({len(x_positions)}x{len(y_positions)}) at {w}x{h}...") |
|
|
|
|
|
tile_idx = 0 |
|
|
for y in y_positions: |
|
|
for x in x_positions: |
|
|
|
|
|
tile = img_padded[y:y+tile_size, x:x+tile_size] |
|
|
tile_pil = Image.fromarray((tile * 255).astype(np.uint8)) |
|
|
|
|
|
|
|
|
result_pil = translate_tile(model, tile_pil, device, seed=42) |
|
|
result = np.array(result_pil).astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
output[y:y+tile_size, x:x+tile_size] += result * blend_weight |
|
|
weights[y:y+tile_size, x:x+tile_size] += blend_weight |
|
|
|
|
|
tile_idx += 1 |
|
|
if tile_idx % 10 == 0 or tile_idx == total_tiles: |
|
|
print(f"[Axion] Tile {tile_idx}/{total_tiles}") |
|
|
|
|
|
|
|
|
output = output / (weights + 1e-8) |
|
|
output = output[:h, :w] |
|
|
|
|
|
return (output * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
|
|
|
_cached_model = None |
|
|
|
|
|
|
|
|
def _translate_impl(file, overlap, enhance_output): |
|
|
"""Main translation function - runs on GPU.""" |
|
|
global _cached_model |
|
|
|
|
|
if file is None: |
|
|
return None, None, "Please upload a SAR image" |
|
|
|
|
|
torch = get_torch() |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"[Axion] Using device: {device}") |
|
|
|
|
|
|
|
|
if _cached_model is None: |
|
|
_cached_model = build_model(device) |
|
|
|
|
|
model = _cached_model |
|
|
|
|
|
|
|
|
filepath = file.name if hasattr(file, 'name') else file |
|
|
print(f"[Axion] Loading: {filepath}") |
|
|
image = load_sar_image(filepath) |
|
|
|
|
|
w, h = image.size |
|
|
print(f"[Axion] Input size: {w}x{h}") |
|
|
|
|
|
start = time.time() |
|
|
result = process_image(image, model, device, overlap=int(overlap)) |
|
|
elapsed = time.time() - start |
|
|
|
|
|
result_pil = Image.fromarray(result) |
|
|
|
|
|
if enhance_output: |
|
|
result_pil = enhance_image(result_pil) |
|
|
|
|
|
tiff_path = tempfile.mktemp(suffix='.tiff') |
|
|
result_pil.save(tiff_path, format='TIFF', compression='lzw') |
|
|
|
|
|
print(f"[Axion] Complete in {elapsed:.1f}s!") |
|
|
|
|
|
info = f"Processed in {elapsed:.1f}s | Output: {result_pil.size[0]}x{result_pil.size[1]}" |
|
|
|
|
|
return result_pil, tiff_path, info |
|
|
|
|
|
|
|
|
|
|
|
if GPU_AVAILABLE and spaces is not None: |
|
|
@spaces.GPU(duration=300) |
|
|
def translate_sar(file, overlap, enhance_output): |
|
|
return _translate_impl(file, overlap, enhance_output) |
|
|
else: |
|
|
translate_sar = _translate_impl |
|
|
|
|
|
|
|
|
print("[Axion] Building Gradio interface...") |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Axion - SAR to Optical") as demo: |
|
|
gr.HTML(""" |
|
|
<style> |
|
|
.gradio-container { background: linear-gradient(180deg, #0a0a0a 0%, #1a1a1a 100%) !important; } |
|
|
</style> |
|
|
<div style="text-align: center; padding: 40px 20px 20px 20px;"> |
|
|
<h1 style="font-family: 'Helvetica Neue', Arial, sans-serif; font-size: 3.2rem; font-weight: 200; color: #ffffff; margin-bottom: 0.5rem; letter-spacing: -0.02em;">SAR to Optical Image Translation</h1> |
|
|
<p style="font-family: 'Helvetica Neue', Arial, sans-serif; font-size: 1.1rem; font-weight: 300; color: #888888;">Transform radar imagery into crystal-clear optical views using our foundation model</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_file = gr.File(label="Upload SAR Image", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"]) |
|
|
gr.HTML(""" |
|
|
<div style="font-size: 0.8rem; color: #666; padding: 8px 12px; background: rgba(255,255,255,0.03); border-radius: 6px; margin: 8px 0;"> |
|
|
<strong style="color: #888;">Input Guidelines:</strong><br> |
|
|
• Use raw SAR imagery (single-band grayscale)<br> |
|
|
• VV polarization preferred, VH also supported<br> |
|
|
• Any resolution supported (processed in 256×256 tiles) |
|
|
</div> |
|
|
""") |
|
|
with gr.Row(): |
|
|
overlap = gr.Slider(16, 128, value=64, step=16, label="Tile Overlap") |
|
|
enhance = gr.Checkbox(value=True, label="Enhance Output") |
|
|
submit_btn = gr.Button("Translate", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_image = gr.Image(label="Optical Output") |
|
|
output_file = gr.File(label="Download") |
|
|
info_text = gr.Textbox(label="Info", show_label=False) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=translate_sar, |
|
|
inputs=[input_file, overlap, enhance], |
|
|
outputs=[output_image, output_file, info_text] |
|
|
) |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; padding: 20px; color: #555; font-size: 0.85rem;"> |
|
|
Powered by <strong style="color: #888;">Axion</strong> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
print("[Axion] Launching app...") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch(ssr_mode=False) |
|
|
|