본문 바로가기
Language-LAB/Pytorch

torch.nn.Embedding

by JS LAB 2025. 3. 23.
728x90
반응형
import torch.nn as nn

# 크기 3의 텐서 10개가 포함된 임베딩 모듈
embedding = nn.Embedding(10, 3)

# 각각 4개의 인덱스로 구성된 2개의 표본 배치
input = torch.LongTensor([[1,2,3,4], [5,6,7,8]])
embedding(input)

 

사전 크기(num_embeddings=10)는 nn.Embedding(10, 3)에서 첫 번째 인자로 설정되며, 이는 임베딩 테이블(행렬)의 크기를 결정합니다. 이를 좀 더 구체적으로 설명하면:

nn.Embedding(10, 3)에서 사전 크기(10)의 역할

  • nn.Embedding(10, 3)은 크기 (10, 3)인 행렬을 내부적으로 생성합니다.
  • 즉, 10개의 행이 있고, 각 행마다 3차원 벡터가 존재합니다.
  • 각 행(인덱스 0~9)은 특정 단어(토큰)와 매핑됩니다.

input = torch.LongTensor([[1,2,3,4], [5,6,7,8]])가 임베딩에 적용될 때

  • input은 다음과 같은 형태를 가짐: 
  • [[1, 2, 3, 4], [5, 6, 7, 8]]
  • embedding(input)을 호출하면, 각 정수값을 인덱스로 사용하여 임베딩 테이블에서 해당하는 행을 가져옴.
  • 예를 들어, embedding(input)이 작동하는 방식:
    • 1 → 임베딩 테이블의 1번 행을 가져옴 → [x1, y1, z1]
    • 2 → 임베딩 테이블의 2번 행을 가져옴 → [x2, y2, z2]
    • 4 → 임베딩 테이블의 4번 행을 가져옴 → [x4, y4, z4]
    • 5 → 임베딩 테이블의 5번 행을 가져옴 → [x5, y5, z5]

정리

  • 사전 크기 10은 내부 임베딩 테이블의 행 개수를 의미함.
  • 인덱스 0~9까지만 접근 가능하며, 그 이상의 값이 들어가면 에러 발생 (IndexError).
  • embedding(input)은 input에서 제공된 인덱스를 사용해 임베딩 테이블에서 해당 벡터(3차원)를 가져옴.

즉, 사전 크기(10)는 모델 내부의 임베딩 행렬 크기를 정의하는 역할을 하고, 실제 동작에서는 input 텐서에 포함된 인덱스를 사용해 해당 임베딩 벡터를 조회하는 방식입니다.

728x90
반응형

'Language-LAB > Pytorch' 카테고리의 다른 글

Hook 메커니즘  (0) 2025.03.24
PyTorch에서 모델이나 레이어 호출  (0) 2025.03.24
offsets 메커니즘  (0) 2025.03.23
torch.nn.Sequential  (0) 2025.03.23