CBAM 모듈을 활용한 사전 학습 ResNet 성능 향상 전략

1. ResNet 구조 분석과 CBAM 삽입 위치

사전 학습된 모델에 새로운 모듈을 추가하여 성능을 향상시킬 수 있는지 고민해볼 필요가 있습니다. 핵심은 기존 구조와 가중치를 유지하면서 어떻게 하면 특성 추출기의 파라미터를 손상시키지 않을 수 있는지에 있습니다. 본 글에서는 두 가지 질문에 답하고자 합니다.

  1. ResNet18에 CBAM 모듈을 어떻게 삽입할 것인가?
  2. 어떤 사전 학습 전략이 효율성을 극대화할 수 있는가?

ResNet18에 CBAM을 추가한다면, 대부분의 코드는 재사용 가능하지만 모델 정의 부분은 새로 작성해야 합니다. 이전에는 사전 학습 레이어를 동결하고 분류 헤드(fully connected layer)만 학습시켰지만, 이제는 각 잔차 블록(residual block)의 CBAM 어텐션 레이어도 추가 학습이 필요합니다.

먼저 데이터 전처리와 CBAM 모듈을 정의한 후, ResNet의 내부 구조를 살펴보겠습니다.


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 채널 어텐션 정의
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.shape
        avg_feat = self.mlp(self.avg_pool(x).view(b, c))
        max_feat = self.mlp(self.max_pool(x).view(b, c))
        attn = self.sigmoid(avg_feat + max_feat).view(b, c, 1, 1)
        return x * attn

# 공간 어텐션 모듈
class SpatialAttention(nn.Module):
    def __init__(self, kernel=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel, padding=kernel//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        concat = torch.cat([avg_out, max_out], dim=1)
        attn = self.conv(concat)
        return x * self.sigmoid(attn)

# CBAM 모듈
class CBAM(nn.Module):
    def __init__(self, channels, reduction=16, kernel=7):
        super().__init__()
        self.channel_attn = ChannelAttention(channels, reduction)
        self.spatial_attn = SpatialAttention(kernel)

    def forward(self, x):
        x = self.channel_attn(x)
        x = self.spatial_attn(x)
        return x

# 데이터 전처리 (CIFAR-10)
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

사전 학습된 ResNet18 구조를 살펴보면 다음과 같습니다.


import torchvision.models as models
from torchinfo import summary

model = models.resnet18(pretrained=True)
model.eval()
summary(model, input_size=(1, 3, 224, 224))

ResNet-18은 크게 세 부분으로 나눌 수 있습니다.

단계 해당 레이어 주요 역할
입력 전처리 Conv2d ~ MaxPool2d 초기 특성 추출 및 특징 맵 크기 축소
핵심 특성 추출 4개의 Sequential 모듈 잔차 블록을 통해 깊고 다양한 수준의 특성 학습
분류 출력 AdaptiveAvgPool2d 및 Linear 특징 벡터 변환 및 클래스 매핑

입력 이미지가 [1, 3, 224, 224]일 때 형태 변화는 다음과 같습니다.

단계/레이어 출력 형태 채널 변화 크기 변화 설명
입력 [1, 3, 224, 224] - - 초기 입력
Conv2d (7x7, stride=2) [1, 64, 112, 112] 3→64 224→112 큰 커널로 초기 특성 추출
MaxPool2d (3x3, stride=2) [1, 64, 56, 56] 불변 112→56 추가 다운샘플링
Stage 1 (layer1) [1, 64, 56, 56] 불변 불변 2개의 BasicBlock, 깊이 확장
Stage 2 (layer2) [1, 128, 28, 28] 64→128 56→28 크기 반감, 채널 두 배
Stage 3 (layer3) [1, 256, 14, 14] 128→256 28→14 크기 반감, 채널 두 배
Stage 4 (layer4) [1, 512, 7, 7] 256→512 14→7 크기 반감, 채널 두 배
AdaptiveAvgPool2d [1, 512, 1, 1] 불변 7→1 글로벌 평균 풀링
Linear [1, 1000] 512→1000 - 1000개 클래스 매핑

BasicBlock의 핵심 개념: 잔차 학습(residual learning)은 네트워크가 목표 매핑 H(x) 대신 잔차 F(x) = H(x) - x를 학습하도록 합니다. 이는 "항등 매핑(identity mapping)" 학습을 단순화하여 네트워크가 깊어져도 성능 저하를 방지합니다. 숏컷 연결(shortcut connection)은 정보가 레이어를 건너뛰어 전파될 수 있게 하여 그래디언트 소실 문제를 완화합니다.

2. CBAM 모듈 삽입 위치에 대한 고찰

사전 학습된 모델에 새로운 모듈을 삽입할 때 고려해야 할 점은 기존 가중치를 보존하는 것입니다. CBAM을 마지막 분류 헤드 직전에 배치하는 것은 간단하지만, 어텐션 메커니즘이 최종 단계에서만 작동하여 중간 계층에서 더 나은 특성을 구축하는 데 도움이 되지 않습니다. 그러나 마지막 컨볼루션 이후에 배치하면 공간 어텐션의 효과가 사라질 수 있습니다(공간 차원이 없어지므로).

가장 효과적인 방법은 **각 잔차 블록의 출력에 CBAM을 적용**하는 것입니다. CBAM 모듈의 초기 상태가 "거의 직통(straight-through)"이기 때문에 이 접근 방식이 가능합니다. CBAM의 최종 연산은 return x * self.sigmoid(attention)입니다. 초기화 시, 모듈의 가중치가 거의 0에 가깝기 때문에 attention 값도 0에 가깝고, sigmoid(0)=0.5이므로 초기에는 x * 0.5에 가깝게 동작합니다. 이는 특성 값을 절반으로 줄이지만, 원래 특성의 구조와 관계를 보존하여 하위 계층이 안정적으로 학습을 시작할 수 있게 합니다.

따라서 CBAM을 각 잔차 블록 뒤에 삽입하는 것이 이상적입니다. 이렇게 하면 기존 사전 학습 가중치를 보존하면서 새로운 어텐션 모듈을 추가할 수 있습니다.


import torch.nn as nn
from torchvision import models

class ResNet18WithCBAM(nn.Module):
    def __init__(self, num_classes=10, pretrained=True, reduction=16, kernel=7):
        super().__init__()
        self.backbone = models.resnet18(pretrained=pretrained)
        
        # CIFAR-10의 32x32 입력에 맞게 첫 번째 컨볼루션 조정
        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.backbone.maxpool = nn.Identity()
        
        # 각 잔차 블록 그룹 뒤에 CBAM 추가
        self.cbam1 = CBAM(64, reduction, kernel)
        self.cbam2 = CBAM(128, reduction, kernel)
        self.cbam3 = CBAM(256, reduction, kernel)
        self.cbam4 = CBAM(512, reduction, kernel)
        
        # 분류 헤드 수정
        self.backbone.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        
        x = self.backbone.layer1(x)
        x = self.cbam1(x)
        
        x = self.backbone.layer2(x)
        x = self.cbam2(x)
        
        x = self.backbone.layer3(x)
        x = self.cbam3(x)
        
        x = self.backbone.layer4(x)
        x = self.cbam4(x)
        
        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.backbone.fc(x)
        return x

3. 사전 학습 모델을 위한 학습 전략

효율적인 학습을 위해 두 가지 전략을 채택합니다.

차등 학습률 (Differential Learning Rates)

사전 학습된 레이어는 경험이 풍부한 전문가, 새로운 CBAM과 분류 헤드는 인턴에 비유할 수 있습니다. 인턴에게는 높은 학습률(예: 1e-3)을, 전문가에게는 낮은 학습률(예: 1e-5)을 적용하여 기존 지식을 보호하면서 새로운 모듈을 빠르게 학습시킵니다.

3단계 점진적 미세 조정 (Progressive Unfreezing)

1단계 (Epoch 1-5): 신규 모듈 워밍업

  • 학습 가능: 분류 헤드(fc) 및 모든 CBAM 모듈
  • 동결: ResNet18의 모든 컨볼루션 레이어(conv1, bn1, layer1~layer4)
  • 목표: 강력한 사전 학습 특성을 활용하여 새로운 작업의 분류 경계를 빠르게 학습
  • 학습률: 1e-3

2단계 (Epoch 6-20): 고급 특성 해제

  • 학습 가능: 1단계 + layer3, layer4 (고급 의미 특성)
  • 동결: conv1, bn1, layer1, layer2 (저수준 특성)
  • 목표: 고급 특성 추출 능력을 새로운 작업에 맞게 조정
  • 학습률: 1e-4

3단계 (Epoch 21-50): 전체 미세 조정

  • 학습 가능: 모든 레이어
  • 동결: 없음
  • 목표: 저수준 특성까지 새로운 작업에 맞게 세밀하게 조정
  • 학습률: 1e-5

이 전략은 저수준 레이어(가장자리, 질감 등 일반적인 특성)는 새로운 작업과도 공유 가능하므로 보호하고, 고수준 레이어(객체 개념 등 작업 특화 특성)를 우선적으로 조정하는 데 기반합니다.


def configure_training(model, phase):
    """학습 단계에 따라 레이어 학습 가능 상태 설정"""
    for param in model.parameters():
        param.requires_grad = False
    
    # 항상 CBAM과 분류 헤드는 학습
    for name, param in model.named_parameters():
        if 'cbam' in name or 'backbone.fc' in name:
            param.requires_grad = True
    
    if phase >= 2:
        for name, param in model.named_parameters():
            if 'backbone.layer3' in name or 'backbone.layer4' in name:
                param.requires_grad = True
    
    if phase >= 3:
        for param in model.parameters():
            param.requires_grad = True

def train_phased(model, criterion, train_loader, test_loader, device, total_epochs=50):
    train_losses, test_accs = [], []
    
    for epoch in range(1, total_epochs + 1):
        if epoch == 1:
            configure_training(model, phase=1)
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
            print("Phase 1: Training attention modules and classification head")
        elif epoch == 6:
            configure_training(model, phase=2)
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
            print("Phase 2: Unfreezing high-level layers (layer3, layer4)")
        elif epoch == 21:
            configure_training(model, phase=3)
            optimizer = optim.Adam(model.parameters(), lr=1e-5)
            print("Phase 3: Full fine-tuning")
        
        model.train()
        running_loss = 0.0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # 평가
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        test_accs.append(accuracy)
        
        print(f"Epoch {epoch}/{total_epochs} - Loss: {epoch_loss:.4f}, Test Accuracy: {accuracy:.2f}%")
    
    return train_losses, test_accs

# 모델 초기화 및 학습 실행
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet18WithCBAM().to(device)
criterion = nn.CrossEntropyLoss()

losses, accuracies = train_phased(model, criterion, train_loader, test_loader, device, total_epochs=50)
print(f"Final test accuracy: {accuracies[-1]:.2f}%")

태그: ResNet CBAM Attention Mechanism Transfer Learning Fine-tuning

6월 10일 23:54에 게시됨