09-多分类问题

数据集
手写数字集:

一共有10个类别。
softmax layer
设$Z^l\in\mathbb{R}^K$是第$l$层的线性输出,则这个多分类函数为:
为了满足的条件$P(y=i)\ge0$,且$\sum\limits_{i=0}^{N}P(y=i)=1$,其中$N$为类别总数。
损失函数 - Cross Entropy

通过pytorch进行调用torch.nn.CrossEntropyLoss()
使用交叉熵,定义$y$需要使用长整型变量,其中的$0$表示第$0$个分类。
举个实例:
CrossEntropyLoss VS NLLLoss
实战
导入相应包
预处理数据
通过transforms进行处理图像,主要是改变PIL Image到Tensor,从单通道变成多通道,使用transforms.ToTensor()进行实现。

transforms.Normalize((0.1307,),(0.3081,)),其中第一个参数表示均值,第二个参数表示方差,这个函数用途是进行归一化,即映射到0/1分布。
设计模型

构造损失函数和优化器
训练和测试
训练
测试
运行
特征提取
傅里叶变换
wavelet
CNN:自动提取
作业
kaggle多分类问题:https://www.kaggle.com/c/otto-group-product-classification-challenge/data
最后更新于
这有帮助吗?
