PyTorch-CUDA 환경에서 슬라이딩 윈도우 어텐션 고속화하기

분산 환경에서 Swin Transformer 계열 모델을 훈련할 때 GPU 활용률이 40%에 머무르고, NCCL 통신 오버헤드가 전방 계산 시간을 역전시키는 경험을 해본 적이 있는가? 동일 아키텍처를 사용하면서도 타 팀은 8개 GPU에서 95% 이상의 지속적인 활용률을 달성하며 두 배의 처리량을 내고 있다면, 문제는 모델 코드가 아닌 실행 환경의 차이에 있을 가능성이 높다.

이 글에서는 슬라이딩 윈도우 어텐션(Sliding Window Attention)과 같은 지역적 집약 연산을 GPU에서 최적으로 수행하기 위한 PyTorch-CUDA 기반 환경 구축 전략을 살펴본다.

배포 현장에서 마주친 성능 함정

특정 프로젝트에서 로컬 단일 GPU로 Swin-Tiny를 검증할 때는 이상 없었으나, 다중 노드 다중 GPU 클러스터로 전환하자 훈련 속도가 오히려 감소하는 현상이 발생했다. AllReduce 연산이 병목이 된 원인은 다음과 같았다.

  • 클러스터 노드: CUDA 11.6 / 로컬: CUDA 11.8
  • PyTorch가 소스 빌드되었으나 cuDNN benchmark 비활성화
  • 구형 NCCL 버전으로 인한 NVLink 토폴로지 미인식

해결책은 NVIDIA NGC의 nvcr.io/nvidia/pytorch:23.10-py3 이미지로 교체하는 것이었으며, 이후 훈련 처리량이 37% 향상되었다. 이는 환경 자체가 성능의 결정적 변수임을 보여주는 사례다.

슬라이딩 윈도우 어텐션이 환경에 민감한 이유

Swin Transformer의 핵심 메커니즘을 GPU 실행 관점에서 분석하면 다음과 같다.

  1. 윈도우 분할: H × W 특징 맵을 M × M 크기의 블록(예: 7×7)으로 분할
  2. 지역 어텐션: 각 블록 내부에서 독립적인 자기 어텐션 수행
  3. 시프티드 윈도우: 다음 레이어에서 블록을 절반씩 오프셋하여 영역 간 연결성 확보

이 과정에서 발생하는 GPU 부하 특성은 다음과 같다.

특성영향
다수의 소형 텐서 생성 ([B×num_windows, M², C])메모리 할당/해제 발 → 단편화
소규모 윈도우 내 집중 연산커널 실행 오버헤드에 민감
다중 헤드 + SoftmaxcuDNN 최적 구현 의존
분산 환경 윈도우/마스크 동기화통신-계산 오버랩 필수

PyTorch-CUDA 이미지의 최적화 구성

전문 컨테이너 이미지는 단순한 "설치 패키지"가 아닌 소프트웨어-하드웨어 협동 최적 시스템이다.

버전 정밀 매칭

수동 설치 시 다음과 같은 호환성 오류가 빈번하다.

# 수동 설치 시 발생 가능한 오류
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# RuntimeError: CUDA error: invalid device ordinal

공식 이미지는 검증된 조합을 고정한다.

pytorch:2.0-cuda11.7-cudnn8-runtime
# PyTorch(CUDA 11.7 컴파일) + cuDNN v8 + NVIDIA 전 GPU 검증

기본 내장 최적화

최적화 항목기능기본 활성화
cudnn.benchmark최고속 컨볼루션 알고리즘 자동 선택
cudaMallocAsync비동기 메모리 할당, 단편화 감소✓ (최신 이미지)
CUDA Graphs커널 실행 병합, 오버헤드 절감수동 설정
FP16/AMP혼합 정밀도 훈련 가속

NCCL: 다중 GPU 통신의 핵심

공식 이미지搭載 NCCL의 역할은 다음과 같다.

  • NVLink, PCIe 토폴로지 자동 인식
  • A100/H100에서 P2P 전송GPU Direct RDMA 활성화
  • 최적화된 AllReduce, Broadcast 구현

환경별 성능 비교 실험

Swin-Tiny의 Window Attention 모듈을 동일 조건에서 실행한 결과(입력: (2, 56, 56, 96), 윈도우 크기: 7).

환경구성스텝 시간(ms)GPU 활용률
ACPU 전용 PyTorch128.5< 5%
B수동 설치 CUDA 11.7 + PyTorch42.368%
CNGC 공식 이미지 (CUDA 11.8 + cuDNN8 + NCCL)27.193%

이미지 교체 단 하나로 속도 50% 향상 및 GPU 활용 안정화 달성.

실전 적용 가이드

1. 공식 이미지 우선 사용

# NGC 이미지 (권장)
docker pull nvcr.io/nvidia/pytorch:23.10-py3

# PyTorch 공식 이미지
docker pull pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime

2. GPU 마운트 필수

docker run -it --gpus all \
  -v $(pwd):/workspace \
  nvcr.io/nvidia/pytorch:23.10-py3

3. 혼합 정밀도 활성화

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for inputs, labels in loader:
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, labels)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

4. 메모리 관리

# 프로세스별 메모리 사용량 제한 (PyTorch 1.10+)
torch.cuda.set_per_process_memory_fraction(0.9)

# 단편화 의심 시 (주의해서 사용)
torch.cuda.empty_cache()

시스템 스택에서의 위치

[응용 계층]     ┌─────────────────────┐
                │  Swin, ViT 등 사용자 모델  │
                └─────────────────────┘
                         ↓
              ┌─────────────────────┐
              │  PyTorch + CUDA Backend  │
              └─────────────────────┘
                         ↓
              ┌─────────────────────┐
              │  cuDNN + NCCL + CUDA Runtime  │
              └─────────────────────┘
                         ↓
[하드웨어 계층]  ┌─────────────────────┐
                │  A100, H100, Tensor Cores    │
                └─────────────────────┘

PyTorch-CUDA 이미지는 중간 3개 계층을 포괄하며, 알고리즘과 물리적 연산 자원 간의 핵심 연결고리 역할을 한다.

향상된 Window Attention 구현 예시

환경 최적화의 효과를 극대화하려면 구현 자체도 효율적으로 작성해야 한다. 다음은 개선된 슬라이딩 윈도우 어텐션 코드다.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

class SlidingWindowAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        window_size: int = 7,
        shift_size: int = 0,
        qkv_bias: bool = True,
        proj_drop: float = 0.0,
        attn_drop: float = 0.0,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.scale = (embed_dim // num_heads) ** -0.5

        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # 상대 위치 편향 테이블
        self.rel_pos_bias = nn.Parameter(
            torch.zeros((2 * window_size - 1) ** 2, num_heads)
        )
        nn.init.trunc_normal_(self.rel_pos_bias, std=0.02)

        # 상대 위치 인덱스 미리 계산
        self._init_relative_indices()

    def _init_relative_indices(self):
        coords = torch.arange(self.window_size)
        grid_y, grid_x = torch.meshgrid(coords, coords, indexing="ij")
        flat_coords = torch.stack([grid_y.flatten(), grid_x.flatten()], dim=1)

        rel_coords = flat_coords[:, None, :] - flat_coords[None, :, :]
        rel_coords += self.window_size - 1

        row_idx = rel_coords[:, :, 0] * (2 * self.window_size - 1)
        col_idx = rel_coords[:, :, 1]
        self.register_buffer("rel_pos_idx", row_idx + col_idx)

    def _partition_windows(
        self, x: torch.Tensor, height: int, width: int
    ) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
        batch, seq_len, channels = x.shape
        assert seq_len == height * width

        # 선택적 쉬프트
        pad_h = (self.window_size - height % self.window_size) % self.window_size
        pad_w = (self.window_size - width % self.window_size) % self.window_size

        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
            height += pad_h
            width += pad_w

        # 순환 이동
        if self.shift_size > 0:
            x = x.view(batch, height, width, channels)
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            x = x.view(batch, height * width, channels)

        # 윈도우 분할: (B, H, W, C) -> (B, num_h, num_w, win_h, win_w, C)
        x = x.view(
            batch,
            height // self.window_size,
            self.window_size,
            width // self.window_size,
            self.window_size,
            channels,
        )
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        windows = x.view(-1, self.window_size * self.window_size, channels)

        return windows, (batch, height, width, channels)

    def _merge_windows(
        self,
        windows: torch.Tensor,
        shape_info: Tuple[int, int, int, int],
        orig_height: int,
        orig_width: int,
    ) -> torch.Tensor:
        batch, height, width, channels = shape_info
        pad_h = (self.window_size - orig_height % self.window_size) % self.window_size
        pad_w = (self.window_size - orig_width % self.window_size) % self.window_size

        # 역변환: (B*num_h*num_w, win_h*win_w, C) -> (B, H, W, C)
        x = windows.view(
            batch,
            height // self.window_size,
            width // self.window_size,
            self.window_size,
            self.window_size,
            channels,
        )
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.view(batch, height * width, channels)

        # 순환 이동 복원
        if self.shift_size > 0:
            x = x.view(batch, height, width, channels)
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            x = x.view(batch, height * width, channels)

        # 패딩 제거
        if pad_h > 0 or pad_w > 0:
            x = x[:, : (orig_height * orig_width), :]

        return x

    def _compute_attention(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> torch.Tensor:
        # query: (B*num_wins, num_heads, win_len, head_dim)
        scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale

        
        # 상대 위치 편향 적용
        bias = self.rel_pos_bias[self.rel_pos_idx.flatten()].view(
            self.window_size * self.window_size,
            self.window_size * self.window_size,
            -1,
        )
        scores = scores + bias.permute(2, 0, 1).unsqueeze(0)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_drop(attn_weights)

        output = torch.matmul(attn_weights, value)
        return output

    def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
        batch, seq_len, channels = x.shape
        orig_height, orig_width = height, width

        # 윈도우 분할
        windows, shape_info = self._partition_windows(x, height, width)

        # QKV 변환 및 멀티헤드 분할
        qkv = self.qkv_proj(windows).reshape(
            windows.shape[0],
            windows.shape[1],
            3,
            self.num_heads,
            channels // self.num_heads,
        )
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B*num_wins, heads, win_len, head_dim)
        queries, keys, values = qkv[0], qkv[1], qkv[2]

        # 어텐션 계산
        attn_output = self._compute_attention(queries, keys, values)

        # 헤드 결합 및 투영
        attn_output = attn_output.transpose(1, 2).reshape(
            windows.shape[0], windows.shape[1], channels
        )
        output = self.proj_drop(self.proj(attn_output))

        # 윈도우 병합
        output = self._merge_windows(output, shape_info, orig_height, orig_width)

        return output


# 성능 검증용 실행 예시
def benchmark_attention():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    batch_size = 2
    height, width = 56, 56
    embed_dim = 96
    num_heads = 3
    window_size = 7

    model = SlidingWindowAttention(
        embed_dim=embed_dim,
        num_heads=num_heads,
        window_size=window_size,
        shift_size=window_size // 2,
    ).to(device)

    x = torch.randn(batch_size, height * width, embed_dim, device=device)

    # 웜업
    for _ in range(10):
        _ = model(x, height, width)
    torch.cuda.synchronize()

    # 실측
    starter = torch.cuda.Event(enable_timing=True)
    ender = torch.cuda.Event(enable_timing=True)

    starter.record()
    for _ in range(100):
        output = model(x, height, width)
    ender.record()
    torch.cuda.synchronize()

    elapsed_ms = starter.elapsed_time(ender) / 100
    print(f"평균 실행 시간: {elapsed_ms:.3f} ms")
    print(f"출력 형태: {output.shape}")

    return elapsed_ms


if __name__ == "__main__":
    benchmark_attention()

이 구현에서는 다음과 같은 최적화 기법을 적용했다.

  • 메모리 레이아웃 최적화: contiguous() 호출로 연속 메모리 확보
  • 불필요한 중간 변수 제거: permute 체이닝으로 오버헤드 감소
  • 패딩 최소화: 동적 패딩 계산으로 메모리 낭비 방지
  • 버퍼 사전 등록: 상대 위치 인덱스를 register_buffer로 캐싱

미래 지향적 관점

모델 아키텍처의 혁신만큼 중요한 것은 전 스택 최적화다. FP8 훈련, CUDA Graphs, 커널 퓨전 등의 기술이 성능 한계를 재정의하는 현재, PyTorch-CUDA 이미지는 이러한 변화의 최전선에 있다. 연구 재현 시 가장 먼저 점검해야 할 것은 "올바른 환경"의 존재 여부다.

태그: PyTorch CUDA Sliding Window Attention Swin Transformer NCCL

6월 30일 23:49에 게시됨