forked from qinzheng93/CNN-Calculator
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathIntroNet.py
More file actions
44 lines (36 loc) · 1.3 KB
/
IntroNet.py
File metadata and controls
44 lines (36 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from __future__ import print_function
from DNNCalculator import Tensor, DNNCalculator
class IntroNetCalc(DNNCalculator):
def __init__(self, only_mac=True):
super(IntroNetCalc, self).__init__(only_mac)
'''
IntroNet Model.
Source: http://arxiv.org/abs/1702.00832
'''
def IntroNet(self, tensor):
'''
This function is called IntroNet.
'''
tensor = self.Conv2d(tensor, out_c=128, size=(
8, 2), stride=(1, 0), padding=(0, 0))
tensor = self.MaxPool2d(tensor, size=(2, 1), stride=(2, 0))
tensor = self.Conv2d(tensor, out_c=64, size=(
16, 1), stride=(1, 0), padding=(0, 0))
tensor = self.MaxPool2d(tensor, size=(2, 1), stride=(2, 0))
tensor = self.Flatten(tensor)
tensor = self.Linear(tensor, 128)
tensor = self.Linear(tensor, 64)
tensor = self.Linear(tensor, 32)
tensor = self.Linear(tensor, 10)
return tensor
def calculate(self):
'''
This function calculates the FLOPs
'''
tensor = Tensor(1, 128, 2)
tensor = self.IntroNet(tensor)
print('params: {}, flops: {}'.format(self.params, self.flops))
if __name__ == '__main__':
only_mac = True
calculator = IntroNetCalc(only_mac=only_mac)
calculator.calculate()