LogoLogo
  • README
  • 前端编程
    • 01 Node JS
    • 02-ES6详解
    • 03-NPM详解
    • 04-Babel详解
    • 05-前端模块化开发
    • 06-WebPack详解
    • 07-Vue详解
    • 08-Git详解
    • 09-微信小程序
  • 人工智能
    • 机器学习
      • 二次分配问题
      • 非负矩阵
      • 概率潜在语义分析
      • 概率图模型
      • 集成学习
      • 降维
      • 距离度量
      • 决策树
      • 逻辑回归
      • 马尔可夫决策过程
      • 马尔可夫链蒙特卡洛法
      • 朴素贝叶斯法
      • 谱聚类
      • 奇异值分解
      • 潜在狄利克雷分配
      • 潜在语义分析
      • 强化学习
      • 社区算法
      • 时间序列模型
      • 特征工程
      • 条件随机场
      • 图论基础
      • 线性分类
      • 线性回归
      • 信息论中的熵
      • 隐马尔科夫模型
      • 支持向量机
      • 主成分分析
      • EM算法
      • Hermite 矩阵的特征值不等式
      • k-means聚类
      • k近邻法
      • PageRank算法
    • 深度学习
      • Pytorch篇
        • 01-线性模型
        • 02-梯度下降法
        • 03-反向传播
        • 04-pytorch入门
        • 05-用pytorch实现线性回归
        • 06-logistic回归
        • 07-处理多维特征的输入
        • 08-加载数据集
        • 09-多分类问题
        • 10-卷积神经网络
        • 11-循环神经网络
    • 图神经网络
      • 图神经网络笔记01
        • 01-图(Graphs)的结构
        • 02-网络的性质和随机图模型
        • 03-网络工具
        • 04-网络中的主题和结构角色
        • 05-网络中的社区结构
      • 图神经网络笔记02
        • 01-深度学习引言
        • 02-神经网络基础
        • 03-卷积神经网络
        • 04-图信号处理与图卷积神经网络
        • 05-GNN的变体与框架-
        • [06-Google PPRGo 两分钟分类千万节点的最快GNN](人工智能/图神经网络/图神经网络笔记02/06-Google%20PPRGo 两分钟分类千万节点的最快GNN.md)
        • 07-序列模型
        • 08-变分自编码器
        • 09-对抗生成网络
  • 日常记录
    • 健身日记
    • 面经记录
    • 自动生成Summary文件
  • 实战项目
    • 谷粒商城
      • 00-项目概述
      • 01-分布式基础-全栈开发篇
      • 02-分布式高级-微服务架构篇
      • 03-高可用集群-架构师提升篇
  • 数据库
    • MySQL笔记
      • 01-MySQL基础篇
      • 02-MySQL架构篇
      • 03-MySQL索引及调优篇
      • 04-MySQL事务篇
      • 05-MySQL日志与备份篇
    • Redis笔记
      • 01-Redis基础篇
      • 02-Redis高级篇
    • 02-Redis篇
  • 算法笔记
    • 01-算法基础篇
    • 02-算法刷题篇
  • 职能扩展
    • 产品运营篇
  • Go编程
    • 01-Go基础
      • 01-Go基础篇
  • Java编程
    • 01-Java基础
      • 01-Java基础篇
      • 02-多线程篇
      • 03-注射与反解篇
      • 04-JUC并发编程篇
      • 05-JUC并发编程与源码分析
      • 06-JVM原理篇
      • 07-Netty原理篇
      • 08-设计模式篇
    • 02 Java Web
      • 01-Mybatis篇
      • 01-Mybatis篇(新版)
      • 02-Spring篇
      • 02-Spring篇(新版)
      • 03-SpringMVC篇
      • 04-MybatisPlus篇
    • 03-Java微服务
      • 01-SpringBoot篇
      • 01-SpringBoot篇(新版)
      • 02-SpringSecurity篇
      • 03-Shiro篇
      • 04-Swagger篇
      • 05-Zookeeper篇
      • 06-Dubbo篇
      • 07-SpringCloud篇
      • 08-SpringAlibaba篇
      • 09-SpringCloud篇(新版)
    • 04-Java中间件
      • 数据库篇
        • 01-分库分表概述
        • 02-MyCat篇
        • 03-MyCat2篇
        • 04-Sharding-jdbc篇
        • 05-ElasticSearch篇
      • 消息中间件篇
        • 01-MQ概述
        • 02-RabbitMQ篇
        • 03-Kafka篇
        • 04-RocketMQ篇
        • 05-Pulsar篇
    • 05-扩展篇
      • Dubbo篇
      • SpringBoot篇
      • SpringCloud篇
    • 06-第三方技术
      • 01-CDN技术篇
      • 02-POI技术篇
      • 03-第三方支付技术篇
      • 04-第三方登录技术篇
      • 05-第三方短信接入篇
      • 06-视频点播技术篇
      • 07-视频直播技术篇
    • 07-云原生
      • 01-Docker篇
      • 02-Kubernetes篇
      • 03-Kubesphere篇
  • Linux运维
    • 01-Linux篇
    • 02-Nginx篇
  • Python编程
    • 01-Python基础
      • 01.配置环境
      • 02.流程控制
      • 03.数值
      • 04.操作符
      • 05.列表
      • 06.元祖
      • 07.集合
      • 08.字典
      • 09.复制
      • 10.字符串
      • 11.函数
      • 12.常见内置函数
      • 13.变量
      • 14.异常和语法错误
      • 15.时间和日期
      • 16.正则表达式
    • 02 Python Web
      • flask篇
        • 01.前言
        • 02.路由
        • 03.模板
        • 04.视图进阶
        • 05.flask-sqlalchemy
        • 06.表单WTForms
        • 07.session与cookie
        • 08.上下文
        • 09.钩子函数
        • 10.flask 信号
        • 11.RESTFUL
        • 13.flask-mail
        • 14.flask+celery
        • 15.部署
        • 16.flask-login
        • 17.flask-cache
        • 18.flask-babel
        • 19.flask-dashed
        • 20.flask-pjax
        • 21.flask上传文件到第三方
        • 22.flask-restless
        • 23.flask-redis
        • 24.flask-flash
        • 25.消息通知
        • 26.分页
    • 03-Python数据分析
      • Matplotlib
      • Numpy
      • Pandas
      • Seaborn
    • 04-Python爬虫
      • 1.准备工作
      • 2.请求模块的使用
      • 3.解析模块的使用
      • 4.数据存储
      • 5.识别验证码
      • 6.爬取APP
      • 7.爬虫框架
      • 8.分布式爬虫
由 GitBook 提供支持
在本页
  • 数据集
  • softmax layer
  • 损失函数 - Cross Entropy
  • CrossEntropyLoss VS NLLLoss
  • 实战
  • 导入相应包
  • 预处理数据
  • 设计模型
  • 构造损失函数和优化器
  • 训练和测试
  • 特征提取
  • 作业

这有帮助吗?

在GitHub上编辑
  1. 人工智能
  2. 深度学习
  3. Pytorch篇

09-多分类问题

上一页08-加载数据集下一页10-卷积神经网络

最后更新于3年前

这有帮助吗?

image-20210303165309014

数据集

手写数字集:

一共有10个类别。

softmax layer

设$Z^l\in\mathbb{R}^K$是第$l$层的线性输出,则这个多分类函数为:

P(y=i)=eZi∑j=0K−1eZj,i∈{0,⋯ ,K−1}P(y=i)=\frac{e^{Z_i}}{\sum\limits^{K-1}_{j=0}e^{Z_j}},i\in \{0,\cdots,K-1\}P(y=i)=j=0∑K−1​eZj​eZi​​,i∈{0,⋯,K−1}

为了满足的条件$P(y=i)\ge0$,且$\sum\limits_{i=0}^{N}P(y=i)=1$,其中$N$为类别总数。

损失函数 - Cross Entropy

Loss(y,y^)=−ylog⁡y^Loss(y,\hat y) = -y\log{\hat y}Loss(y,y^​)=−ylogy^​
y = np.array([1.,0.,0.])
z = np.array([0.2,0.1,-0.1])
y_pred = np.exp(z) / np.exp(z).sum()
loss = (-y*np.log(y_pred)).sum()
print(loss)

通过pytorch进行调用torch.nn.CrossEntropyLoss()

y = torch.LongTensor([0])
z = torch.Tensor([[0.2,0.1,-0.1]])
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(z,y)
print(loss)

使用交叉熵,定义$y$需要使用长整型变量,其中的$0$表示第$0$个分类。

举个实例:

import torch
criterion = torch.nn.CrossEntropyLoss()
Y = torch.LongTensor([2,0,1])

Y_pred1 = torch.Tensor([ 
    [0.1,0.2,0.9],# 2
    [1.1,0.1,0.2],# 0
    [0.2,2.1,0.1],# 1
])
Y_pred2 = torch.Tensor([
    [0.8,0.2,0.3],
    [0.2,0.3,0.5],
    [0.2,0.2,0.5]
])

l1 = criterion(Y_pred1,Y)
l2 = criterion(Y_pred2,Y)
print("Batch Loss1 = {},Batch Loss2 = {}".format(l1.data,l2.data))

CrossEntropyLoss VS NLLLoss

CrossEntropyLoss <==> LogSoftmax + NLLLoss

实战

导入相应包

import torch
from torchvision import transforms # 对图像进行处理
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F # 使用激活函数
import torch.optim as optim

预处理数据

batch_size = 64

# 用于处理图像
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.1307,),(0.3081,))
])

# 导入训练数据
train_dataset = datasets.MNIST(root='./数据集/mnist/',
							  train=True,
							  download=True,
                              transform = transform)
# 打乱顺序
train_loader = DataLoader(train_dataset,
                         shuffle=True,
                         batch_size=batch_size)

# 导入测试数据
test_dataset = datasets.MNIST(root='./数据集/mnist/',
							  train=False,
							  download=True,
                              transform = transform)
# 打乱顺序
test_loader = DataLoader(test_dataset,
                        shuffle=True,
                        batch_size=batch_size)

通过transforms进行处理图像,主要是改变PIL Image到Tensor,从单通道变成多通道,使用transforms.ToTensor()进行实现。

transforms.Normalize((0.1307,),(0.3081,)),其中第一个参数表示均值,第二个参数表示方差,这个函数用途是进行归一化,即映射到0/1分布。

设计模型

class Net(torch.nn.Module):
    
    def __init__(self,):
        super(Net,self).__init__()
        self.l1 = torch.nn.Linear(784,512)
        self.l2 = torch.nn.Linear(512,256)
        self.l3 = torch.nn.Linear(256,128)
        self.l4 = torch.nn.Linear(128,64)
        self.l5 = torch.nn.Linear(64,10)
        
    def forward(self,x):
        x = x.view(-1,784)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        x = self.l5(x)
        return x
    
model = Net()

构造损失函数和优化器

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

训练和测试

  • 训练

def train(epoch):
    running_loss = 0.0
    for batch_idx,data in enumerate(train_loader,0):
        inputs,target = data
        
        optimizer.zero_grad()
        
        # forward + backward + update
        outputs = model(inputs)
        
        loss = criterion(outputs,target)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        
        if batch_idx % 300 == 299:
            print('[{:d},{:5d}] loss:{:.3f}'.format(
            		epoch+1,batch_idx+1,running_loss/300))
            running_loss = 0.0
  • 测试

def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images,labels = data
            outputs = model(images)
            _,predicted = torch.max(outputs.data,dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    print("Accuracy on test set:{:.2%}".format(correct/total))
  • 运行

if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

特征提取

  • 傅里叶变换

  • wavelet

  • CNN:自动提取

作业

kaggle多分类问题:https://www.kaggle.com/c/otto-group-product-classification-challenge/data

image-20210303171439663
image-20210303175145206
image-20210304101427030
image-20210304102023388