File size: 5,020 Bytes
d258641
 
 
 
 
c514002
 
d258641
c514002
 
3fb9cf2
d258641
c514002
 
555e888
d258641
c514002
 
555e888
d258641
c514002
 
d258641
2912c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b5f9ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c514002
 
d258641
 
c514002
d258641
 
 
2912c27
 
 
 
c514002
 
 
 
 
 
 
 
d258641
 
 
 
 
 
 
 
73722a9
 
c886c6f
 
 
 
 
 
98b524c
73722a9
2912c27
c886c6f
2912c27
3b5f9ac
d258641
c514002
 
 
73722a9
c514002
 
 
 
 
73722a9
c886c6f
 
 
 
 
 
c514002
 
73722a9
 
c514002
 
 
 
d258641
c514002
 
 
2912c27
c514002
d258641
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
from PIL import Image
import torch
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution

# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Model IDs / paths ----
PRETRAINED_ID = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"          # crystal clear
FINETUNED_ID = "swin2sr_div2k_finetuned_x4_1000steps"                     # smooth (local folder in repo)

# ---- Load processors ----
processor_pre = AutoImageProcessor.from_pretrained(PRETRAINED_ID)
processor_ft  = AutoImageProcessor.from_pretrained(FINETUNED_ID, local_files_only=True)

# ---- Load models ----
model_pre = Swin2SRForImageSuperResolution.from_pretrained(PRETRAINED_ID).to(device)
model_ft  = Swin2SRForImageSuperResolution.from_pretrained(FINETUNED_ID, local_files_only=True).to(device)

model_pre.eval()
model_ft.eval()


ABOUT_TEXT = """
    **About this tool**

    This is a free image enhancement tool that uses Swin2SR to reconstruct and refine image detail at the pixel level (“pixel refabrication”).  
    I built it because many of my own photos are distorted or only exist as low‑quality copies on social media.

    **Tip for social media images**

    If your only copy is on Instagram, open Instagram on a computer, view the photo in full‑screen, take a screenshot, and crop it tightly around the image before uploading here.

    **Usage notes**

    Processing and queue times can occasionally exceed ten minutes, especially during heavy use.  
    If this tool gets regular traffic, I plan to upgrade to a stronger GPU backend.

    Please reach out if you need any help or have questions.
    """  


def swin2sr_upscale(input_image: Image.Image, mode: str):
    if input_image is None:
        return None, "No image uploaded."

    try:
        input_image = input_image.convert("RGB")
        w_lr, h_lr = input_image.size

        if mode == "Smooth (Over Smoothing | tuned_1)":
            model = model_ft
            processor = processor_ft
        else:
            model = model_pre
            processor = processor_pre

        inputs = processor(images=input_image, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)

        sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1)
        sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy())
        sr_image = Image.fromarray(sr_array).convert("RGB")

        w_sr, h_sr = sr_image.size
        msg = (
            f"Original resolution: {w_lr}×{h_lr} pixels\n"
            f"Super-resolved: {w_sr}×{h_sr} pixels "
            f"(scale factors: {w_sr / w_lr:.1f}×, {h_sr / h_lr:.1f}×)"
        )
        return sr_image, msg

    except Exception as e:
        return None, f"Error: {type(e).__name__}: {e}"
'''

# ---- Inference function ----
def swin2sr_upscale(input_image: Image.Image, mode: str):
    """
    Run 4x super-resolution using Swin2SR.
    mode: "Crystal clear (pretrained)" or "Smooth (fine-tuned)".
    """
    if input_image is None:
        return None
    
    # original size
    w_lr, h_lr = input_image.size
    
    if mode == "Smooth (fine-tuned)":
        model = model_ft
        processor = processor_ft
    else:
        model = model_pre
        processor = processor_pre

    inputs = processor(images=input_image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1)
    sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy())
    sr_image = Image.fromarray(sr_array)

    # ---ensure JPEG-compatible mode ---
    sr_image = sr_image.convert("RGB")
    # ----OG and new size ---
    w_sr, h_sr = sr_image.size
    msg = (
        f"Original resolution: {w_lr}×{h_lr} pixels\n"
        f"Super-resolved: {w_sr}×{h_sr} pixels "
        f"(scale factors: {w_sr / w_lr:.1f}×, {h_sr / h_lr:.1f}×)"
    )

  

    return sr_image, msg
'''
# ---- Gradio UI ----
with gr.Blocks() as demo:
    gr.Markdown("# Image Super-Resolution (Swin2SR x4)")
    gr.Markdown(
        "Choose **Clear (pretrained)** for the original Swin2SR model, "
        "or **Smooth (fine-tuned)** for the Swin2SR version we fine-tuned on DIV2K patches."
    )

    with gr.Row():
        input_image = gr.Image(type="pil", label="Upload low-res image")
        output_image = gr.Image(type="pil",format="jpeg", label="4x Super-resolved image")
    
    
    res_info = gr.Markdown("Resolution info.")
     # ---- info / help box ----
    gr.Markdown(ABOUT_TEXT)
    
    mode_dropdown = gr.Dropdown(
        label="Style",
        choices=["Clear (OFtB)", "Smooth (Over Smoothing | tuned_1)"],
        value="Clear (OFtB)",
        interactive=True,
    )

    run_btn = gr.Button("Upscale")

    run_btn.click(
        fn=swin2sr_upscale,
        inputs=[input_image, mode_dropdown],
        outputs=[output_image, res_info]
    )

if __name__ == "__main__":
    demo.launch()