图像分类任务是计算机视觉的基础,也是新手必须要掌握的知识和技能,其中卷积神经网络(CNN)就是最常用的提取图像特征的方法。
数据集来源于飞桨新手赛,一个cat_12_train目录,包含2160张大小不一的图像,标注信息存放在data_list.txt中。每行的格式内容如下:
图像名称 分类ID
cat_12_train/8GOkTtqw7E6IHZx4olYnhzvXLCiRsUfM.jpg 0
cat_12_train/hwQDH3VBabeFXISfjlWEmYicoyr6qK1p.jpg 0
cat_12_train/RDgZKvM6sp3Tx9dlqiLNEVJjmcfQ0zI4.jpg 0
cat_12_train/ArBRzHyphTxFS2be9XLaU58m34PudlEf.jpg 0部分图像
还有一个cat_12_test目录,包含240张测试图像,没有标注信息。
为了方便的模型训练,需要把数据集拆分成训练集和验证集,就是把data_list.txt 按8:2拆分成train_list.txt和val_list.txt
先从data_list.txt读所有的数据存放在numpy数组datanp中,并随机乱序,然后对datanp进行切片,把切片后的数据写回训练集和验证集文本文件。
import os
import numpy as np
datadir="cat_12"
datalst=[]
#打开文本文件
with open(os.path.join(datadir,"data_list.txt")) as f:
for line in f.readlines():
d=line.split(" ")
imgpath=d[0]
imglabel=d[1]
datalst.append([imgpath,imglabel])
datanp=np.array(datalst)
np.random.shuffle(datanp)
trainnp=datanp[:int(0.8*len(datanp))]
valnp=datanp[int(0.8*len(datanp)):]
#写入训练集文本文件
with open(os.path.join(datadir,"train_list.txt"),"w") as f:
for d in trainnp:
f.write(d[0]+" "+d[1])
#写入验证集文本文件
with open(os.path.join(datadir,"val_list.txt"),"w") as f:
for d in valnp:
f.write(d[0]+" "+d[1])
这里使用的是ResNet50_vd_ssld分类模型,使用paddlex来训练,代码和过程这里省略,可以参看前面的文章:2.人工智能-图像分类
把训练模型转换成部署模型,对给的240个测试数据集进行预测,并保存预测结果到result.csv文件。
import paddlex as pdx
import cv2
import os
import pandas as pd
import numpy as np
import matplotlib.image as mpimg
testdir="cat_12/cat_12_test"
testlst=[]
predictor=pdx.deploy.Predictor("inference_model",use_gpu=True)
for f in os.listdir(testdir):
fpath=os.path.join(testdir,f)
#img = cv2.imdecode(np.fromfile(fpath,dtype = np.uint8),-1)
img=cv2.imread(fpath)
#因编码问题读不出图像内容的
if img is None:
img=mpimg.imread(fpath)
img=cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
result=predictor.predict(img)
print(result[0]["category"],result[0]["score"])
testlst.append([f,result[0]["category"]])
#把testlst列表写入csv
pd.DataFrame(testlst).to_csv("test.csv",index=False,header=False)Qt29gPjYZwv3B6RJh5yiTWXrVImue1FH.jpg
这里有一个注意项:在cv2读取图像的时候,如果文件名存在中文或编码问题,会引起cv2无法读取出图像内容,返回None,导致程序无法正常进行。碰到这种情况大部分解决方式就是中文改为英文。例如上面的这张图像。
本文在这里换个思路,cv2无法读取,就换成matplotlib.image来读取。
提交结果来看,0.91的得分并不是很高,看到最高的分数都达到0.99以上。如果对训练数据进行预处理,和训练参数的调整,应该还是可以提高分数的。
#说明:
1.结果文件总行数应为240,否则成绩无效;
2.结果文件命名应为result.csv,每行内容格式为:文件名,分类结果,文件名和分类结果使用英文逗号","分隔;
评价指标
评比标准为分类正确的准确率,计算方式如下:
准确率 = 正确的分类个数 / 测试集图片个数提交结果
| 留言与评论(共有 0 条评论) “” |