Language-LAB/Pytorch
Hook 메커니즘
JS LAB
2025. 3. 24. 14:36
728x90
반응형
PyTorch Hook 간단 정리
훅(Hook)은 PyTorch에서 모델의 내부 동작을 관찰하고 수정할 수 있게 해주는 기능입니다.
1. 모듈 Hook 기본
Forward Hook (출력 확인/수정)
import torch
import torch.nn as nn
# 간단한 모델
model = nn.Linear(5, 2)
# 출력값 저장 함수
outputs = {}
def save_output(name):
def hook(module, input, output):
outputs[name] = output # 출력값 저장
return hook
# 훅 등록하기
handle = model.register_forward_hook(save_output('linear'))
# 모델 실행
x = torch.randn(3, 5) # 샘플 입력
y = model(x)
# 저장된 출력값 확인
print("저장된 출력:", outputs['linear'])
# 훅 제거
handle.remove()
Forward Pre-Hook (입력 확인/수정)
# 입력값 수정 함수
def double_input(module, input):
# input은 항상 튜플 형태
return (input[0] * 2,)
# 훅 등록
handle = model.register_forward_pre_hook(double_input)
# 모델 실행 (입력이 2배로 증가됨)
x = torch.ones(3, 5)
y = model(x) # 실제로는 model(x*2)와 같음
# 훅 제거
handle.remove()
Backward Hook (그래디언트 확인/수정)
# 그래디언트 확인 함수
grads = {}
def save_grad(name):
def hook(module, grad_input, grad_output):
grads[name] = grad_output[0] # 출력 그래디언트 저장
return hook
# 훅 등록
handle = model.register_full_backward_hook(save_grad('linear'))
# 모델 실행 및 역전파
x = torch.randn(3, 5, requires_grad=True)
y = model(x)
y.sum().backward() # 역전파 실행
# 저장된 그래디언트 확인
print("저장된 그래디언트:", grads['linear'])
# 훅 제거
handle.remove()
2. 텐서 Hook (그래디언트 접근)
# 텐서 생성
x = torch.randn(3, 5, requires_grad=True)
# 그래디언트 출력 함수
def print_grad(grad):
print("그래디언트 값:", grad)
return grad # 그대로 반환
# 텐서에 훅 등록
handle = x.register_hook(print_grad)
# 연산 및 역전파
y = x.sum() * 2
y.backward() # 그래디언트 계산 시 훅 실행됨
# 훅 제거
handle.remove()
3. 실용적인 예제
특성 맵 시각화 (가장 간단한 버전)
import torch.nn as nn
import matplotlib.pyplot as plt
# CNN 모델
model = nn.Sequential(
nn.Conv2d(3, 16, 3),
nn.ReLU()
)
# 특성 맵 저장
feature_maps = {}
model[0].register_forward_hook(
lambda m, i, o: feature_maps.update({'conv': o.detach()})
)
# 이미지 실행
img = torch.randn(1, 3, 32, 32)
_ = model(img)
# 첫 4개 특성 맵 시각화
plt.figure(figsize=(10, 2.5))
for i in range(4):
plt.subplot(1, 4, i+1)
plt.imshow(feature_maps['conv'][0, i].numpy())
plt.title(f'특성 맵 #{i+1}')
plt.axis('off')
plt.show()
모델 디버깅 (NaN 감지)
# NaN 감지 훅
def detect_nan(module, input, output):
if torch.isnan(output).any():
print(f"NaN 발견: {module.__class__.__name__}")
# 추가 디버깅 정보
print(f"입력 형태: {input[0].shape}")
print(f"입력 범위: {input[0].min()} ~ {input[0].max()}")
# 모든 레이어에 훅 적용
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
module.register_forward_hook(detect_nan)
간단한 그래디언트 모니터링
# 그래디언트 크기 모니터링 훅
def monitor_gradient(module, grad_input, grad_output):
grad = grad_output[0]
print(f"그래디언트 통계: 평균={grad.mean():.4f}, 최대={grad.max():.4f}")
# 특정 레이어에 적용
model.layer.register_full_backward_hook(monitor_gradient)
4. Hook 종류 한눈에 보기
종류 | 메서드 | 함수 형태 | 주요 용도 |
---|---|---|---|
Forward Hook | register_forward_hook |
fn(module, input, output) |
레이어 출력 확인/수정 |
Pre-Forward Hook | register_forward_pre_hook |
fn(module, input) |
레이어 입력 전처리 |
Backward Hook | register_full_backward_hook |
fn(module, grad_input, grad_output) |
그래디언트 확인/수정 |
텐서 Hook | tensor.register_hook |
fn(grad) |
특정 텐서 그래디언트 처리 |
5. 사용 시 주의사항
- 메모리 관리:
.detach()
로 텐서 복사본 만들어 메모리 누수 방지 - 훅 제거:
handle.remove()
로 사용 후 제거 - 반환값:
- Forward Hook: 값 반환 시 출력 대체됨,
None
반환 시 출력 유지 - Pre-Hook: 튜플 형태로 반환해야 함
- Forward Hook: 값 반환 시 출력 대체됨,
훅은 모델 내부를 들여다보고 디버깅하는 데 매우 유용한 PyTorch의 강력한 기능입니다!
728x90
반응형