由于三国系列暂时没找到对应较大较全的图像数据集,本次图片的操作使用 ImageNet 数据集,ImageNet 是一个著名的公共图像数据库,包含数百万张图像和数千个用于图像分类的对象。用于训练对象分类、检测和分割等任务的模型,它包含超过 1400 万张图像。
在 Python 中处理图像数据的时候,例如应用卷积神经网络(也称CNN)等算法可以处理大量图像数据集,这里就需要学习如何用最简单的方式存储、读取数据。
对于图像数据处理应该有有个定量的比较方式,读取和写入文件需要多长时间,以及将使用多少磁盘内存。
分别用不同的方式去处理、解决图像的存储、性能优化的问题。
整套学习自学教程中应用的数据都是《三國志》、《真·三國無雙》系列游戏中的内容。
图像数据集 CIFAR-10,由 60000 个 32x32 像素的彩色图像组成,这些图像属于不同的对象类别,例如狗、猫和飞机等等。相对而言 CIFAR 不是一个非常大的数据集,但如使用完整的 TinyImages 数据集,那么将需要大约 400GB 的可用磁盘空间。
文中的代码应用的数据集下载地址 CIFAR-10 数据集 。
这份数据是使用 Pickle 进行了序列化和批量保存。pickle模块可以序列化任何 Python 对象,而无需进行任何额外的代码或转换。
图像加载到 NumPy 数组中
import numpy as np
import pickle
from pathlib import Path
# 文件路径
data_dir = Path("Datasets/cifar-10-batches-py/")
# 解码功能
def unpickle(file):
with open(file, "rb") as fo:
dict = pickle.load(fo, encoding="bytes")
return dict
images, labels = [], []
for batch in data_dir.glob("data_batch_*"):
batch_data = unpickle(batch)
for i, flat_im in enumerate(batch_data[b"data"]):
im_channels = []
# 每个图像都是扁平化的,通道按 R, G, B 的顺序排列
for j in range(3):
im_channels.append(
flat_im[j * 1024 : (j + 1) * 1024].reshape((32, 32))
)
# 重建原始图像
images.append(np.dstack((im_channels)))
# 保存标签
labels.append(batch_data[b"labels"][i])
print("加载 CIFAR-10 训练集:")
print(f" - np.shape(images) {np.shape(images)}")
print(f" - np.shape(labels) {np.shape(labels)}")
加载 CIFAR-10 训练集:
- np.shape(images) (50000, 32, 32, 3)
- np.shape(labels) (50000,)
安装三方库 Pillow 用于图像处理 。
pip install Pillow
LMDB 也称为『闪电数据库』,代表闪电内存映射数据库,通过键值存储,而不是关系数据库。有点类似与 Spark 中 RDD 的数据存储方式。
安装三方库 lmdb 用于图像处理 。
pip install lmdb
HDF5 全称 Hierarchical Data Format,是一种可移植、紧凑的科学数据格式。
安装三方库 h5py 用于图像处理 。
pip install h5py
3种不同的方式进行数据读取操作
from pathlib import Path
disk_dir = Path("Datasets/cifar-10-batches-py/disk/")
lmdb_dir = Path("Datasets/cifar-10-batches-py/lmdb/")
hdf5_dir = Path("Datasets/cifar-10-batches-py/hdf5/")
同时加载的数据可以创建文件夹分开保存
disk_dir.mkdir(parents=True, exist_ok=True)
lmdb_dir.mkdir(parents=True, exist_ok=True)
hdf5_dir.mkdir(parents=True, exist_ok=True)
使用 Pillow 完成输入是一个单一的图像 image,在内存中作为一个 NumPy 数组,并且使用唯一的图像 ID 对其进行命名image_id。
单个图像保存到磁盘
from PIL import Image
import csv
def store_single_disk(image, image_id, label):
""" 将单个图像作为 .png 文件存储在磁盘上。
参数:
---------------
image 图像数组, (32, 32, 3) 格式
image_id 图像的整数唯一 ID
label 图像标签
"""
Image.fromarray(image).save(disk_dir / f"{image_id}.png")
with open(disk_dir / f"{image_id}.csv", "wt") as csvfile:
writer = csv.writer(
csvfile, delimiter=" ", quotechar="|", quoting=csv.QUOTE_MINIMAL
)
writer.writerow([label])
LMDB 是一个键值对存储系统,其中每个条目都保存为一个字节数组,键将是每个图像的唯一标识符,值将是图像本身。
键和值都应该是字符串。 常见的用法是将值序列化为字符串,然后在读回时将其反序列化。
用于重建的图像尺寸,某些数据集可能包含不同大小的图像会使用到这个方法。
class CIFAR_Image:
def __init__(self, image, label):
self.channels = image.shape[2]
self.size = image.shape[:2]
self.image = image.tobytes()
self.label = label
def get_image(self):
""" 将图像作为 numpy 数组返回 """
image = np.frombuffer(self.image, dtype=np.uint8)
return image.reshape(*self.size, self.channels)
单个图像保存到 LMDB
import lmdb
import pickle
def store_single_lmdb(image, image_id, label):
""" 将单个图像存储到 LMDB
参数:
---------------
image 图像数组, (32, 32, 3) 格式
image_id 图像的整数唯一 ID
label 图像标签
"""
map_size = image.nbytes * 10
# 创建新的 LMDB 环境
env = lmdb.open(str(lmdb_dir / f"single_lmdb"), map_size=map_size)
# 开始一个新的写事务
with env.begin(write=True) as txn:
# 所有键值对都必须是字符串
value = CIFAR_Image(image, label)
key = f"{image_id:08}"
txn.put(key.encode("ascii"), pickle.dumps(value))
env.close()
一个 HDF5 文件可以包含多个数据集。例如创建两个数据集,分别应用于图像和元数据。
import h5py
def store_single_hdf5(image, image_id, label):
""" 将单个图像存储到 HDF5 文件
参数:
---------------
image 图像数组, (32, 32, 3) 格式
image_id 图像的整数唯一 ID
label 图像标签
"""
# 创建一个新的 HDF5 文件
file = h5py.File(hdf5_dir / f"{image_id}.h5", "w")
# 在文件中创建数据集
dataset = file.create_dataset(
"image", np.shape(image), h5py.h5t.STD_U8BE, data=image
)
meta_set = file.create_dataset(
"meta", np.shape(label), h5py.h5t.STD_U8BE, data=label
)
file.close()
将保存单个图像的所有三个函数放入字典中。
_store_single_funcs = dict(
disk=store_single_disk,
lmdb=store_single_lmdb,
hdf5=store_single_hdf5
)
以三种不同的方式存储保存 CIFAR 中的第一张图像及其对应的标签。
from timeit import timeit
store_single_timings = dict()
for method in ("disk", "lmdb", "hdf5"):
t = timeit(
"_store_single_funcs[method](image, 0, label)",
setup="image=images[0]; label=labels[0]",
number=1,
globals=globals(),
)
store_single_timings[method] = t
print(f"存储方法: {method}, 使用耗时: {t}")
来一个表格看看对比。
存储方法 | 存储耗时 | 使用内存 |
Disk | 0.011781 | 8 K |
LMDB | 0.001933 | 32 K |
HDF5 | 0.001986 | 8 K |
同单个图像存储方法类似,修改代码进行多个图像数据的存储。将多个图像保存为 .png 文件等同于多次调用 store_single_method() 。但此方法不适用于 LMDB 或 HDF5,因为每个图像都有不同的数据库文件。
def store_many_disk(images, labels):
""" 参数:
---------------
images 图像数组 (N, 32, 32, 3) 格式
labels 标签数组 (N,1) 格式
"""
num_images = len(images)
# 一张一张保存所有图片
for i, image in enumerate(images):
Image.fromarray(image).save(disk_dir / f"{i}.png")
# 将所有标签保存到 csv 文件
with open(disk_dir / f"{num_images}.csv", "w") as csvfile:
writer = csv.writer(
csvfile, delimiter=" ", quotechar="|", quoting=csv.QUOTE_MINIMAL
)
for label in labels:
writer.writerow([label])
def store_many_lmdb(images, labels):
""" 参数:
---------------
images 图像数组 (N, 32, 32, 3) 格式
labels 标签数组 (N,1) 格式
"""
num_images = len(images)
map_size = num_images * images[0].nbytes * 10
# 为所有图像创建一个新的 LMDB 数据库
env = lmdb.open(str(lmdb_dir / f"{num_images}_lmdb"), map_size=map_size)
# 在一个事务中写入所有图像
with env.begin(write=True) as txn:
for i in range(num_images):
# 所有键值对都必须是字符串
value = CIFAR_Image(images[i], labels[i])
key = f"{i:08}"
txn.put(key.encode("ascii"), pickle.dumps(value))
env.close()
def store_many_hdf5(images, labels):
""" 参数:
---------------
images 图像数组 (N, 32, 32, 3) 格式
labels 标签数组 (N,1) 格式
"""
num_images = len(images)
# 创建一个新的 HDF5 文件
file = h5py.File(hdf5_dir / f"{num_images}_many.h5", "w")
# 在文件中创建数据集
dataset = file.create_dataset(
"images", np.shape(images), h5py.h5t.STD_U8BE, data=images
)
meta_set = file.create_dataset(
"meta", np.shape(labels), h5py.h5t.STD_U8BE, data=labels
)
file.close()
使用 100000 个图像进行测试
cutoffs = [10, 100, 1000, 10000, 100000]
images = np.concatenate((images, images), axis=0)
labels = np.concatenate((labels, labels), axis=0)
# 确保有 100,000 个图像和标签
print(np.shape(images))
print(np.shape(labels))
(100000, 32, 32, 3)
(100000,)
创建一个计算方式进行对比
_store_many_funcs = dict(
disk=store_many_disk, lmdb=store_many_lmdb, hdf5=store_many_hdf5
)
from timeit import timeit
store_many_timings = {"disk": [], "lmdb": [], "hdf5": []}
for cutoff in cutoffs:
for method in ("disk", "lmdb", "hdf5"):
t = timeit(
"_store_many_funcs[method](images_, labels_)",
setup="images_=images[:cutoff]; labels_=labels[:cutoff]",
number=1,
globals=globals(),
)
store_many_timings[method].append(t)
# 打印出方法、截止时间和使用时间
print(f"Method: {method}, Time usage: {t}")
Method: disk, Time usage: 0.006555199999999983
Method: lmdb, Time usage: 0.002043500000000087
Method: hdf5, Time usage: 0.0014462000000001751
Method: disk, Time usage: 0.05839030000000012
Method: lmdb, Time usage: 0.00463260000000032
Method: hdf5, Time usage: 0.0014360999999998292
Method: disk, Time usage: 0.5384359000000001
Method: lmdb, Time usage: 0.0415766999999998
Method: hdf5, Time usage: 0.002570799999999984
Method: disk, Time usage: 5.3288174
Method: lmdb, Time usage: 0.25626450000000034
Method: hdf5, Time usage: 0.01716279999999948
Method: disk, Time usage: 53.710009
Method: lmdb, Time usage: 2.387375200000008
Method: hdf5, Time usage: 0.24997290000000305
PLOT 显示具有多个数据集和匹配图例的单个图
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei']
def plot_with_legend(
x_range, y_data, legend_labels, x_label, y_label, title, log=False
):
""" 参数:
--------------
x_range 包含 x 数据的列表
y_data 包含 y 值的列表
legend_labels 字符串图例标签列表
x_label x 轴标签
y_label y 轴标签
"""
plt.figure(figsize=(10, 7))
if len(y_data) != len(legend_labels):
raise TypeError(
"数据集的数量与标签的数量不匹配"
)
all_plots = []
for data, label in zip(y_data, legend_labels):
if log:
temp, = plt.loglog(x_range, data, label=label)
else:
temp, = plt.plot(x_range, data, label=label)
all_plots.append(temp)
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.legend(handles=all_plots)
plt.show()
# 获取要显示的存储时间数据
disk_x_stroe = store_many_timings["disk"]
lmdb_x_stroe = store_many_timings["lmdb"]
hdf5_x_stroe = store_many_timings["hdf5"]
plot_with_legend(
cutoffs,
[disk_x_stroe, lmdb_x_stroe, hdf5_x_stroe],
["PNG files", "LMDB", "HDF5"],
"图片数量",
"耗时(秒)",
"存储的时间",
log=False,
)
plot_with_legend(
cutoffs,
[disk_x_stroe, lmdb_x_stroe, hdf5_x_stroe],
["PNG files", "LMDB", "HDF5"],
"图片数量",
"耗时(秒)",
"存储的时间对数",
log=True,
)
def read_single_disk(image_id):
""" 参数:
---------------
image_id 图像的整数唯一 ID
返回结果:
---------------
images 图像数组 (N, 32, 32, 3) 格式
labels 标签数组 (N,1) 格式
"""
image = np.array(Image.open(disk_dir / f"{image_id}.png"))
with open(disk_dir / f"{image_id}.csv", "r") as csvfile:
reader = csv.reader(
csvfile, delimiter=" ", quotechar="|", quoting=csv.QUOTE_MINIMAL
)
label = int(next(reader)[0])
return image, label
def read_single_lmdb(image_id):
""" 参数:
---------------
image_id 图像的整数唯一 ID
返回结果:
---------------
images 图像数组 (N, 32, 32, 3) 格式
labels 标签数组 (N,1) 格式
"""
# 打开 LMDB 环境
env = lmdb.open(str(lmdb_dir / f"single_lmdb"), readonly=True)
# 开始一个新的事务
with env.begin() as txn:
# 进行编码
data = txn.get(f"{image_id:08}".encode("ascii"))
# 加载的 CIFAR_Image 对象
cifar_image = pickle.loads(data)
# 检索相关位
image = cifar_image.get_image()
label = cifar_image.label
env.close()
return image, labels
def read_single_hdf5(image_id):
""" 参数:
---------------
image_id 图像的整数唯一 ID
返回结果:
---------------
images 图像数组 (N, 32, 32, 3) 格式
labels 标签数组 (N,1) 格式
"""
# 打开 HDF5 文件
file = h5py.File(hdf5_dir / f"{image_id}.h5", "r+")
image = np.array(file["/image"]).astype("uint8")
label = int(np.array(file["/meta"]).astype("uint8"))
return image, labels
_read_single_funcs = dict(
disk=read_single_disk, lmdb=read_single_lmdb, hdf5=read_single_hdf5
)
from timeit import timeit
read_single_timings = {"disk": [], "lmdb": [], "hdf5": []}
for method in ("disk", "lmdb", "hdf5"):
t = timeit(
"_read_single_funcs[method](0)",
setup="image=images[0]; label=labels[0]",
number=1,
globals=globals(),
)
read_single_timings[method] = t
print(f"读取方法: {method}, 使用耗时: {t}")
读取方法: disk, 使用耗时: 0.0010008999999797652
读取方法: lmdb, 使用耗时: 0.000521800000001349
读取方法: hdf5, 使用耗时: 0.0010201999999992495
存储方法 | 存储耗时 |
Disk | 0.0010008 |
LMDB | 0.0005218 |
HDF5 | 0.0010201 |
同单个图像存储方法类似,修改代码进行多个图像数据的存储。将多个图像保存为 .png 文件等同于多次调用 read_single_method() 。但此方法不适用于 LMDB 或 HDF5,因为每个图像都有不同的数据库文件。
def read_many_disk(num_images):
""" 参数:
---------------
num_images 要读取的图像数量
返回结果:
---------------
images 图像数组 (N, 32, 32, 3) 格式
labels 标签数组 (N,1) 格式
"""
images, labels = [], []
# 循环遍历所有ID,一张一张地读取每张图片
for image_id in range(num_images):
images.append(np.array(Image.open(disk_dir / f"{image_id}.png")))
with open(disk_dir / f"{num_images}.csv", "r") as csvfile:
reader = csv.reader(
csvfile, delimiter=" ", quotechar="|", quoting=csv.QUOTE_MINIMAL
)
for row in reader:
if row !=[]:
labels.append(int(row[0]))
return images, labels
def read_many_lmdb(num_images):
""" 参数:
---------------
num_images 要读取的图像数量
返回结果:
---------------
images 图像数组 (N, 32, 32, 3) 格式
labels 标签数组 (N,1) 格式
"""
images, labels = [], []
env = lmdb.open(str(lmdb_dir / f"{num_images}_lmdb"), readonly=True)
# 开始一个新的事务
with env.begin() as txn:
# 在一个事务中读取,也可以拆分成多个事务分别读取
for image_id in range(num_images):
data = txn.get(f"{image_id:08}".encode("ascii"))
# CIFAR_Image 对象,作为值存储
cifar_image = pickle.loads(data)
# 检索相关位
images.append(cifar_image.get_image())
labels.append(cifar_image.label)
env.close()
return images, labels
def read_many_hdf5(num_images):
""" 参数:
---------------
num_images 要读取的图像数量
返回结果:
---------------
images 图像数组 (N, 32, 32, 3) 格式
labels 标签数组 (N,1) 格式
"""
images, labels = [], []
# 打开 HDF5 文件
file = h5py.File(hdf5_dir / f"{num_images}_many.h5", "r+")
images = np.array(file["/images"]).astype("uint8")
labels = np.array(file["/meta"]).astype("uint8")
return images, labels
创建一个计算方式进行对比
_read_many_funcs = dict(
disk=read_many_disk, lmdb=read_many_lmdb, hdf5=read_many_hdf5
)
from timeit import timeit
read_many_timings = {"disk": [], "lmdb": [], "hdf5": []}
for cutoff in cutoffs:
for method in ("disk", "lmdb", "hdf5"):
t = timeit(
"_read_many_funcs[method](num_images)",
setup="num_images=cutoff",
number=1,
globals=globals(),
)
read_many_timings[method].append(t)
# 打印出方法、截止时间和经过时间
print(f"读取方法: {method}, No. images: {cutoff}, 耗时: {t}")
读取方法: disk, No. images: 10, 耗时: 0.007880700000015395
读取方法: lmdb, No. images: 10, 耗时: 0.00048730000000318796
读取方法: hdf5, No. images: 10, 耗时: 0.0009300000000109776
读取方法: disk, No. images: 100, 耗时: 0.04341629999998986
读取方法: lmdb, No. images: 100, 耗时: 0.0012704000000098858
读取方法: hdf5, No. images: 100, 耗时: 0.0010410999999805881
读取方法: disk, No. images: 1000, 耗时: 0.4195468000000062
读取方法: lmdb, No. images: 1000, 耗时: 0.01065929999998616
读取方法: hdf5, No. images: 1000, 耗时: 0.004244100000022399
读取方法: disk, No. images: 10000, 耗时: 4.095435699999996
读取方法: lmdb, No. images: 10000, 耗时: 0.10225660000000403
读取方法: hdf5, No. images: 10000, 耗时: 0.023506800000006933
读取方法: disk, No. images: 100000, 耗时: 41.69310229999999
读取方法: lmdb, No. images: 100000, 耗时: 1.0663754999999924
读取方法: hdf5, No. images: 100000, 耗时: 0.23246740000001864
PLOT 显示具有多个数据集和匹配图例的单个图
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei']
def plot_with_legend(
x_range, y_data, legend_labels, x_label, y_label, title, log=False
):
""" 参数:
--------------
x_range 包含 x 数据的列表
y_data 包含 y 值的列表
legend_labels 字符串图例标签列表
x_label x 轴标签
y_label y 轴标签
"""
plt.figure(figsize=(10, 7))
if len(y_data) != len(legend_labels):
raise TypeError(
"数据集的数量与标签的数量不匹配"
)
all_plots = []
for data, label in zip(y_data, legend_labels):
if log:
temp, = plt.loglog(x_range, data, label=label)
else:
temp, = plt.plot(x_range, data, label=label)
all_plots.append(temp)
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.legend(handles=all_plots)
plt.show()
# 获取要显示的读取时间数据
disk_x_read = read_many_timings["disk"]
lmdb_x_read = read_many_timings["lmdb"]
hdf5_x_read = read_many_timings["hdf5"]
plot_with_legend(
cutoffs,
[disk_x_read, lmdb_x_read, hdf5_x_read],
["PNG files", "LMDB", "HDF5"],
"图片数量",
"读取耗时(秒)",
"读取时间",
log=False,
)
plot_with_legend(
cutoffs,
[disk_x_read, lmdb_x_read, hdf5_x_read],
["PNG files", "LMDB", "HDF5"],
"图片数量",
"读取耗时(秒)",
"读取时间对数",
log=True,
)
同一张图表上查看读取和写入时间
plot_with_legend(
cutoffs,
[disk_x_r, lmdb_x_r, hdf5_x_r, disk_x, lmdb_x, hdf5_x],
[
"Read PNG",
"Read LMDB",
"Read HDF5",
"Write PNG",
"Write LMDB",
"Write HDF5",
],
"Number of images",
"Seconds",
"Log Store and Read Times",
log=False,
)
各种存储方式使用磁盘空间
disk_mem = [28, 207, 2009, 21032, 201296]
lmdb_mem = [60, 420, 4000, 49000, 403000]
hdf5_mem = [38, 307, 2900, 29500, 313000]
X = [disk_mem, lmdb_mem, hdf5_mem]
ind = np.arange(3)
width = 0.35
plt.subplots(figsize=(8, 10))
plots = [plt.bar(ind, [row[0] for row in X], width)]
for i in range(1, len(cutoffs)):
plots.append(
plt.bar(
ind, [row[i] for row in X], width, bottom=[row[i - 1] for row in X]
)
)
plt.ylabel("内存使用(KB)")
plt.title("不同方法使用的磁盘内存")
plt.xticks(ind, ("PNG", "LMDB", "HDF5"))
plt.yticks(np.arange(0, 400000, 100000))
plt.legend(
[plot[0] for plot in plots], ("10", "100", "1,000", "10,000", "100,000")
)
plt.show()
虽然 HDF5 和 LMDB 都占用更多的磁盘空间,但是使用和性能在很大程度上受操作系统、存储数据大小的因素影响。
通常对于大的数据集,可以通过并行化来加速操作。 也就是并发处理。
作为.png 文件存储到磁盘实际上允许完全并发。只要图像名称不同就可以从不同的线程读取多个图像,或一次写入多个文件。
如果将所有 CIFAR 分成十组,那么可以为一组中的每个读取设置十个进程,并且相应的处理时间可以减少到原来的10%左右。
有兴趣的话可以自己试一下。
| 留言与评论(共有 0 条评论) “” |