RISC-V MCU中文社区

Keras入门第1讲之keras可视化

发表于 2023-05-27 12:45:57
0
624
0

keras可视化可以帮助我们直观的查看所搭建的模型拓扑结构,以及模型的训练的过程,方便我们优化模型。

模型可视化又分为模型拓扑结构可视化以及训练过程可视化。

以上一讲的mnist为例,演示不同可视化方法:

1 Netron 查看h5模型

参考《TFlite之格式解析 Netron部分,Netron 是一款常见的可视化工具,支持网页查看常见的AI模型,支持非常丰富的格式(ONNX, Tensorflow, Pytorch, Keras, Caffe等),网页地址: https://netron.app/

将上一讲生成的keras_mnist.h5导入,得到模型结构,如下图:

2 keras的model.summary()方法

对于一些简单的模型,可以直接使用keras提供的model.summary()方法,如上一讲的mnist模型,代码中:

# 搭建好模型后,加上这一句

print("model:")
model.summary()

输出模型如下:

model:
Model: "sequential"
_________________________________________________________________
Layer (type)               Output Shape             Param #
=================================================================
flatten (Flatten)           (None, 784)               0

dense (Dense)               (None, 784)               615440

dense_1 (Dense)             (None, 10)               7850

=================================================================
Total params: 623,290
Trainable params: 623,290
Non-trainable params: 0
_________________________________________________________________

可见模型有一个flatten层,两个全连接层。

3 keras的graphviz功能

keras.utils.vis_utils模块提供了画出Keras模型的函数(利用graphviz)

import tensorflow as tf
import tensorflow.keras as keras

# 搭建好模型后,加上下一句
keras.utils.plot_model(model, to_file='model.svg', show_shapes=True)

注意:我这里按照官网生成model.png会失败,寻找网上的解决方案也无法解决,只好换用svg格式,svg格式可以通过浏览器直接打开,但是可能显示不全,可以修改model.svg的参数解决,参照参考2:

# 要调节width、height参数,以及viewBox参数

<svg width="825pt" height="825pt"
viewBox="0.00 0.00 825.00 825.00"   # viewBox="x, y, width, height"

浏览器解析,模型如下:

plot_model函数定义:


tf.keras.utils.plot_model(
   model,                       # keras model句柄
   to_file='model.png',         # 保存文件名及格式(这里使用svg格式)
   show_shapes=False,           # 是否显示形状信息,默认不显示
   show_layer_names=True,       # 显示layer名
   rankdir='TB',                # 横向显示(LR), 纵向显示(TB)
   expand_nested=False,         # 是否将嵌套模型扩展到聚类中
   dpi=96
)

4 训练历史可视化

Keras Model 上的 fit() 方法返回一个 History 对象。History.history 是一个记录了连续迭代的训练/验证损失值和评估值的字典。可以通过matplotlib将数据展示出来(这里就不使用matplotlib画图了,将数据打印出来):

# 截取部分代码如下:

import tensorflow as tf
import tensorflow.keras as keras

# step4: train
history = model.fit(x_train, y_train, batch_size=64, epochs=5)

# 打印history字典中的keys值
print(history.history.keys())
# 获取验证准确率数据
print(history.history['accuracy'])
# 获取训练时的损失值
print(history.history['loss'])

执行结果为:

dict_keys(['loss', 'accuracy'])
[0.9143333435058594, 0.9519500136375427, 0.9597333073616028, 0.9615499973297119, 0.9619333148002625]
[3.5814223289489746, 0.38597372174263, 0.26168128848075867, 0.23524414002895355, 0.24421487748622894]

5 训练过程的可视化:keras + Tensorboard

Tensorboard提供训练过程可视化的功能,是通过keras的回调函数来实现的。

# 截取部分代码如下:

import tensorflow as tf
import tensorflow.keras as keras
from keras.callbacks import TensorBoard

tbCallBack = TensorBoard() # 默认日志放到./logs 文件夹下
history = model.fit(x_train, y_train, batch_size=64, epochs=5, callbacks=[tbCallBack])

通过shell下执行如下命令,然后使用浏览器可打开tensorboard面板:

tensorboard --logdir /path/to/logs

Tensorboard 功能强大,对tensorflow、keras、还是pytorch都提供良好支持,这里先不做展开,可以参考3。

参考:

  1. 可视化 Visualization - Keras 中文文档

  2. viewBox_svg设置宽高_Lyrelion的博客-CSDN博客

  3. (80条消息) Tensorboard深入详解(一)


喜欢0
用户评论
sureZ-ok

sureZ-ok 实名认证

懒的都不写签名

积分
问答
粉丝
关注
专栏作者
  • RV-STAR 开发板
  • RISC-V处理器设计系列课程
  • 培养RISC-V大学土壤 共建RISC-V教育生态
RV-STAR 开发板