Upload SoftPool.py with huggingface_hub
Browse files- SoftPool.py +80 -0
SoftPool.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pure PyTorch implementation of SoftPool.
|
| 3 |
+
This is a fallback that doesn't require CUDA kernel compilation.
|
| 4 |
+
SoftPool: https://arxiv.org/abs/2101.00440
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False):
|
| 12 |
+
"""
|
| 13 |
+
Apply soft pooling on 2D input tensor.
|
| 14 |
+
|
| 15 |
+
SoftPool approximates max pooling while maintaining differentiability
|
| 16 |
+
by using exponential weighting: y = sum(x * exp(x)) / sum(exp(x))
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
x: Input tensor of shape (N, C, H, W)
|
| 20 |
+
kernel_size: Pooling kernel size
|
| 21 |
+
stride: Stride (defaults to kernel_size)
|
| 22 |
+
force_inplace: Unused, for API compatibility
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Pooled tensor
|
| 26 |
+
"""
|
| 27 |
+
if stride is None:
|
| 28 |
+
stride = kernel_size
|
| 29 |
+
|
| 30 |
+
if isinstance(kernel_size, int):
|
| 31 |
+
kernel_size = (kernel_size, kernel_size)
|
| 32 |
+
if isinstance(stride, int):
|
| 33 |
+
stride = (stride, stride)
|
| 34 |
+
|
| 35 |
+
# Use unfold to extract patches
|
| 36 |
+
batch, channels, height, width = x.shape
|
| 37 |
+
kh, kw = kernel_size
|
| 38 |
+
sh, sw = stride
|
| 39 |
+
|
| 40 |
+
# Calculate output dimensions
|
| 41 |
+
out_h = (height - kh) // sh + 1
|
| 42 |
+
out_w = (width - kw) // sw + 1
|
| 43 |
+
|
| 44 |
+
# Apply exponential weighting
|
| 45 |
+
# For numerical stability, subtract max before exp
|
| 46 |
+
x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride) # (N, C*kh*kw, out_h*out_w)
|
| 47 |
+
x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w)
|
| 48 |
+
|
| 49 |
+
# Softmax-style weighting for soft pooling
|
| 50 |
+
x_max = x_unfold.max(dim=2, keepdim=True)[0]
|
| 51 |
+
exp_x = torch.exp(x_unfold - x_max) # Numerical stability
|
| 52 |
+
|
| 53 |
+
# Weighted sum: sum(x * exp(x)) / sum(exp(x))
|
| 54 |
+
softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8)
|
| 55 |
+
|
| 56 |
+
# Reshape to output format
|
| 57 |
+
softpool = softpool.view(batch, channels, out_h, out_w)
|
| 58 |
+
|
| 59 |
+
return softpool
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SoftPool2d(nn.Module):
|
| 63 |
+
"""
|
| 64 |
+
SoftPool 2D Layer.
|
| 65 |
+
|
| 66 |
+
A differentiable pooling operation that approximates max pooling
|
| 67 |
+
using exponential weighting.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False):
|
| 71 |
+
super(SoftPool2d, self).__init__()
|
| 72 |
+
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
| 73 |
+
self.stride = stride if stride is not None else self.kernel_size
|
| 74 |
+
self.force_inplace = force_inplace
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
return soft_pool2d(x, self.kernel_size, self.stride, self.force_inplace)
|
| 78 |
+
|
| 79 |
+
def extra_repr(self):
|
| 80 |
+
return f'kernel_size={self.kernel_size}, stride={self.stride}'
|