博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow saver简介+Demo with linear-model
阅读量:5266 次
发布时间:2019-06-14

本文共 1660 字,大约阅读时间需要 5 分钟。

 

tf.train.Saver提供Save和Restore Tensorflow变量的功能,常用于保存、还原模型训练结果,这在自己的训练和迁移学习中都很有用。

 

训练、保存脚本:

import tensorflow as tfcheckpoint_dir = './ckpt/'x_train = [1, 2, 3, 6, 8]y_train = [4.8, 8.5, 10.4, 21.0, 25.3]x = tf.placeholder(tf.float32, name='x')y = tf.placeholder(tf.float32, name='y')W = tf.Variable(1, dtype=tf.float32, name='W')b = tf.Variable(0, dtype=tf.float32, name='b')# 定义模型linear_model = W * x + b with tf.name_scope("loss-model"):    loss = tf.reduce_sum(tf.square(linear_model - y))    acc = tf.sqrt(loss)train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss)sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)variable_saver = tf.train.Saver(max_to_keep=3)# 训练、保存variablesfor i in range(1000):    sess.run([train_step], {x: x_train, y: y_train})    if i%10 == 0:        variable_saver.save(sess, checkpoint_dir, i)curr_W, curr_b, curr_loss, curr_acc = sess.run([W, b, loss, acc], {x: x_train, y: y_train})print("After train W: %f, b: %f, loss: %f, acc: %f" % (curr_W, curr_b, curr_loss, curr_acc))

运行保存的文件如下

ckpt

还原保存的变量:

import tensorflow as tfcheckpoint_dir = './ckpt/'W = tf.Variable(1, dtype=tf.float32, name='W')b = tf.Variable(0, dtype=tf.float32, name='b')sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)variable_saver = tf.train.Saver(max_to_keep=3)latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)if latest_checkpoint is not None:    variable_saver.restore(sess, latest_checkpoint)curr_W, curr_b = sess.run([W, b])print("After train W: %f, b: %f" % (curr_W, curr_b))

 

 

参考了:https://blog.csdn.net/gzj_1101/article/details/80299610

转载于:https://www.cnblogs.com/xbit/p/10071455.html

你可能感兴趣的文章
HTML <select> 标签
查看>>
tju 1782. The jackpot
查看>>
湖南多校对抗赛(2015.03.28) H SG Value
查看>>
hdu1255扫描线计算覆盖两次面积
查看>>
hdu1565 用搜索代替枚举找可能状态或者轮廓线解(较优),参考poj2411
查看>>
bzoj3224 splay板子
查看>>
程序存储问题
查看>>
Mac版OBS设置详解
查看>>
优雅地书写回调——Promise
查看>>
android主流开源库
查看>>
AX 2009 Grid控件下多选行
查看>>
PHP的配置
查看>>
Struts框架----进度1
查看>>
Round B APAC Test 2017
查看>>
MySQL 字符编码问题详细解释
查看>>
Ubuntu下面安装eclipse for c++
查看>>
让IE浏览器支持CSS3圆角属性的方法
查看>>
巡风源码阅读与分析---nascan.py
查看>>
LiveBinding应用 dataBind 数据绑定
查看>>
Linux重定向: > 和 &> 区别
查看>>