Skip to main content

Command Palette

Search for a command to run...

TensorFlow 2 Keras实现线性回归

Updated
1 min read

介绍

线性回归是入门机器学习必学的算法,其也是最基础的算法之一。

接下来,我们以线性回归为例,使用 TensorFlow 2 提供的 API 和 Eager Execution 机制对其进行实现。

线性回归是一种较为简单,但十分重要的机器学习方法,它也是神经网络的基础。

如下所示,线性回归要解决的问题就是如何找到最理想的直线去拟合散点样本

image-20211012155654480

对于一个线性回归问题,一般来讲有 2 种解决方法,分别是:

  • 最小二乘法
    • 代数求解
    • 矩阵求解
  • 梯度下降法。

本次,我们将使用梯度下降方法来解决线性回归问题。

Keras 方式实现

配合 TensorFlow 提供的高阶 API,我们省去了定义线性函数,定义损失函数,以及定义优化算法等 3 个步骤。

不过,高阶 API 实现过程实际上还不够精简,我们可以完全使用 TensorFlow Keras API 来实现线性回归

Keras 本来是一个用 Python 编写的独立高阶神经网络 API,它能够以 TensorFlow, CNTK,或者 Theano 作为后端运行

目前,TensorFlow 已经吸纳 Keras,并组成了 tf.keras 模块。官方介绍,tf.keras 和单独安装的 Keras 略有不同,但考虑到未来的发展趋势,主要以学习 tf.keras 为主。

image-20211012160727469

初始化

image-20211012162409785

我们这里使用 Keras 提供的 Sequential 顺序模型结构向其中添加一个线性层。不同的地方在于,Keras 顺序模型第一层为线性层时,规定需指定输入维度,这里为 input_dim=1

image-20211012161515170

接下来,直接使用 .compile 编译模型,指定损失函数为 MSE 平方损失函数,优化器选择 SGD 随机梯度下降。然后,就可以使用 .fit 传入数据开始迭代了。

`image-20211012162556372`

batch_size 是采用小批次训练的参数,主要用于解决一次性传入数据过多无法训练的问题。当然,由于示例数据本身较少,这里意义不大,但还是按照常规使用方法进行设置。

你会发现,完全使用 Keras 高阶 API 实际上只需要 4 行核心代码即可完成,相比于低阶 API 简化了很多。

完整代码:

import tensorflow as tf
TRUE_W = 3.0
TRUE_b = 2.0
NUM_SAMPLES = 100

X = tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
noise = tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
y = X * TRUE_W + TRUE_b + noise

# 模型训练
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=1,input_dim=1))
model.compile(optimizer='sgd',loss='mse')
model.fit(X,y,epochs=10,batch_size=32)

More from this blog

【两万字总结】Spark 部署与入门

Spark 介绍 核心概念 Spark 是 UC Berkeley AMP lab 开发的一个集群计算的框架,类似于 Hadoop,但有很多的区别。 最大的优化是让计算任务的中间结果可以存储在内存中,不需要每次都写入 HDFS,更适用于需要迭代的 MapReduce 算法场景中,可以获得更好的性能提升。 例如一次排序测试中,对 100TB 数据进行排序,Spark 比 Hadoop 快三倍,并且只需要十分之一的机器。 Spark 集群目前最大的可以达到 8000 节点,处理的数据达到 PB 级别...

Oct 20, 202115 min read

【引言】浙大机器学习课程记录

机器学习的定义 第一种定义 ARTHUR SAMUEL对Machine learning 的定义 Machine Learning is Fields of study that gives computers the ability to learn without being explicitly programmed 机器学习是这样的领域,它赋予计算机学习的能力,(这种学历能力)不是通过显著式编程获得的 显著式编程 提前人为指定规律的编程方式 非显著式编程 让计算机自己总结规律的...

Oct 19, 20212 min read

TensorFlow 2 基础概念语法与常用模块

TensorFlow 2 简介 TensorFlow 是由谷歌在 2015 年 11 月发布的深度学习开源工具,我们可以用它来快速构建深度神经网络,并训练深度学习模型。运用 TensorFlow 及其他开源框架的主要目的,就是为我们提供一个更利于搭建深度学习网络的模块工具箱,使开发时能够简化代码,最终呈现出的模型更加简洁易懂。 2019 年,TensorFlow 推出了 2.0 版本,也意味着 TensorFlow 从 1.x 正式过度到 2.x 时代。根据 TensorFlow 官方 介绍内...

Oct 12, 20213 min read

uiu's log

27 posts

Insist on programming & Love open source