PyTorch Dataset 클래스 이해하기
PyTorch 딥러닝 프레임워크에서 데이터 관리는 매우 중요합니다. 특히 대규모 데이터셋을 효율적으로 처리하기 위해서는 torch.utils.data.Dataset 클래스를 이해하고 활용하는 것이 필수적입니다. 이 클래스는 데이터셋을 추상화하여, 개별 샘플에 접근하고 전체 데이터셋의 길이를 파악할 수 있도록 해줍니다. 사용자는 자신의 데이터 형식에 맞춰 Dataset 클래스를 상속받아 사용자 정의 데이터셋을 구현할 수 있습니다.
본 문서에서는 PyTorch 튜토리얼에서 제공하는 개미-벌 분류 데이터셋(hymenoptera_data)을 예시로 들어 이미지 데이터를 로드하고, 사용자 정의 Dataset 클래스를 구현하는 방법을 설명합니다.
- 개미-벌 분류 데이터셋 다운로드 링크: https://download.pytorch.org/tutorial/hymenoptera_data.zip
- 이미지 데이터 로딩 기본 예시 (PIL 라이브러리 사용) 먼저 Python의 PIL(Pillow) 라이브러리를 사용하여 단일 이미지를 로드하고 기본 정보를 확인하는 방법을 살펴보겠습니다.
from PIL import Image
import os
# 이미지 파일 경로 설정 (예시, 실제 경로에 맞게 수정 필요)
# Windows 환경에서는 백슬래시를 두 번 사용하거나 os.path.join을 권장합니다.
image_file_path = "D:\\data\\hymenoptera_data\\train\\ants\\0013035.jpg"
# Image.open() 메서드를 사용하여 이미지 로드
loaded_image = Image.open(image_file_path)
# 로드된 이미지의 크기 (너비, 높이) 출력
print(f"이미지 크기: {loaded_image.size}") # 예시 출력: (768, 512)
# 이미지 표시 (실행 환경에 따라 별도의 뷰어가 팝업될 수 있음)
# loaded_image.show()
위 코드는 지정된 경로의 이미지를 열고, 그 크기를 콘솔에 출력합니다. loaded_image.show()는 이미지를 화면에 띄워 보여줍니다.
- 사용자 정의
Dataset클래스 구현 데이터셋의 모든 이미지를 개별적으로 처리하고, PyTorch 모델 학습에 적합한 형태로 준비하기 위해torch.utils.data.Dataset을 상속받는 사용자 정의 클래스를 만듭니다. 이 클래스는 최소한__init__,__len__,__getitem__세 가지 메서드를 구현해야 합니다.
__init__(self, ...): 데이터셋 초기화 시 호출되며, 파일 경로, 레이블 정보 등을 설정합니다.__len__(self): 데이터셋에 포함된 전체 샘플의 수를 반환합니다.__getitem__(self, index): 주어진 인덱스에 해당하는 하나의 샘플(이미지, 레이블 등)을 로드하고 반환합니다.
from torch.utils.data import Dataset
import os
from PIL import Image
from typing import Tuple
class CustomImageDataset(Dataset):
"""
주어진 경로에서 이미지와 해당 레이블을 로드하는 사용자 정의 데이터셋 클래스.
"""
def __init__(self, base_directory: str, class_folder: str):
"""
데이터셋을 초기화합니다.
:param base_directory: 데이터셋의 루트 경로 (예: "dataset/train")
:param class_folder: 현재 데이터셋이 나타내는 클래스의 폴더 이름 (예: "ants", "bees")
"""
self.base_directory = base_directory
self.class_folder = class_folder
# 특정 클래스 폴더의 전체 경로를 구성합니다.
self.full_class_path = os.path.join(self.base_directory, self.class_folder)
# 해당 클래스 폴더 내의 모든 이미지 파일 이름 목록을 가져옵니다.
# 일반적인 이미지 파일 확장자를 필터링하여 정확성을 높입니다.
self.image_filenames = [
f for f in os.listdir(self.full_class_path)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))
]
def __len__(self) -> int:
"""
데이터셋에 포함된 총 이미지 샘플 수를 반환합니다.
"""
return len(self.image_filenames)
def __getitem__(self, index: int) -> Tuple[Image.Image, str]:
"""
주어진 인덱스에 해당하는 이미지와 레이블을 로드하고 반환합니다.
:param index: 데이터셋 내의 샘플 인덱스
:return: (PIL.Image 객체, 레이블 문자열) 튜플
"""
# 인덱스에 해당하는 이미지 파일 이름과 전체 경로를 가져옵니다.
image_name = self.image_filenames[index]
image_full_path = os.path.join(self.full_class_path, image_name)
# 이미지를 로드하고, 일관성을 위해 RGB 형식으로 변환합니다.
sample_image = Image.open(image_full_path).convert("RGB")
# 레이블은 해당 이미지의 클래스 폴더 이름을 사용합니다.
sample_label = self.class_folder
return sample_image, sample_label
# 사용자 정의 데이터셋 사용 예시
data_root_path = "dataset/train" # 실제 데이터셋 'train' 폴더의 경로로 변경 필요
# 'ants' 클래스 데이터셋 인스턴스 생성
ants_dataset_instance = CustomImageDataset(data_root_path, "ants")
# 'bees' 클래스 데이터셋 인스턴스 생성
bees_dataset_instance = CustomImageDataset(data_root_path, "bees")
print(f"개미 데이터셋의 총 이미지 수: {len(ants_dataset_instance)}")
print(f"벌 데이터셋의 총 이미지 수: {len(bees_dataset_instance)}")
# 첫 번째 개미 이미지와 레이블 가져오기 예시
# first_ant_img, first_ant_label = ants_dataset_instance[0]
# print(f"첫 번째 개미 이미지의 레이블: {first_ant_label}")
# first_ant_img.show() # 이미지를 화면에 표시
# 여러 데이터셋을 하나로 결합하기 위해 PyTorch의 ConcatDataset 사용
from torch.utils.data import ConcatDataset
# 'ants'와 'bees' 데이터셋을 결합하여 전체 훈련 데이터셋 생성
combined_train_dataset = ConcatDataset([ants_dataset_instance, bees_dataset_instance])
print(f"결합된 훈련 데이터셋의 총 이미지 수: {len(combined_train_dataset)}")
위 코드에서 ConcatDataset은 여러 Dataset 객체를 하나로 합치는 PyTorch의 유틸리티 클래스입니다. 이를 통해 '개미'와 '벌' 데이터셋을 하나의 훈련 데이터셋으로 통합할 수 있습니다.
- 데이터셋 레이블 텍스트 파일 생성 (선택 사항)
경우에 따라 이미지 파일의 레이블을 별도의 텍스트 파일로 저장해야 할 수도 있습니다. 다음 코드는 특정 이미지 폴더의 모든 이미지에 대해 해당 이미지의 클래스 레이블이 담긴 텍스트 파일을 생성하는 예시입니다. 이 예시에서는 폴더 이름에서 레이블을 추출하고, 각 이미지 파일 이름과 동일한 이름의
.txt파일을 생성합니다.
import os
# 기본 설정 경로
base_data_directory = "dataset/train" # 데이터셋의 루트 경로
image_source_folder = "ants" # 레이블을 생성할 이미지가 있는 폴더 (예: 'ants', 'bees')
output_label_directory = "generated_labels_txt" # 생성될 레이블 파일이 저장될 폴더 이름
# 이미지 폴더 이름에서 클래스 레이블 추출 (예: "ants" -> "ants")
extracted_label_name = image_source_folder
# 이미지 원본 폴더의 전체 경로
full_image_source_path = os.path.join(base_data_directory, image_source_folder)
# 레이블 텍스트 파일이 저장될 출력 폴더의 전체 경로
full_output_label_path = os.path.join(base_data_directory, output_label_directory)
# 출력 레이블 폴더가 없으면 새로 생성합니다.
os.makedirs(full_output_label_path, exist_ok=True)
# 이미지 원본 폴더 내의 모든 이미지 파일 목록을 가져옵니다.
image_files_to_process = [
f for f in os.listdir(full_image_source_path)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))
]
print(f"'{image_source_folder}' 폴더의 이미지에 대해 '{extracted_label_name}' 레이블 텍스트 파일 생성 중...")
# 각 이미지 파일에 대해 레이블 텍스트 파일 생성
for image_filename in image_files_to_process:
# 파일 이름에서 확장자를 제거하여 텍스트 파일의 기본 이름으로 사용합니다.
base_filename_without_ext = os.path.splitext(image_filename)[0]
# 생성될 레이블 텍스트 파일의 전체 경로를 구성합니다.
label_output_file_path = os.path.join(full_output_label_path, f"{base_filename_without_ext}.txt")
# 텍스트 파일에 추출된 레이블을 기록합니다.
try:
with open(label_output_file_path, 'w', encoding='utf-8') as f:
f.write(extracted_label_name)
# print(f"레이블 파일 생성: '{label_output_file_path}'")
except IOError as e:
print(f"오류: '{label_output_file_path}' 생성 실패 - {e}")
print("레이블 텍스트 파일 생성 작업이 완료되었습니다.")
이 과정을 통해 각 이미지에 대응하는 텍스트 레이블 파일이 생성되며, 이는 나중에 데이터셋을 구성할 때 활용될 수 있습니다. 예를 들어, dataset/train/ants/0013035.jpg 이미지에 대해서는 dataset/train/generated_labels_txt/0013035.txt 파일에 "ants"라는 내용이 저장됩니다.