Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # encoding: utf-8 | |
| ''' | |
| @author: Xu Yan | |
| @file: basic_blocks.py | |
| @time: 2021/4/14 22:53 | |
| ''' | |
| import torch.nn as nn | |
| import torchsparse.nn as spnn | |
| class BasicConvolutionBlock(nn.Module): | |
| def __init__(self, inc, outc, ks=3, stride=1, dilation=1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| spnn.Conv3d( | |
| inc, | |
| outc, | |
| kernel_size=ks, | |
| dilation=dilation, | |
| stride=stride), spnn.BatchNorm(outc), | |
| spnn.ReLU(True)) | |
| def forward(self, x): | |
| out = self.net(x) | |
| return out | |
| class BasicDeconvolutionBlock(nn.Module): | |
| def __init__(self, inc, outc, ks=3, stride=1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| spnn.Conv3d( | |
| inc, | |
| outc, | |
| kernel_size=ks, | |
| stride=stride, | |
| transposed=True), | |
| spnn.BatchNorm(outc), | |
| spnn.ReLU(True)) | |
| def forward(self, x): | |
| return self.net(x) | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, inc, outc, ks=3, stride=1, dilation=1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| spnn.Conv3d( | |
| inc, | |
| outc, | |
| kernel_size=ks, | |
| dilation=dilation, | |
| stride=stride), spnn.BatchNorm(outc), | |
| spnn.ReLU(True), | |
| spnn.Conv3d( | |
| outc, | |
| outc, | |
| kernel_size=ks, | |
| dilation=dilation, | |
| stride=1), | |
| spnn.BatchNorm(outc)) | |
| self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ | |
| nn.Sequential( | |
| spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), | |
| spnn.BatchNorm(outc) | |
| ) | |
| self.ReLU = spnn.ReLU(True) | |
| def forward(self, x): | |
| out = self.ReLU(self.net(x) + self.downsample(x)) | |
| return out | |