-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_preprocessing,py
56 lines (49 loc) · 1.56 KB
/
data_preprocessing,py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import pickle
import struct
import random
import numpy as np
import pandas as pd
import urllib.request as req
import gzip, os, os.path
def unzip_file(savepath, baseurl, files):
if not os.path.exists(savepath): os.mkdir(savepath)
for f in files:
url = baseurl + "/" + f
loc = savepath + "/" + f
print("download:", url)
if not os.path.exists(loc):
req.urlretrieve(url, loc)
for f in files:
gz_file = savepath + "/" + f
raw_file = savepath + "/" + f.replace(".gz", "")
print("gzip:", f)
with gzip.open(gz_file, "rb") as fp:
body = fp.read()
with open(raw_file, "wb") as w:
w.write(body)
def to_csv(name, maxdata):
lbl_f = open("./data/"+name+"-labels-idx1-ubyte", "rb")
img_f = open("./data/"+name+"-images-idx3-ubyte", "rb")
csv_f = open("./data/"+name+".csv", "w", encoding="utf-8")
mag, lbl_count = struct.unpack(">II", lbl_f.read(8))
mag, img_count = struct.unpack(">II", img_f.read(8))
rows, cols = struct.unpack(">II", img_f.read(8))
pixels = rows * cols
res = []
for idx in range(lbl_count):
if idx > maxdata: break
label = struct.unpack("B", lbl_f.read(1))[0]
bdata = img_f.read(pixels)
sdata = list(map(lambda n: str(n), bdata))
csv_f.write(str(label)+",")
csv_f.write(",".join(sdata)+"\r\n")
if idx < 10:
s = "P2 28 28 255\n"
s += " ".join(sdata)
iname = "./data/{0}-{1}-{2}.pgm".format(name, idx, label)
with open(iname, "w", encoding="utf-8") as f:
f.write(s)
csv_f.close()
lbl_f.close()
img_f.close()