PyTorch框架-TensorDataset
在学习深度学习前有必要先学习下当下深度学习的主流框架PyTorch、tensorflow、MXnet,但本篇文章主要介绍PyTorch的TensorDataset和DataLoader。我目前接触的框架也只有PyTorch,其他框架自己按需学习吧 :)
先看下面的代码,就能大致知道TensorDataset和DataLoader是干什么的了
TensorDataset
1 |
|
输出结果如下:
1 |
|
TensorDataset(a,b)将两个tensor传入到TensorDataset里,
得到TensorDataset后可以对其进行切片输出,例如TensorDataset(a,b)[0:2],指的是取a和b前面两个元素。
除了进行切片输出外,还可以用循环来取数据,例如a是矩阵,b是向量,用for循环取出来的分别就是矩阵的每一行和向量的每个元素。
DataLoader
DataLoader进行数据封装
1 |
|
输出:
1 |
|
DataLoader(dataset=train_ids,batch_size=4,shuffle=True)这里分别是传入上面讲的TensorDataset类型的train_ids,batch_size=4指的是一个批次里有几个元素,这里等于4,是指一个batch里取四个元素。shuffle=True每次都洗牌。取到的都不一样。
将batch_size设置为1时,输出结果如下:
1 |
|
1 |
|
将shuffle改为False, 每次取到的数据都是一样的
1 |
|
觉得不错的话,支持一根棒棒糖吧 ୧(๑•̀⌄•́๑)૭
wechat pay
alipay
PyTorch框架-TensorDataset
http://yuting0907.github.io/posts/8bd2f176.html