Language-LAB/Pytorch

PyTorch에서 모델이나 레이어 호출

JS LAB 2025. 3. 24. 14:33
728x90
반응형

PyTorch 모델 호출 문법 정리

PyTorch에서 모델이나 레이어를 호출할 때 알아두면 좋은 내용을 깔끔하게 정리해드립니다.

1. 기본 호출 메커니즘

# 정의
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 5)

    def forward(self, x):
        return self.layer(x)

# 사용
model = MyModel()
output = model(input_tensor)  # model.forward(input_tensor)가 자동으로 호출됨

작동 원리:

  • model(input_tensor)는 내부적으로 model.__call__(input_tensor)를 실행
  • __call__ 메서드는 nn.Module에 정의되어 있으며, 이 메서드가 forward를 호출
  • 이 과정은 모델 계층 구조의 모든 레벨에서 재귀적으로 발생

2. 입력 매개변수 처리

# 여러 매개변수 전달
def forward(self, x, y):
    return x + y

# 호출
output = model(tensor1, tensor2)  # model.forward(tensor1, tensor2) 호출

# 선택적 매개변수
def forward(self, x, mask=None):
    if mask is not None:
        x = x * mask
    return x

# 호출
output1 = model(tensor)         # mask=None으로 호출
output2 = model(tensor, mask)   # mask 제공하여 호출

3. 다양한 입력 형태 처리하기

# 가변 인자 사용
def forward(self, *args, **kwargs):
    if len(args) == 1:
        # 입력이 하나인 경우
        return self._process_single(args[0])
    elif len(args) == 2:
        # 입력이 두 개인 경우
        return self._process_double(args[0], args[1])
    else:
        raise ValueError(f"Expected 1 or 2 arguments, got {len(args)}")

4. 호출 시 모드 관리

# 평가 모드로 설정
model.eval()
with torch.no_grad():
    output = model(test_input)  # 그래디언트 계산 없이 forward만 실행

# 학습 모드로 설정
model.train()
output = model(train_input)  # 그래디언트 계산 포함하여 forward 실행

5. 중첩된 모듈 호출

class ComplexModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(1000, 128)
        self.encoder = EncoderModule()  # 다른 nn.Module
        self.decoder = DecoderModule()  # 다른 nn.Module

    def forward(self, x):
        x = self.embed(x)          # embed.forward(x) 호출
        encoded = self.encoder(x)   # encoder.forward(x) 호출
        decoded = self.decoder(encoded)  # decoder.forward(encoded) 호출
        return decoded

6. 주의사항 및 팁

  1. 단일 forward: 하나의 클래스에 여러 forward 함수를 정의하면 마지막 것만 사용됩니다.

  2. hook 활용: register_forward_hook, register_backward_hook을 사용하여 호출 전후에 특정 작업 수행 가능.

    def hook_fn(module, input, output):
        print(f"Output shape: {output.shape}")
    
    model.register_forward_hook(hook_fn)
  3. 직접 forward 호출 지양: model.forward(x) 대신 model(x)를 사용해야 PyTorch의 훅 시스템과 모드 관리가 제대로 작동합니다.

  4. 계산 그래프 제어: with torch.no_grad():, torch.set_grad_enabled(False) 등으로 그래디언트 계산 제어 가능.

  5. JIT 컴파일 고려: 복잡한 모델은 torch.jit.script(model)로 컴파일하여 성능 향상 가능.

이러한 암묵적 호출 메커니즘은 PyTorch 코드를 더 간결하고 직관적으로 만들어줍니다.

728x90
반응형