Skip to content

Commit 1877b67

Browse files
yerlanyerlan
authored andcommitted
proper implementation with readme.md file
0 parents  commit 1877b67

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Proper ResNet Implementation for CIFAR10/CIFAR100
2+
[Torchvision model zoo](https://github.com/pytorch/vision/tree/master/torchvision/models) provides number of implementations of various state-of-the-art architectures, however, most of them are defined and implemented for ImageNet.
3+
Usually it is very straightforward to use them on other datasets, but sometimes this models needs manual setup.
4+
5+
Unfortunately, none of the repositories with ResNets on CIFAR10 provides an implementation as described in [original paper](https://arxiv.org/abs/1512.03385). If you just use torchvision's models on CIFAR10 you'll get the model **that differs in number of layers and parameters**. That is unacceptable if you want to directly compare ResNets on CIFAR10.
6+
The purpose of this repo is to provide a valid implementation of ResNet-s for CIFAR10. Following models are provided:
7+
| Name | # layers | # params|
8+
|-----------|---------:|-------:|
9+
|ResNet20 | 20 | 0.27M |
10+
|ResNet32 | 32 | 0.46M |
11+
|ResNet44 | 44 | 0.66M |
12+
|ResNet56 | 56 | 0.85M |
13+
|ResNet110 | 110 | 1.7M |
14+
|ResNet1202 | 1202 | 19.4M |
15+
And their implementation matches description in original paper.

resnet.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
'''
2+
Properly implemented ResNet-s for CIFAR10 as described in paper [1].
3+
4+
The implementation and structure of this file is hugely influenced by [2]
5+
which is implemented for ImageNet and doesn't have option A for identity.
6+
Moreover, most of the implementations on the web is copy-paste from
7+
torchvision's resnet and has wrong number of params.
8+
9+
Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
10+
number of layers and parameters:
11+
12+
name | layers | params
13+
ResNet20 | 20 | 0.27M
14+
ResNet32 | 32 | 0.46M
15+
ResNet44 | 44 | 0.66M
16+
ResNet56 | 56 | 0.85M
17+
ResNet110 | 110 | 1.7M
18+
ResNet1202| 1202 | 19.4m
19+
20+
which this implementation indeed has.
21+
22+
Reference:
23+
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24+
Deep Residual Learning for Image Recognition. arXiv:1512.03385
25+
[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
26+
27+
If you use this implementation in you work, please don't forget to mention the
28+
author, Yerlan Idelbayev.
29+
'''
30+
import torch
31+
import torch.nn as nn
32+
import torch.nn.functional as F
33+
34+
from torch.autograd import Variable
35+
36+
__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
37+
38+
39+
class LambdaLayer(nn.Module):
40+
def __init__(self, lambd):
41+
super(LambdaLayer, self).__init__()
42+
self.lambd = lambd
43+
44+
def forward(self, x):
45+
return self.lambd(x)
46+
47+
48+
class BasicBlock(nn.Module):
49+
expansion = 1
50+
51+
def __init__(self, in_planes, planes, stride=1, option='A'):
52+
super(BasicBlock, self).__init__()
53+
self.compressible_conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
54+
self.bn1 = nn.BatchNorm2d(planes)
55+
self.compressible_conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
56+
self.bn2 = nn.BatchNorm2d(planes)
57+
58+
self.shortcut = nn.Sequential()
59+
if stride != 1 or in_planes != planes:
60+
if option == 'A':
61+
"""
62+
For CIFAR10 ResNet paper uses option A.
63+
"""
64+
self.shortcut = LambdaLayer(lambda x:
65+
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
66+
elif option == 'B':
67+
self.shortcut = nn.Sequential(
68+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
69+
nn.BatchNorm2d(self.expansion * planes)
70+
)
71+
72+
def forward(self, x):
73+
out = F.relu(self.bn1(self.compressible_conv1(x)))
74+
out = self.bn2(self.compressible_conv2(out))
75+
out += self.shortcut(x)
76+
out = F.relu(out)
77+
return out
78+
79+
80+
class ResNet(nn.Module):
81+
def __init__(self, block, num_blocks, num_classes=10):
82+
super(ResNet, self).__init__()
83+
self.in_planes = 16
84+
85+
self.compressible_conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
86+
self.bn1 = nn.BatchNorm2d(16)
87+
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
88+
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
89+
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
90+
self.compressible_linear = nn.Linear(64, num_classes)
91+
92+
def _make_layer(self, block, planes, num_blocks, stride):
93+
strides = [stride] + [1]*(num_blocks-1)
94+
layers = []
95+
for stride in strides:
96+
layers.append(block(self.in_planes, planes, stride))
97+
self.in_planes = planes * block.expansion
98+
99+
return nn.Sequential(*layers)
100+
101+
def forward(self, x):
102+
out = F.relu(self.bn1(self.compressible_conv1(x)))
103+
out = self.layer1(out)
104+
out = self.layer2(out)
105+
out = self.layer3(out)
106+
out = F.avg_pool2d(out, out.size()[3])
107+
out = out.view(out.size(0), -1)
108+
out = self.compressible_linear(out)
109+
return out
110+
111+
112+
def resnet20():
113+
return ResNet(BasicBlock, [3, 3, 3])
114+
115+
116+
def resnet32():
117+
return ResNet(BasicBlock, [5, 5, 5])
118+
119+
120+
def resnet44():
121+
return ResNet(BasicBlock, [7, 7, 7])
122+
123+
124+
def resnet56():
125+
return ResNet(BasicBlock, [9, 9, 9])
126+
127+
128+
def resnet110():
129+
return ResNet(BasicBlock, [18, 18, 18])
130+
131+
132+
def resnet1202():
133+
return ResNet(BasicBlock, [200, 200, 200])
134+
135+
136+
def test(net):
137+
import numpy as np
138+
total_params = 0
139+
140+
for x in filter(lambda p: p.requires_grad, net.parameters()):
141+
total_params += np.prod(x.data.numpy().shape)
142+
print("Total number of params", total_params)
143+
print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
144+
145+
146+
if __name__ == "__main__":
147+
for net_name in __all__:
148+
if net_name.startswith('resnet'):
149+
print(net_name)
150+
test(globals()[net_name]())
151+
print()

0 commit comments

Comments
 (0)