Axion / softpool.py
Dhenenjay's picture
Upload softpool.py with huggingface_hub
a9abf27 verified
"""Pure PyTorch SoftPool implementation."""
import torch
import torch.nn as nn
import torch.nn.functional as F
def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False):
if stride is None:
stride = kernel_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
batch, channels, height, width = x.shape
kh, kw = kernel_size
sh, sw = stride
out_h = (height - kh) // sh + 1
out_w = (width - kw) // sw + 1
x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride)
x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w)
x_max = x_unfold.max(dim=2, keepdim=True)[0]
exp_x = torch.exp(x_unfold - x_max)
softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8)
return softpool.view(batch, channels, out_h, out_w)
class SoftPool2d(nn.Module):
def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False):
super(SoftPool2d, self).__init__()
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
self.stride = stride if stride is not None else self.kernel_size
def forward(self, x):
return soft_pool2d(x, self.kernel_size, self.stride)