| """Pure PyTorch SoftPool implementation.""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False): | |
| if stride is None: | |
| stride = kernel_size | |
| if isinstance(kernel_size, int): | |
| kernel_size = (kernel_size, kernel_size) | |
| if isinstance(stride, int): | |
| stride = (stride, stride) | |
| batch, channels, height, width = x.shape | |
| kh, kw = kernel_size | |
| sh, sw = stride | |
| out_h = (height - kh) // sh + 1 | |
| out_w = (width - kw) // sw + 1 | |
| x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride) | |
| x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w) | |
| x_max = x_unfold.max(dim=2, keepdim=True)[0] | |
| exp_x = torch.exp(x_unfold - x_max) | |
| softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8) | |
| return softpool.view(batch, channels, out_h, out_w) | |
| class SoftPool2d(nn.Module): | |
| def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False): | |
| super(SoftPool2d, self).__init__() | |
| self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) | |
| self.stride = stride if stride is not None else self.kernel_size | |
| def forward(self, x): | |
| return soft_pool2d(x, self.kernel_size, self.stride) | |