需求:将图片文件保存成Tfrecord的格式.
解决方法:基于tensorflow、cv2、numpy等库完成该功能.
注:改编自网上代码
1) 准备要训练的手写识别的图片文件,并按照目录结构存放。见下图示意:
2) 生成训练图片和标签对应的文本文件,见下图示意:
3) 编写图片生成TFrecord代码,代码见下:
import numpy as np
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from io import StringIO,BytesIO
# 将value转化成int64字节属性
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 将value转化成bytes属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 从训练文本里读取样本、标签。返回样本、标签列表以及行数
def load_file(examples_list_file):
# type: (object) -> object
lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')])
examples = []
labels = []
for example, label in lines:
examples.append(example)
labels.append(label)
return np.asarray(examples), np.asarray(labels), len(lines)
def extract_image(filename, resize_height, resize_width):
image = cv2.imread(filename)
image = cv2.resize(image, (resize_height, resize_width))
b, g, r = cv2.split(image)
rgb_image = cv2.merge([r, g, b])
rgb_image = rgb_image / 255.
rgb_image = rgb_image.astype(np.float32)
return rgb_image
def Image2TFRecord(trainDir,trainLabelFile,tfFile):
resize_height = 28
resize_width = 28
train_file_root = trainDir
train_file = trainLabelFile
examples, labels, examples_num = load_file(train_file)
writer = tf.python_io.TFRecordWriter(tfFile)
for i, [example, label] in enumerate(zip(examples, labels)):
print('No.%d' % (i))
print(examples[i].decode(encoding="utf-8"))
root = train_file_root + '/' + examples[i].decode(encoding="utf-8")
print(root)
image = extract_image(root, resize_height, resize_width)
a = image.shape
print(root)
print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label))
image_raw = image.tostring() # 将Image转化成字符
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw': _bytes_feature(image_raw),
'height': _int64_feature(image.shape[0]),
'width': _int64_feature(image.shape[1]),
'depth': _int64_feature(image.shape[2]),
'label': _int64_feature(label)
}))
writer.write(example.SerializeToString())
writer.close()
if __name__ == '__main__':
##Image2TFRecord('E:/Python/mnist_img_data4','E:/Python/mnist_img_data4/train.txt','E:/Python/mnist_img_output/a44.tfrecords')
##TFrcords2Img('E:/Python/mnist_img_output/a44.tfrecords')
#### 延展阅读 https:///doc/numpy-1.10.1/user/basics.io.genfromtxt.html
###指定按照定长3来分割
data = " 1 2 3\n 4 5 67\n890123 4"
print(np.genfromtxt(StringIO(data), delimiter=3))
''' 结果:
[[ 1. 2. 3.]
[ 4. 5. 67.]
[890. 123. 4.]]
'''
### 指定按照3列来分割,第一列长度是4,第二列长度是3,第三列长度是2
data = "123456789\n 4 7 9\n 4567 9"
print(np.genfromtxt(StringIO(data), delimiter=(4, 3, 2)))
'''结果:
[[1234. 567. 89.]
[ 4. 7. 9.]
[ 4. 567. 9.]]
'''
data = "1, abc , 2\n 3, xxx, 4"
### 默认是不自动替换空格的
print(np.genfromtxt(StringIO(data), delimiter=",", dtype="|S5"))
print(np.genfromtxt(StringIO(data), delimiter=",", dtype="|S5", autostrip=True))
'''
[[b'1' b' abc ' b' 2']
[b'3' b' xxx' b' 4']]
[[b'1' b'abc' b'2']
[b'3' b'xxx' b'4']]
'''
data = "1 2 3\n4 5 6"
print(np.genfromtxt(StringIO(data), usecols=(0, -1)))
'''
[[1. 3.]
[4. 6.]]
'''
print(np.genfromtxt(StringIO(data),names="a, b, c", usecols=("b", "c")))
'''
[(2., 3.) (5., 6.)]
'''
##DataType的类型:
4)在中E:/Python/mnist_img_output目录下生成a4.tfrecords