LogoLogo
  • README
  • 前端编程
    • 01 Node JS
    • 02-ES6详解
    • 03-NPM详解
    • 04-Babel详解
    • 05-前端模块化开发
    • 06-WebPack详解
    • 07-Vue详解
    • 08-Git详解
    • 09-微信小程序
  • 人工智能
    • 机器学习
      • 二次分配问题
      • 非负矩阵
      • 概率潜在语义分析
      • 概率图模型
      • 集成学习
      • 降维
      • 距离度量
      • 决策树
      • 逻辑回归
      • 马尔可夫决策过程
      • 马尔可夫链蒙特卡洛法
      • 朴素贝叶斯法
      • 谱聚类
      • 奇异值分解
      • 潜在狄利克雷分配
      • 潜在语义分析
      • 强化学习
      • 社区算法
      • 时间序列模型
      • 特征工程
      • 条件随机场
      • 图论基础
      • 线性分类
      • 线性回归
      • 信息论中的熵
      • 隐马尔科夫模型
      • 支持向量机
      • 主成分分析
      • EM算法
      • Hermite 矩阵的特征值不等式
      • k-means聚类
      • k近邻法
      • PageRank算法
    • 深度学习
      • Pytorch篇
        • 01-线性模型
        • 02-梯度下降法
        • 03-反向传播
        • 04-pytorch入门
        • 05-用pytorch实现线性回归
        • 06-logistic回归
        • 07-处理多维特征的输入
        • 08-加载数据集
        • 09-多分类问题
        • 10-卷积神经网络
        • 11-循环神经网络
    • 图神经网络
      • 图神经网络笔记01
        • 01-图(Graphs)的结构
        • 02-网络的性质和随机图模型
        • 03-网络工具
        • 04-网络中的主题和结构角色
        • 05-网络中的社区结构
      • 图神经网络笔记02
        • 01-深度学习引言
        • 02-神经网络基础
        • 03-卷积神经网络
        • 04-图信号处理与图卷积神经网络
        • 05-GNN的变体与框架-
        • [06-Google PPRGo 两分钟分类千万节点的最快GNN](人工智能/图神经网络/图神经网络笔记02/06-Google%20PPRGo 两分钟分类千万节点的最快GNN.md)
        • 07-序列模型
        • 08-变分自编码器
        • 09-对抗生成网络
  • 日常记录
    • 健身日记
    • 面经记录
    • 自动生成Summary文件
  • 实战项目
    • 谷粒商城
      • 00-项目概述
      • 01-分布式基础-全栈开发篇
      • 02-分布式高级-微服务架构篇
      • 03-高可用集群-架构师提升篇
  • 数据库
    • MySQL笔记
      • 01-MySQL基础篇
      • 02-MySQL架构篇
      • 03-MySQL索引及调优篇
      • 04-MySQL事务篇
      • 05-MySQL日志与备份篇
    • Redis笔记
      • 01-Redis基础篇
      • 02-Redis高级篇
    • 02-Redis篇
  • 算法笔记
    • 01-算法基础篇
    • 02-算法刷题篇
  • 职能扩展
    • 产品运营篇
  • Go编程
    • 01-Go基础
      • 01-Go基础篇
  • Java编程
    • 01-Java基础
      • 01-Java基础篇
      • 02-多线程篇
      • 03-注射与反解篇
      • 04-JUC并发编程篇
      • 05-JUC并发编程与源码分析
      • 06-JVM原理篇
      • 07-Netty原理篇
      • 08-设计模式篇
    • 02 Java Web
      • 01-Mybatis篇
      • 01-Mybatis篇(新版)
      • 02-Spring篇
      • 02-Spring篇(新版)
      • 03-SpringMVC篇
      • 04-MybatisPlus篇
    • 03-Java微服务
      • 01-SpringBoot篇
      • 01-SpringBoot篇(新版)
      • 02-SpringSecurity篇
      • 03-Shiro篇
      • 04-Swagger篇
      • 05-Zookeeper篇
      • 06-Dubbo篇
      • 07-SpringCloud篇
      • 08-SpringAlibaba篇
      • 09-SpringCloud篇(新版)
    • 04-Java中间件
      • 数据库篇
        • 01-分库分表概述
        • 02-MyCat篇
        • 03-MyCat2篇
        • 04-Sharding-jdbc篇
        • 05-ElasticSearch篇
      • 消息中间件篇
        • 01-MQ概述
        • 02-RabbitMQ篇
        • 03-Kafka篇
        • 04-RocketMQ篇
        • 05-Pulsar篇
    • 05-扩展篇
      • Dubbo篇
      • SpringBoot篇
      • SpringCloud篇
    • 06-第三方技术
      • 01-CDN技术篇
      • 02-POI技术篇
      • 03-第三方支付技术篇
      • 04-第三方登录技术篇
      • 05-第三方短信接入篇
      • 06-视频点播技术篇
      • 07-视频直播技术篇
    • 07-云原生
      • 01-Docker篇
      • 02-Kubernetes篇
      • 03-Kubesphere篇
  • Linux运维
    • 01-Linux篇
    • 02-Nginx篇
  • Python编程
    • 01-Python基础
      • 01.配置环境
      • 02.流程控制
      • 03.数值
      • 04.操作符
      • 05.列表
      • 06.元祖
      • 07.集合
      • 08.字典
      • 09.复制
      • 10.字符串
      • 11.函数
      • 12.常见内置函数
      • 13.变量
      • 14.异常和语法错误
      • 15.时间和日期
      • 16.正则表达式
    • 02 Python Web
      • flask篇
        • 01.前言
        • 02.路由
        • 03.模板
        • 04.视图进阶
        • 05.flask-sqlalchemy
        • 06.表单WTForms
        • 07.session与cookie
        • 08.上下文
        • 09.钩子函数
        • 10.flask 信号
        • 11.RESTFUL
        • 13.flask-mail
        • 14.flask+celery
        • 15.部署
        • 16.flask-login
        • 17.flask-cache
        • 18.flask-babel
        • 19.flask-dashed
        • 20.flask-pjax
        • 21.flask上传文件到第三方
        • 22.flask-restless
        • 23.flask-redis
        • 24.flask-flash
        • 25.消息通知
        • 26.分页
    • 03-Python数据分析
      • Matplotlib
      • Numpy
      • Pandas
      • Seaborn
    • 04-Python爬虫
      • 1.准备工作
      • 2.请求模块的使用
      • 3.解析模块的使用
      • 4.数据存储
      • 5.识别验证码
      • 6.爬取APP
      • 7.爬虫框架
      • 8.分布式爬虫
由 GitBook 提供支持
在本页
  • 数据集
  • 激活函数
  • logistic回归模型
  • 损失函数
  • 对比
  • 实现

这有帮助吗?

在GitHub上编辑
  1. 人工智能
  2. 深度学习
  3. Pytorch篇

06-logistic回归

上一页05-用pytorch实现线性回归下一页07-处理多维特征的输入

最后更新于3年前

这有帮助吗?

数据集

  • 手写数字数据集

image-20210226164852513

这个数据集:训练集有60000个样本,测试集有10000个样本,类别有10个。

import torchvision

# torchvision 是一个数据集集合的模块
## root设置数据集存放的路径,train表示是否下载训练集,download表示是否进行下载
train_set = torchvision.datasets.MNIST(root="./数据集/mnist",
					train=True,download=True)
test_set = torchvision.datasets.MNIST(root="./数据集/mnist",
					train=False,download=True)
  • CIFAR-10 数据集

这个数据集:训练集有50000个样本,测试集有10000个样本,类别有10个。

import torchvision

# torchvision 是一个数据集集合的模块
## root设置数据集存放的路径,train表示是否下载训练集,download表示是否进行下载
train_set = torchvision.datasets.CIFAR10(root="./数据集/cifar10",
						train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./数据集/cifar10",
						train=False,download=True)

激活函数

  • sigmoid函数

σ(x)=11+e−x\sigma(x) = \frac{1}{1+e^{-x}}σ(x)=1+e−x1​

分类的结果是概率,结果需要在$[0,1]$中

其他的一些激活函数:

1.  erf(π2x)2.  x1+x23.  tanh⁡(x)4.  2πarctan⁡(π2x)5.  2πgd(π2x)6.  x1+∣x∣\begin{align} & 1.\ \ \mathbb{erf}(\frac{\sqrt{\pi}}{2}x) \\ & 2.\ \ \frac{x}{\sqrt{1+x^2}} \\ & 3.\ \ \tanh(x) \\ & 4.\ \ \frac{2}{\pi}\arctan(\frac{\pi}{2}x) \\ & 5.\ \ \frac{2}{\pi}\mathbb{gd}(\frac{\pi}{2}x) \\ & 6.\ \ \frac{x}{1+|x|}\\ \end{align}​1.  erf(2π​​x)2.  1+x2​x​3.  tanh(x)4.  π2​arctan(2π​x)5.  π2​gd(2π​x)6.  1+∣x∣x​​​

logistic回归模型

  • 定义模型

y^=x∗w+b\hat y = x*w + by^​=x∗w+b
  • 逻辑回归模型

y^=σ(x∗w+b)\hat y = \sigma(x*w+b)y^​=σ(x∗w+b)
  • 线性单元

  • 逻辑回归单元

损失函数

  • 线性函数损失函数

loss=(y^−y)2=(x∗w−y)2loss = (\hat y - y)^2 = (x*w-y)^2loss=(y^​−y)2=(x∗w−y)2
  • 二分类损失函数

loss=−(ylog⁡y^+(1−y)log⁡(1−y^))loss = -(y\log{\hat y}+(1-y)\log{(1-\hat y)})loss=−(ylogy^​+(1−y)log(1−y^​))
  • 小批量损失函数

loss=−1N∑n=1Nynlog⁡y^n+(1−yn)log⁡(1−y^n)loss = -\frac{1}{N}\sum\limits_{n=1}^Ny_n\log{\hat y_n} + (1-y_n)\log{(1-\hat y_n)}loss=−N1​n=1∑N​yn​logy^​n​+(1−yn​)log(1−y^​n​)

对比

  • 线性

import torch

class LinearModel(torch.nn.Module):
    
    def __init__(self,):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
        
    def forward(self,X):
        y_pred = self.linear(x)
        return y_pred
  • 逻辑回归

class LogisticRegressionModel(torch.nn.Module):
    
    def __init__(self,):
        super(LogisitcRegressionModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
        
    def forward(self,X):
    	# 调用sigmoid激活函数
        y_pred = torch.sigmoid(self.linear(X))
        return y_pred

实现

  • 步骤

  1. 预处理数据

  2. 设计模型并使用类torch.nn.Module

  3. 构造损失函数和优化器

  4. 训练数据

  • 代码

x_data = torch.Tensor([
    [1.0],[2.0],[3.0]
])
y_data = torch.Tensor([
    [0],[0],[1]
])


class LogisticRegressionModel(torch.nn.Module):
    
    def __init__(self,):
        super(LogisticRegressionModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
        
    def forward(self,X):
        y_pred = torch.sigmoid(self.linear(X))
        return y_pred
    
    def fit(self,X,y):
        criterion = torch.nn.BCELoss(reduction='sum')
        optimizer = torch.optim.SGD(self.parameters(),lr=0.01)
        
        for epoch in range(1000):
            y_pred = self.forward(x_data)
            loss = criterion(y_pred,y)
            print(epoch,loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

model = LogisticRegressionModel()
model.fit(x_data,y_data)
  • 测试

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view((200,1))
y_t = model(x_t)
y = y_t.data.numpy()

plt.plot(x,y)
plt.plot([0,10],[0.5,0.5],c='r')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()
image-20210226165735229
image-20210226172204981
image-20210226172228412
image-20210226173657175
image-20210227113654374