Language-LAB/Pytorch

Hook 메커니즘

JS LAB 2025. 3. 24. 14:36
728x90
반응형

PyTorch Hook 간단 정리

훅(Hook)은 PyTorch에서 모델의 내부 동작을 관찰하고 수정할 수 있게 해주는 기능입니다.

1. 모듈 Hook 기본

Forward Hook (출력 확인/수정)

import torch
import torch.nn as nn

# 간단한 모델
model = nn.Linear(5, 2)

# 출력값 저장 함수
outputs = {}
def save_output(name):
    def hook(module, input, output):
        outputs[name] = output  # 출력값 저장
    return hook

# 훅 등록하기
handle = model.register_forward_hook(save_output('linear'))

# 모델 실행
x = torch.randn(3, 5)  # 샘플 입력
y = model(x)

# 저장된 출력값 확인
print("저장된 출력:", outputs['linear'])

# 훅 제거
handle.remove()

Forward Pre-Hook (입력 확인/수정)

# 입력값 수정 함수
def double_input(module, input):
    # input은 항상 튜플 형태
    return (input[0] * 2,)

# 훅 등록
handle = model.register_forward_pre_hook(double_input)

# 모델 실행 (입력이 2배로 증가됨)
x = torch.ones(3, 5)
y = model(x)  # 실제로는 model(x*2)와 같음

# 훅 제거
handle.remove()

Backward Hook (그래디언트 확인/수정)

# 그래디언트 확인 함수
grads = {}
def save_grad(name):
    def hook(module, grad_input, grad_output):
        grads[name] = grad_output[0]  # 출력 그래디언트 저장
    return hook

# 훅 등록
handle = model.register_full_backward_hook(save_grad('linear'))

# 모델 실행 및 역전파
x = torch.randn(3, 5, requires_grad=True)
y = model(x)
y.sum().backward()  # 역전파 실행

# 저장된 그래디언트 확인
print("저장된 그래디언트:", grads['linear'])

# 훅 제거
handle.remove()

2. 텐서 Hook (그래디언트 접근)

# 텐서 생성
x = torch.randn(3, 5, requires_grad=True)

# 그래디언트 출력 함수
def print_grad(grad):
    print("그래디언트 값:", grad)
    return grad  # 그대로 반환

# 텐서에 훅 등록
handle = x.register_hook(print_grad)

# 연산 및 역전파
y = x.sum() * 2
y.backward()  # 그래디언트 계산 시 훅 실행됨

# 훅 제거
handle.remove()

3. 실용적인 예제

특성 맵 시각화 (가장 간단한 버전)

import torch.nn as nn
import matplotlib.pyplot as plt

# CNN 모델
model = nn.Sequential(
    nn.Conv2d(3, 16, 3),
    nn.ReLU()
)

# 특성 맵 저장
feature_maps = {}
model[0].register_forward_hook(
    lambda m, i, o: feature_maps.update({'conv': o.detach()})
)

# 이미지 실행
img = torch.randn(1, 3, 32, 32)
_ = model(img)

# 첫 4개 특성 맵 시각화
plt.figure(figsize=(10, 2.5))
for i in range(4):
    plt.subplot(1, 4, i+1)
    plt.imshow(feature_maps['conv'][0, i].numpy())
    plt.title(f'특성 맵 #{i+1}')
    plt.axis('off')
plt.show()

모델 디버깅 (NaN 감지)

# NaN 감지 훅
def detect_nan(module, input, output):
    if torch.isnan(output).any():
        print(f"NaN 발견: {module.__class__.__name__}")
        # 추가 디버깅 정보
        print(f"입력 형태: {input[0].shape}")
        print(f"입력 범위: {input[0].min()} ~ {input[0].max()}")

# 모든 레이어에 훅 적용
for name, module in model.named_modules():
    if isinstance(module, (nn.Conv2d, nn.Linear)):
        module.register_forward_hook(detect_nan)

간단한 그래디언트 모니터링

# 그래디언트 크기 모니터링 훅
def monitor_gradient(module, grad_input, grad_output):
    grad = grad_output[0]
    print(f"그래디언트 통계: 평균={grad.mean():.4f}, 최대={grad.max():.4f}")

# 특정 레이어에 적용
model.layer.register_full_backward_hook(monitor_gradient)

4. Hook 종류 한눈에 보기

종류 메서드 함수 형태 주요 용도
Forward Hook register_forward_hook fn(module, input, output) 레이어 출력 확인/수정
Pre-Forward Hook register_forward_pre_hook fn(module, input) 레이어 입력 전처리
Backward Hook register_full_backward_hook fn(module, grad_input, grad_output) 그래디언트 확인/수정
텐서 Hook tensor.register_hook fn(grad) 특정 텐서 그래디언트 처리

5. 사용 시 주의사항

  1. 메모리 관리: .detach()로 텐서 복사본 만들어 메모리 누수 방지
  2. 훅 제거: handle.remove()로 사용 후 제거
  3. 반환값:
    • Forward Hook: 값 반환 시 출력 대체됨, None 반환 시 출력 유지
    • Pre-Hook: 튜플 형태로 반환해야 함

훅은 모델 내부를 들여다보고 디버깅하는 데 매우 유용한 PyTorch의 강력한 기능입니다!

728x90
반응형