加载Google官方深度学习模型

前言

上篇教程《使用深度学习进行语义分析》对文本进行分析,其中使用了一个很重要的二分模型。对于这种通用的模型,Google官方准备了一个叫做tensorflow_hub的仓库,里面有很多已经训练好的模型,主要包括文本、图片、视频、音频四个领域,用户可以直接通过pip下载引用。

下载tensorflow_hub和tensorflow-datasets

pip3 install tensorflow-hubpip3 install tensorflow-datasets

国内网络如果访问有问题,可以跳过本篇教程,或者简单浏览,不会影响后面的学习。

准备数据

直接从tensorflow-datasets中获取,数据形式基本和上篇教程相同

train_data, validation_data, test_data = tensorflow_datasets.load(    name="imdb_reviews",    split=('train[:60%]', 'train[60%:]', 'test'),    data_dir='temp',    as_supervised=True)

构建模型

神经网络模型通过层叠加的方式构建,会有三个非常重要的思考影响结构:

  1. 如何表示文本
  2. 有多少层
  3. 每个层有多少隐藏单元

这里还是和上篇教程基本相同的处理模式,但是不再自己构建TextVectorization层,而是直接从tensorflow_hub拿google/nnlm-en-dim50/2

embedding = "https://tfhub.dev/google/nnlm-en-dim50/2"hub_layer = hub.KerasLayer(embedding, input_shape=[],                           dtype=tf.string, trainable=True)model = tf.keras.Sequential()model.add(hub_layer)model.add(tf.keras.layers.Dense(16, activation='relu'))model.add(tf.keras.layers.Dense(1))model.summary()

它的作用和TextVectorization非常相似,也是将长句子转化为嵌入式数组(embedding vector)。

嵌入式数组

训练模型

history = model.fit(train_data.shuffle(10000).batch(512),                    epochs=10,                    validation_data=validation_data.batch(512),                    verbose=1)

模型评估

results = model.evaluate(test_data.batch(512), verbose=2)for name, value in zip(model.metrics_names, results):    print("%s: %.3f" % (name, value))

从输出中可以看到有85%的正确率

49/49 - 1s - loss: 0.3644 - accuracy: 0.8508 - 1s/epoch - 27ms/steploss: 0.364accuracy: 0.851
发表评论
留言与评论(共有 0 条评论) “”
   
验证码:

相关文章

推荐文章