Dhenenjay commited on
Commit
8870824
·
verified ·
1 Parent(s): 0e3f667

Upload SoftPool.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. SoftPool.py +80 -0
SoftPool.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pure PyTorch implementation of SoftPool.
3
+ This is a fallback that doesn't require CUDA kernel compilation.
4
+ SoftPool: https://arxiv.org/abs/2101.00440
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False):
12
+ """
13
+ Apply soft pooling on 2D input tensor.
14
+
15
+ SoftPool approximates max pooling while maintaining differentiability
16
+ by using exponential weighting: y = sum(x * exp(x)) / sum(exp(x))
17
+
18
+ Args:
19
+ x: Input tensor of shape (N, C, H, W)
20
+ kernel_size: Pooling kernel size
21
+ stride: Stride (defaults to kernel_size)
22
+ force_inplace: Unused, for API compatibility
23
+
24
+ Returns:
25
+ Pooled tensor
26
+ """
27
+ if stride is None:
28
+ stride = kernel_size
29
+
30
+ if isinstance(kernel_size, int):
31
+ kernel_size = (kernel_size, kernel_size)
32
+ if isinstance(stride, int):
33
+ stride = (stride, stride)
34
+
35
+ # Use unfold to extract patches
36
+ batch, channels, height, width = x.shape
37
+ kh, kw = kernel_size
38
+ sh, sw = stride
39
+
40
+ # Calculate output dimensions
41
+ out_h = (height - kh) // sh + 1
42
+ out_w = (width - kw) // sw + 1
43
+
44
+ # Apply exponential weighting
45
+ # For numerical stability, subtract max before exp
46
+ x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride) # (N, C*kh*kw, out_h*out_w)
47
+ x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w)
48
+
49
+ # Softmax-style weighting for soft pooling
50
+ x_max = x_unfold.max(dim=2, keepdim=True)[0]
51
+ exp_x = torch.exp(x_unfold - x_max) # Numerical stability
52
+
53
+ # Weighted sum: sum(x * exp(x)) / sum(exp(x))
54
+ softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8)
55
+
56
+ # Reshape to output format
57
+ softpool = softpool.view(batch, channels, out_h, out_w)
58
+
59
+ return softpool
60
+
61
+
62
+ class SoftPool2d(nn.Module):
63
+ """
64
+ SoftPool 2D Layer.
65
+
66
+ A differentiable pooling operation that approximates max pooling
67
+ using exponential weighting.
68
+ """
69
+
70
+ def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False):
71
+ super(SoftPool2d, self).__init__()
72
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
73
+ self.stride = stride if stride is not None else self.kernel_size
74
+ self.force_inplace = force_inplace
75
+
76
+ def forward(self, x):
77
+ return soft_pool2d(x, self.kernel_size, self.stride, self.force_inplace)
78
+
79
+ def extra_repr(self):
80
+ return f'kernel_size={self.kernel_size}, stride={self.stride}'