-
Notifications
You must be signed in to change notification settings - Fork 14
/
demo.py
61 lines (46 loc) · 1.46 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from load_model import load_model
from rectify import inference
"""
1. load_model()
根据模型类型,导入存储在硬盘中的模型文件至内存。
Parameters:
None
Returns:
- model: {UNetRNN}
模型对象,包括模型各层结构和预训练的参数。
- device: {device}
torch.device类对象,表示分配给torch.Tensor进行运算的设备。包含设备类型("cpu"或"cuda")和设备序号。
Example:
from load_model import load_model
model, device = load_model()
2. inferecne(input_path, output_path, model, device)
校正推理,对单张图像进行校正处理。
Parameters:
- input_path: {str}
待校正图像路径
- output_path: {str}
图像保存路径
- model: {UNetRNN}
模型对象,包括模型各层结构和预训练的参数。
- device: {device}
torch.device类对象,表示分配给torch.Tensor进行运算的设备。包含设备类型("cpu"或"cuda")和设备序号。
Example:
from rectify import inference
from load_model import load_model
input = 'example/card.jpg'
output = 'result/card.png'
model, device = load_model()
inference(input, output, trained_model, device)
"""
if __name__ == "__main__":
"""
Demo
"""
input1 = 'example/card1.jpg'
input2 = 'example/card2.jpg'
output1 = 'result/card1.png'
output2 = 'result/card2.png'
trained_model, device = load_model()
inference(input1, output1, trained_model, device)
inference(input2, output2, trained_model, device)
print("Done.")