Theory/Pytorch

[파이토치 트랜스포머 #7] 4장 파이토치 심화 - 1) 과대적합과 과소적합, 배치 정규화

남디윤 2024. 4. 12. 15:30

 

과대적합과 과소적합

  • 과대적합 Overfitting: 모델이 훈련 데이터에서는 우수하게 예측, 새로운 데이터에서는 제대로 예측하지 못해 오차가 크게 발생하는 것을 의미
  • 과소적합 Underfitting: 훈련 데이터에서도 성능이 좋지 않고, 새로운 데이터에 대해서도 성능이 좋지 않음.
  • 공통점
    • 성능 저하
    • 모델 선택 실패
      • 모델을 변경해 문제 완화 가능
      • 과대적합: 모델의 구조가 너무 복잡
      • 과소적합: 모델의 구조가 너무 단순
    • 편향-분산 트레이드오프
      • 모델이 훈련 데이터와 새로운 데이터에 대해서도 우수한 성능을 보이려면 낮은 편향과 낮은 분산을 가져야 함
        • 분산이 높으면 추정치에 대한 변동 폭이 커짐
        • 데이터가 갖고 있는 노이즈까지 학습 과정에 포함
        • → 과대 적합 발생
      • 모델이 복잡할수록 분산은 커지고, 편향은 작아짐
      • 모델이 단순할수록 분산은 작아지고, 편향은 커짐
      • → 분산과 편향의 균형을 맞춰야 함

 

과대적합과 과소적합 문제 해결

  • 과대적합: 모델의 일반화(Generalization) 능력을 저하해 문제가 발생
  • 과소적합: 모델이 데이터의 특징을 제대로 학습할 수 없을 때 발생
  • 피하기 위한 방법
    • 데이터 수집: 모델이 훈련 데이터에서 노이즈를 학습하지 않으면서 일반적인 규칙을 찾을 수 있게 학습 데이터 수를 증가
    • 피처 엔지니어링: 신규 데이터 수집이 어려운 경우라면, 기존 훈련 데이터에서 변수나 특징을 추출하거나 피처를 더 작은 차원으로 축소. 피처 엔지니어링을 적합하게 진행하면 노이즈에 더 강한 모델 구축 가능
    • 모델 변경: 과대 적합의 경우, 깊은(Deep) 구조의 모델일 가능성이 높기에 모델 계층을 축소하거나 간단한 모델로 변경 필요. 과소 적합의 경우, 계층을 확장하거나 더 복잡한 모델로 변경 필요
    • 조기 중단 Early Stopping: 검증 데이터세트로 성능을 지속적으로 평가해 모델의 성능이 저하되기 전에 모델 학습을 조기 중단
    • 배치 정규화 Batch Normalization: 모델의 계층마다 평균과 분산을 조정해 내부 공변량 변화를 줄여 과대적합 방지
    • 가중치 초기화 Weight Initialization: 모델의 매개변수 최적화 전에 가중치 초깃값을 설정하는 프로세스를 의미. 학습 시 기울기가 매우 작아지거나 커지는 문제 발생 가능. 이러한 문제를 방지하는 방법
    • 정칙화 Regularization: 목적 함수에 패널티를 부여하는 방법. 학습 조기 중단, L1 정칙화, L2 정칙화, 드롭아웃, 가중치 감쇠 등

 

 

배치 정규화

  • 배치 정규화 Batch Normalization
    • 내부 공변량 변화 Internal Covariate Shift를 줄여 과대 적합을 방지하는 기술
    • 각 계층에 대한 입력이 일반화되고 독립적으로 정규화가 수행되므로 더 빠르게 값을 수렴할 수 있음.
    • 입력이 정규화되므로 초기 가중치에 대한 영향 줄이기 가능
  • 배치 단위 학습
    • 인공 신경망 학습 시 배치 단위 학습 진행
    • 상위 계층의 매개변수가 갱신될 때마다 현재 계층에 전달되는 데이터의 분포도 변경됨
    • 각 계층은 배치 단위의 데이터로 인해 계속 변화되는 입력 분포를 학습해야 하기 때문에 인공 신경망의 성능과 안전성이 낮아지고 학습 속도가 저하됨
    • ⇒ 내부 공변량 변화: 계층마다 입력 분포가 변경되는 현상
  • 내부 공변량 변화 발생
    • 은닉층에서 다음 은닉층으로 전달될 때 입력값이 균일하지 않아 가중치가 제대로 갱신되지 않을 수 있음
    • 학습 불안정 & 속도 저하 → 가중치 일정한 값 수렴 어려움
  • 모델 학습 시 초기 가중치 값에 민감 → 일반화 어려움 → 더 많은 학습 데이터 요구
  • 방식: 미니 배치의 입력을 정규화

 

정규화 종류

  • 배치 정규화 이외에도 계층 정규화 Layer Normalization, 인스턴스 정규화 Instance Normalization, 그룹 정규화 Group Normalization이 있음
  • 데이터 예시: 이미지 데이터 (특징: 차원, 채널, 미니 배치 * 특징은 데이터에 따라 달라질 수 있음)

  • 배치 정규화
    • 미니 배치에서 계산된 평균 및 분산을 기반으로 계층의 입력을 정규화
    • 이미지 데이터
      • 이미지 데이터 전체를 대상으로 정규화 하는 것이 아닌
      • 이미지 데이터 채널별로 정규화 수행
    • 컴퓨터 비전과 관련된 모델 중 합성곱 신경망(CNN)이나 다층 퍼셉트론(MLP)과 같은 순방향 신경망 Feedforward Neural Network 에서 주로 사용됨
  • 계층 정규화
    • 미니 배치의 샘플 전체를 계산하는 방법이 아닌, 채널 축으로 계산
    • 샘플이 서로 다른 길이를 가지더라도 정규화 수행 가능
    • 신경망 모델 중 자연어 처리에서 주로 사용되며, 순환 신경망(RNN)이나 트랜스포머 기반 모델에서 주로 사용됨
  • 인스턴스 정규화
    • 채널과 샘플을 기준으로 정규화를 수행
    • 입력이 다른 분포를 갖는 작업에 적합
    • 생성적 적대 신경망(GAN)이나 이미지의 스타일을 변환하는 스타일 변환 Style Transfer 모델에서 주로 사용됨
  • 그룹 정규화
    • 채널을 N개의 그룹으로 나누고 각 그룹 내에서 정규화를 수행
    • 그룹 하나로 설정 → 인스턴스 정규화 동일
    • 그룹의 개수=채널의 개수 → 계층 정규화 동일
    • 배치 크기가 작거나 채널 수가 매우 많은 경우에 주로 사용됨
    • 합성곱 신경망(CNN)의 배치 크기가 작으면 배치 정규화가 배치의 평균과 분산이 데이터 세트를 표현한다고 보기 어렵기 때문에 배치 정규화의 대안으로 사용됨

 

배치 정규화 수식 및 적용

  • 배치 정규화 수식
    • $\epsilon$ 엡실론: 분모가 0이 되는 현상 방지하는 작은 상수
    • $\gamma, \beta$ 감마, 베타: 학습 가능 매개변수, 활성화 함수에서 발생하는 음수의 영역 처리할 수 있게 값을 조절하는 스케일 Scale 값과 시프트 Shift 값
  • $$
    y_i = \gamma \left( \frac{x_i - E[X]}{\sqrt{Var[X] + \epsilon}} \right) + \beta
    $$
  • pytorch 정규화 예시
    import torch
    from torch import nn
    
    x = torch.FloatTensor(
    [
    [-0.6577, -0.5797, 0.6360],
    [0.7392, 0.2145, 1.523],
    [0.2432, 0.5662, 0.322]
    ]
    )
    
    print(nn.BatchNorm1d(3)(x))

 

  • 정규화 클래스
    • 배치 정규화
      • torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d
        • 1d: 2D/3D 입력 데이터에 배치 정규화 수행
        • 2d: 4D 입력 데이터에 대해 배치 정규화 수행
          • 예시) 이미지 데이터 [N, C, H, W] 형태
          • N은 배치 크기를 의미합니다.
          • C는 채널의 수를 의미합니다 (예를 들어, RGB 이미지의 경우 3).
          • H는 이미지의 높이를 의미합니다.
          • W는 이미지의 너비를 의미합니다.
        • 3d: 5D 입력 데이터에 대해 배치 정규화 수행
          • 예시) 시간에 따른 변화를 포함하는 동영상 데이터 [N, C, D, H, W] 형태
            • N은 배치 크기입니다.
            • C는 채널의 수입니다 (예를 들어, RGB 이미지의 경우 3).
            • D는 깊이를 의미하며, 시간적 차원이나 연속된 이미지 프레임을 나타낼 수 있습니다.
            • H는 높이입니다.
            • W는 너비입니다.
    • 계층 정규화
      • torch.nn.LayerNorm
    • 인스턴스 정규화
      • torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d
    • 그룹 정규화
      • torch.nn.GroupNrom