pytorch에서 Dataset을 다루는 방법과 DataLoader를 사용하는 방법에 대해 알아보겠습니다
https://tutorials.pytorch.kr/beginner/basics/data_tutorial.html 여기서 공부를 하였고 정리하였습니다.
1) 데이터셋 불러오기
- TorchVision Fashion-MNIST 데이터셋을 불러오는 예제를 하겠습니다.
- Fashion-MNIST는 흑백의 28x28 이미지와 10개의 class 중 하나의 label로 구성되어 있습니다.
- datasets.FashionMNIST()
입력 매개변수
- root : train 또는 test 데이터가 저장되는 경로
- train : train용 또는 test용 데이터셋 여부를 지정 True / False
- download = True시 root에 데이터가 없는 경우 인터넷에서 다운
- transform, target_transform = feature와 label, transform을 지정
training_data = datasets.FashionMNIST(
root = "data",
train = True,
download = True,
transform = ToTensor()
)
test_data = datasets.FashionMNIST(
root = "data",
train = False,
download = True,
transform = ToTensor()
)
- root에 "data"로 설정했기 때문에
data 폴더가 생성되며 FashionMNIST 데이터가 다운로드 받아진 것을 볼 수 있습니다.
2) 데이터셋을 시각화
- 데이터셋에 list 처럼 직접 접근하고 시각화를 할 수 있습니다.
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize = (8,8))
for i in range(1, 10):
sample_idx = torch.randint(len(training_data), size = (1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(3, 3, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap = "gray")
plt.show()
출력)
3) 사용자 정의 데이터셋 구성
- 사용자 정의 Dataset 클래스는 반드시! 3개 함수를 구현해야 합니다.
- __init__
- __len__
- __ getitem__
- FashionMNIST 이미지들은 img_dir 디렉터리에 저장되며 label은 annotations_file의 csv 파일에 별도로 저장됩니다.
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
# __init__ 함수는 Data 객체가 instantiate 될 때 한번 실행
# 여기서는 annotations_file이 포함된 디렉토리와 두가지 변형에 대해 초기화 한다.
def __init__(self, annotations_file, img_dir, transform = None, target_transform = None ):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
# __len__ 함수는 데이터셋의 샘플 개수를 반환
def __len__(self):
return len(self.img_labels)
# __getitem__ 함수는 주어진 idx에 해당하는 샘플을 데이터셋에서 불러와서 반환
# idx를 기반으로 이미지의 위치를 식별하고 read_image를 사용하여 이미지를 텐서로 변환
# self.img_labels의 csv 데이터로부터 해당하는 label을 가져오고 해당하면 transform 함수를 호출한 뒤
# 텐서 이미지와 label을 dict 형으로 반환
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform :
label = self.target_transform(label)
return image, label
4) DataLoader로 학습용 데이터 준비
- DataLoader 함수는 데이터셋의 feature을 가져오고 하나의 샘플에 label을 지정하는 일을 한 번에 합니다.
- 모델을 학습할 때 minibatch로 전달하고 epoch 마다 데이터를 다시 섞어서(shuffle) overfit을 막습니다.
- DataLoader는 간단한 API로 복잡한 과정들을 반복 가능한 객체(iteratable)입니다.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size = 64, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = 64, shuffle = True)
5) DataLoader를 통해 반복
- DataLoader에 데이터셋을 불러온 뒤에는 필요에 따라 데이터셋을 반복(iterate) 할 수 있습니다.
- 아래의 각 반복(iteration)은 각각 (batch_size = 64의 feature과 label을 포함하는) train_features와 train_labels의 batch로 반환했습니다.
- shuffle을 True로 했기 때문에 배치를 반복한 뒤 데이터가 섞입니다.
# 이미지와 label을 표시하는 것
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape : {train_features.size()}")
print(f"Labels batch shape : {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label : {label}")
출력)
- 파이토치에서 MNIST 데이터셋을 다운로드하고 DataLoader를 통해 데이터를 준비하는 방법에 대해 알아봤습니다.
다음 포스팅으로 찾아오겠습니다
감사합니다 :)
'AI > Pytorch' 카테고리의 다른 글
[Pytorch] Transform(변형) (0) | 2021.07.27 |
---|---|
[Pytorch] 텐서(Tensor) 조작 기초(3) (0) | 2021.07.22 |
[Pytorch] 텐서(Tensor) 조작 기초(2) (0) | 2021.07.21 |
[Pytorch] 텐서(Tensor) 조작 기초(1) (0) | 2021.06.28 |