PyTorch框架-TensorDataset

在学习深度学习前有必要先学习下当下深度学习的主流框架PyTorch、tensorflow、MXnet,但本篇文章主要介绍PyTorch的TensorDataset和DataLoader。我目前接触的框架也只有PyTorch,其他框架自己按需学习吧 :)

先看下面的代码,就能大致知道TensorDataset和DataLoader是干什么的了

TensorDataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# TensorDataset 学习
# DataLoader 学习
from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader

a = torch.tensor([
[1,2,3]
,[4,5,6]
,[7,8,9]
,[1,2,3]
,[4,5,6]
,[7,8,9]
,[1,2,3]
,[4,5,6]
,[7,8,9]
,[1,2,3]
,[4,5,6]
,[7,8,9]
])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66]
)
train_ids = TensorDataset(a,b)
# 切片输出
print(train_ids[0:2])
print('=' * 80)
# 循环取数据
for x_train,y_label in train_ids:
print(x_train,y_label)

# DataLoader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids,batch_size=4,shuffle=True)
for i, data in enumerate(train_loader,1):
x_data,label = data
print('batch:{0} x_data:{1} label:{2}'.format(i,x_data,label))

输出结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
(tensor([[1, 2, 3],
[4, 5, 6]]), tensor([44, 55]))
<torch.utils.data.dataset.TensorDataset object at 0x7fb6189168d0>
================================================================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
batch:1 x_data:tensor([[7, 8, 9],
[1, 2, 3],
[4, 5, 6],
[1, 2, 3]]) label:tensor([66, 44, 55, 44])
batch:2 x_data:tensor([[1, 2, 3],
[7, 8, 9],
[4, 5, 6],
[1, 2, 3]]) label:tensor([44, 66, 55, 44])
batch:3 x_data:tensor([[4, 5, 6],
[7, 8, 9],
[7, 8, 9],
[4, 5, 6]]) label:tensor([55, 66, 66, 55])

TensorDataset(a,b)将两个tensor传入到TensorDataset里,

得到TensorDataset后可以对其进行切片输出,例如TensorDataset(a,b)[0:2],指的是取a和b前面两个元素。

除了进行切片输出外,还可以用循环来取数据,例如a是矩阵,b是向量,用for循环取出来的分别就是矩阵的每一行和向量的每个元素。

DataLoader

DataLoader进行数据封装

1
2
3
4
5
6
# DataLoader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids,batch_size=4,shuffle=True)
for i, data in enumerate(train_loader,1):
x_data,label = data
print('batch:{0} x_data:{1} label:{2}'.format(i,x_data,label))

输出:

1
2
3
4
5
6
7
8
9
10
11
12
batch:1 x_data:tensor([[7, 8, 9],
[1, 2, 3],
[4, 5, 6],
[1, 2, 3]]) label:tensor([66, 44, 55, 44])
batch:2 x_data:tensor([[1, 2, 3],
[7, 8, 9],
[4, 5, 6],
[1, 2, 3]]) label:tensor([44, 66, 55, 44])
batch:3 x_data:tensor([[4, 5, 6],
[7, 8, 9],
[7, 8, 9],
[4, 5, 6]]) label:tensor([55, 66, 66, 55])

DataLoader(dataset=train_ids,batch_size=4,shuffle=True)这里分别是传入上面讲的TensorDataset类型的train_ids,batch_size=4指的是一个批次里有几个元素,这里等于4,是指一个batch里取四个元素。shuffle=True每次都洗牌。取到的都不一样。

将batch_size设置为1时,输出结果如下:

1
2
3
4
5
6
# DataLoader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids,batch_size=1,shuffle=True)
for i, data in enumerate(train_loader,1):
x_data,label = data
print('batch:{0} x_data:{1} label:{2}'.format(i,x_data,label))
1
2
3
4
5
6
7
8
9
10
11
12
13
================================================================================
batch:1 x_data:tensor([[1, 2, 3]]) label:tensor([44])
batch:2 x_data:tensor([[7, 8, 9]]) label:tensor([66])
batch:3 x_data:tensor([[1, 2, 3]]) label:tensor([44])
batch:4 x_data:tensor([[4, 5, 6]]) label:tensor([55])
batch:5 x_data:tensor([[4, 5, 6]]) label:tensor([55])
batch:6 x_data:tensor([[4, 5, 6]]) label:tensor([55])
batch:7 x_data:tensor([[4, 5, 6]]) label:tensor([55])
batch:8 x_data:tensor([[7, 8, 9]]) label:tensor([66])
batch:9 x_data:tensor([[7, 8, 9]]) label:tensor([66])
batch:10 x_data:tensor([[7, 8, 9]]) label:tensor([66])
batch:11 x_data:tensor([[1, 2, 3]]) label:tensor([44])
batch:12 x_data:tensor([[1, 2, 3]]) label:tensor([44])

将shuffle改为False, 每次取到的数据都是一样的

1
2
3
4
5
6
# DataLoader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids,batch_size=1,shuffle=False)
for i, data in enumerate(train_loader,1):
x_data,label = data
print('batch:{0} x_data:{1} label:{2}'.format(i,x_data,label))

觉得不错的话,给点打赏吧 ୧(๑•̀⌄•́๑)૭



wechat pay



alipay

PyTorch框架-TensorDataset
http://yuting0907.github.io/2022/07/27/PyTorch框架-TensorDataset/
作者
Echo Yu
发布于
2022年7月27日
许可协议