이미지 회전 각도 판별 대회 솔루션 분석

컴퓨터 비전 대회에서의 이미지 회전 각도 판별

이미지 회전 각도 판별은 컴퓨터 비전 대회의 대표적인 과제로, 입력 이미지의 회전 각도를 0°, 90°, 180°, 270° 중 하나로 분류합니다. 조명 변화, 배경 간섭, 객체 다양성 등이 모델 성능에 영향을 미치는 주요 요인입니다. 본 문서는 실제 대회 데이터를 활용한 종합 솔루션을 제시합니다.

문제 정의 및 데이터 분석

각 회전 상태별 특징:

  • 0°: 원본 방향
  • 90°: 시계 방향 90° 회전
  • 180°: 상하 반전
  • 270°: 시계 방향 270° 회전

데이터 분석 시 클래스 불균형과 이미지 품질(선명도, 조명, 배경 복잡도)을 반드시 확인해야 합니다.

import matplotlib.pyplot as plt

def show_rotated_examples(image_set, label_set, samples=4):
    fig, axes = plt.subplots(4, samples, figsize=(15, 12))
    for angle in range(4):
        angle_imgs = image_set[label_set == angle]
        for idx in range(samples):
            if idx < len(angle_imgs):
                axes[angle, idx].imshow(angle_imgs[idx])
                axes[angle, idx].set_title(f'각도: {angle*90}°')
                axes[angle, idx].axis('off')
    plt.tight_layout()
    plt.show()

기본 모델 구성

CNN 아키텍처 선택 가이드:

  • 경량: MobileNet, EfficientNet-B0
  • 중간: ResNet34, EfficientNet-B3
  • 고성능: ResNet50, EfficientNet-B5
import torch.nn as nn
import torch.optim as optim

def model_training(net, train_loader, valid_loader, epochs=20):
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(net.parameters(), lr=0.001)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, patience=3)
    
    top_accuracy = 0.0
    for epoch in range(epochs):
        net.train()
        total_loss = 0.0
        for imgs, lbls in train_loader:
            opt.zero_grad()
            outputs = net(imgs)
            loss = loss_fn(outputs, lbls)
            loss.backward()
            opt.step()
            total_loss += loss.item()
        
        net.eval()
        correct = 0
        with torch.no_grad():
            for imgs, lbls in valid_loader:
                outputs = net(imgs)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == lbls).sum().item()
        
        acc = correct / len(valid_loader.dataset)
        lr_scheduler.step(acc)
        
        print(f'Epoch {epoch+1}: Loss: {total_loss:.4f}, Accuracy: {acc:.4f}')
        
        if acc > top_accuracy:
            top_accuracy = acc
            torch.save(net.state_dict(), 'top_model.pt')

데이터 증강 기법

회전 판별 과제에 특화된 증강 전략:

  • 색공간 변환: 채도/명암 조절, 그레이스케일 변환
  • 지역 가림: 객체 일부 가림
  • 혼합 기법: MixUp, CutMix 적용

대각도 회전 증강은 주의하여 사용해야 합니다.

from torchvision import transforms

def get_augmentation(img_dim=224):
    return transforms.Compose([
        transforms.RandomResizedCrop(img_dim, scale=(0.8, 1.0)),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([transforms.GaussianBlur(3)], p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

추가 기술 적용

테스트 시점 증강(TTA): 다중 크기 변환과 지역 자르기를 조합하여 예측 정확도 향상

def tta_inference(model, input_img, tta_steps=8):
    preds = []
    # 기본 예측
    base_pred = model(input_img.unsqueeze(0))
    preds.append(base_pred)
    
    # 크기 변환
    scales = [0.9, 1.0, 1.1]
    for s in scales:
        resized = transforms.functional.resize(input_img, int(input_img.shape[1]*s))
        preds.append(model(resized.unsqueeze(0)))
    
    # 지역 자르기
    h, w = input_img.shape[1], input_img.shape[2]
    crop_dim = int(min(h, w) * 0.9)
    ...
    
    return torch.mean(torch.stack(preds), dim=0)

가짜 레이블: 고신뢰도 예측을 학습 데이터에 추가

def create_pseudo_labels(model, unlabeled_set, threshold=0.95):
    model.eval()
    pseudo_set = []
    for imgs, _ in unlabeled_set:
        outputs = model(imgs)
        probs = torch.softmax(outputs, dim=1)
        conf, preds = torch.max(probs, dim=1)
        mask = conf > threshold
        ...
    return pseudo_set

모델 앙상블: 다양한 아키텍처의 예측 가중 평균

def ensemble(models, input_img, weights=None):
    if weights is None: 
        weights = [1.0] * len(models)
    total = None
    for m, w in zip(models, weights):
        with torch.no_grad():
            pred = m(input_img.unsqueeze(0))
            total = pred*w if total is None else total + pred*w
    return total / sum(weights)

성능 최적화 팁

  • 학습률 조정: 코사인 감소 스케줄링 적용
  • 오류 분석: 잘못 분류된 사례 유형 식별
  • 자원 관리: AMP 자동 혼합 정밀도 사용

태그: ComputerVision ImageClassification DataAugmentation TestTimeAugmentation PseudoLabeling

6월 13일 23:29에 게시됨