PyTorch 데이터 파이프라인 구축: Dataset과 DataLoader의 구조 및 최적화

PyTorch에서 효율적인 데이터 파이프라인을 구축하기 위해서는 DatasetDataLoader의 역할을 명확히 이해해야 합니다.

  • Dataset은 데이터 저장소이자 인덱싱 매커니즘입니다. 원본 데이터의 위치를 정의하고, 특정 인덱스에 해당하는 단일 샘플을 추출 및 전처리하는 로직을 담당합니다.
  • DataLoader는 배치 조립 및 전송 파이프라인입니다. Dataset으로부터 샘플을 가져와 지정된 규칙(배치 크기, 셔플 등)에 따라 묶고, 멀티프로세싱을 통해 GPU로 데이터를 지속적으로 공급합니다.

다음은 원본 데이터부터 모델 입력까지의 데이터 흐름과 각 컴포넌트의 역할을 나타낸 아키텍처 다이어그램입니다.


flowchart TD
    A["Raw Data\n(Images, Text, etc.)"] --> B[Dataset Class]
    subgraph B [Dataset Logic]
        B1["__init__\nLoad metadata/paths"] --> B2["__getitem__\nFetch & preprocess single item"]
    end
    B2 --> C["Single Sample\n(e.g., Tensor, Label)"]
    C --> D{DataLoader Engine}
    subgraph D [Batching & Multiprocessing]
        D1["Sampler\nIndex generation"] --> D2["Batch Assembly"]
        D3["Custom collate_fn\n(Optional)"] --> D2
    end
    D2 --> E["Mini-Batch\n(e.g., Tensor[B,C,H,W])"]
    E --> F["Model Training on GPU"]

1. Dataset: 데이터 추상화 및 인덱싱

Dataset은 PyTorch의 추상 클래스를 상속받아 구현합니다. 필수로 오버라이딩해야 하는 두 가지 매직 메서드는 다음과 같습니다.

  • __len__: 전체 데이터셋의 샘플 개수를 반환합니다.
  • __getitem__(index): 주어진 인덱스를 기반으로 전처리된 단일 샘플(데이터와 타겟)을 반환합니다.

커스텀 Dataset 구현 예시


import os
import torch
import json
from torch.utils.data import Dataset
from PIL import Image

class VisionClassificationDataset(Dataset):
    """이미지 분류를 위한 커스텀 데이터셋 구현"""
    
    def __init__(self, image_root, annotation_path, augmentor=None):
        """
        데이터셋 초기화 및 메타데이터 로드
        Args:
            image_root (str): 이미지가 저장된 루트 디렉토리
            annotation_path (str): 어노테이션 JSON 파일 경로
            augmentor (callable, optional): 이미지 증강 파이프라인
        """
        self.image_root = image_root
        self.augmentor = augmentor
        
        # JSON 형식의 어노테이션 파일 파싱
        with open(annotation_path, 'r', encoding='utf-8') as f:
            annotations = json.load(f)
            
        self.data_info = annotations['samples']
        self.class_mapping = annotations['class_to_idx']
    
    def __len__(self):
        return len(self.data_info)
    
    def __getitem__(self, index):
        item = self.data_info[index]
        img_path = os.path.join(self.image_root, item['file_name'])
        target_class = self.class_mapping[item['category']]
        
        # 이미지 로드 및 RGB 변환
        img = Image.open(img_path).convert('RGB')
        
        # 데이터 증강 적용
        if self.augmentor:
            img = self.augmentor(img)
            
        # 타겟을 텐서로 변환
        target = torch.tensor(target_class, dtype=torch.long)
        
        return img, target

Dataset의 두 가지 주요 유형

  • Map-style Dataset: 인덱스를 통해 특정 샘플에 무작위로 접근할 수 있는 방식입니다. 메모리에 메타데이터를 올릴 수 있거나 파일 시스템 경로로 매핑 가능한 경우에 적합합니다.
  • Iterable-style Dataset: __iter__ 메서드를 통해 순차적으로 데이터를 스트리밍하는 방식입니다. 단일 머신의 메모리나 디스크 용량을 초과하는 대규모 데이터셋(예: 대용량 로그, 웹 크롤링 데이터)을 처리할 때 사용됩니다.

2. DataLoader: 배치 프로세싱 및 멀티스레딩 엔진

DataLoaderDataset 객체를 래핑하여 미니배치 생성, 데이터 셔플링, 멀티프로세싱 기반의 비동기 데이터 로딩을 관리합니다.

DataLoader 구성 및 훈련 루프 통합


from torch.utils.data import DataLoader
from torchvision import transforms

# 이미지 전처리 및 증강 파이프라인 정의
data_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# Dataset 인스턴스 생성
train_dataset = VisionClassificationDataset(
    image_root='./dataset/train_images',
    annotation_path='./dataset/train_annotations.json',
    augmentor=data_transforms
)

# DataLoader 구성
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=64,            # 미니배치 크기
    shuffle=True,             # 에포크마다 데이터 순서 섞기
    num_workers=8,            # 데이터 로딩에 사용할 워커 프로세스 수
    pin_memory=True,          # GPU 전송 속도 향상을 위한 페이지 잠금 메모리 사용
    drop_last=True,           # 배치 크기에 맞지 않는 마지막 잔여 데이터 삭제
    persistent_workers=True   # 에포크가 끝난 후에도 워커 프로세스 유지 (재시작 오버헤드 감소)
)

# 훈련 루프 내에서의 활용
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for epoch in range(num_epochs):
    for images, targets in train_loader:
        # GPU로 데이터 비동기 전송
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        
        # 모델 순전파, 손실 계산, 역전파 등 훈련 로직 수행
        # ...

DataLoader의 내부 작동 원리

  • sampler: 데이터 인덱스의 샘플링 순서를 결정합니다. shuffle=True일 경우 RandomSampler가, False일 경우 SequentialSampler가 기본으로 사용됩니다. 불균형 데이터셋을 위한 WeightedRandomSampler 등으로 교체할 수 있습니다.
  • batch_sampler: sampler가 생성한 인덱스를 배치 단위로 그룹화합니다.
  • collate_fn: __getitem__에서 반환된 개별 샘플 리스트를 하나의 배치 텐서로 병합하는 함수입니다. 가변 길이의 시퀀스나 복잡한 데이터 구조를 다룰 때 필수적으로 커스터마이징해야 합니다.

커스텀 collate_fn을 활용한 가변 시퀀스 처리


import torch
from torch.nn.utils.rnn import pad_sequence

def sequence_padding_collator(batch):
    """
    가변 길이의 텍스트 시퀀스를 처리하기 위한 커스텀 collate_fn
    batch: Dataset.__getitem__이 반환한 튜플의 리스트 [(seq1, label1), (seq2, label2), ...]
    """
    # 시퀀스와 레이블 분리
    sequences, labels = zip(*batch)
    
    # 레이블 텐서 스태킹
    batch_labels = torch.stack(labels)
    
    # 시퀀스 패딩 (가장 긴 시퀀스 기준으로 0 채움)
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)
    
    # 패딩되지 않은 유효 토큰을 나타내는 어텐션 마스크 생성
    attention_masks = (padded_sequences != 0).to(torch.float32)
    
    return padded_sequences, attention_masks, batch_labels

# 커스텀 collate_fn을 적용한 DataLoader
text_loader = DataLoader(
    dataset=text_dataset, 
    batch_size=32, 
    collate_fn=sequence_padding_collator
)

3. 대규모 모델 훈련을 위한 파이프라인 최적화 전략

  • 역할의 명확한 분리:
    • Dataset__getitem__에는 단일 샘플에 대한 경량 전처리(이미지 디코딩, 리사이징 등)만 포함합니다.
    • 배치 수준의 복잡한 연산(패딩, 마스킹 등)은 DataLoadercollate_fn으로 위임하여 멀티프로세싱의 이점을 극대화합니다.
  • I/O 및 메모리 최적화:
    • num_workers 튜닝: CPU 코어 수와 스토리지 I/O 속도에 맞춰 최적의 워커 수를 실험적으로 찾습니다. 일반적으로 4~8 사이에서 최적의 성능을 보입니다.
    • pin_memorynon_blocking: GPU 훈련 시 pin_memory=True로 설정하고, 데이터 전송 시 non_blocking=True를 사용하여 CPU-GPU 간 데이터 전송과 모델 연산을 오버랩시킵니다.
    • 지연 로딩(Lazy Loading): __init__에서 대용량 데이터를 메모리에 모두 적재하지 말고, 파일 경로만 저장한 뒤 __getitem__에서 동적으로 읽도록 설계합니다.
  • 분산 및 대규모 데이터 처리:
    • 메모리에 올릴 수 없는 거대 데이터셋은 IterableDataset을 사용하여 스트리밍 방식으로 처리합니다.
    • 멀티 GPU 분산 훈련 환경에서는 DistributedSampler를 사용하여 각 프로세스가 중복되지 않는 데이터 서브셋을 학습하도록 보장해야 합니다.

태그: PyTorch Dataset DataLoader collate_fn IterableDataset

6월 12일 17:27에 게시됨