본문 바로가기
Language-LAB/Pytorch

offsets 메커니즘

by JS LAB 2025. 3. 23.
728x90
반응형

EmbeddingLayer의 offsets 메커니즘 상세 설명

offsets 메커니즘은 이 클래스의 가장 핵심적인 부분이며, 여러 필드의 범주형 특성을 하나의 임베딩 테이블에서 충돌 없이 관리하기 위한 기법입니다. 단계별로 자세히 설명해드리겠습니다.

 

 

# ref: layers.py in https://github.com/morningsky/multi_task_learning
class EmbeddingLayer(torch.nn.Module):

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return self.embedding(x)

 

1. 문제 상황 이해하기

먼저 이 코드가 해결하려는 문제를 이해해봅시다:

  • 여러 개의 범주형 필드가 있음 (예: 사용자 ID, 상품 ID, 카테고리 등)
  • 각 필드마다 서로 다른 수의 고유 카테고리가 존재 (예: 1000명의 사용자, 500개의 상품, 100개의 카테고리)
  • 이 모든 필드의 카테고리들을 하나의 통합 임베딩 테이블에서 관리하고자 함

여기서 핵심 문제는: "서로 다른 필드에서 같은 값을 가진 카테고리를 어떻게 구분할 것인가?" 입니다.

2. offsets 코드 분석

self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)

이 한 줄의 코드를 단계별로 분해하겠습니다:

2.1 누적합 계산: np.cumsum(field_dims)

np.cumsum()은 누적 합계를 계산합니다.

예를 들어 field_dims = [1000, 500, 100]일 때:

  • np.cumsum(field_dims) = [1000, 1500, 1600]

이 값들은 각 필드의 끝 지점에 해당합니다:

  • 1000: 첫 번째 필드(사용자 ID)의 끝 인덱스
  • 1500: 두 번째 필드(상품 ID)의 끝 인덱스
  • 1600: 세 번째 필드(카테고리)의 끝 인덱스

2.2 슬라이싱: [:-1]

누적합에서 마지막 값을 제외합니다.

  • [1000, 1500, 1600][:-1] = [1000, 1500]

이는 "현재 필드의 끝 인덱스"가 아닌 "다음 필드의 시작 인덱스"를 구하기 위함입니다.

2.3 리스트 앞에 0 추가: (0, *...)

*는 Python의 언패킹 연산자로, 리스트/배열의 요소들을 풀어서 개별 요소로 만듭니다.

  • (0, *[1000, 1500]) = (0, 1000, 1500)

0을 맨 앞에 추가하는 이유는 첫 번째 필드의 시작 인덱스가 0이기 때문입니다.

2.4 최종 결과: np.array(..., dtype=np.long)

최종적으로 Numpy 배열로 변환하고 자료형을 지정합니다.

  • np.array((0, 1000, 1500), dtype=np.long) = array([0, 1000, 1500])

3. 실제 작동 방식 - 시각적 예시

field_dims = [1000, 500, 100]일 때 임베딩 테이블과 오프셋의 관계:

임베딩 테이블 인덱스:
0                                999 1000                      1499 1500                 1599
|-----------------------------------|--------------------------------|----------------------|
        사용자 ID 범위 (1000개)                상품 ID 범위 (500개)         카테고리 범위 (100개)

오프셋 값:
[0,                               1000,                              1500]
 ↑                                 ↑                                  ↑
 사용자 ID 필드 시작 인덱스         상품 ID 필드 시작 인덱스           카테고리 필드 시작 인덱스

4. forward 메서드에서의 활용

def forward(self, x):
    x = x + x.new_tensor(self.offsets).unsqueeze(0)
    return self.embedding(x)

4.1 입력 예시

x가 다음과 같이 주어진다고 가정해봅시다:

[
  [5, 42, 7],  # 첫 번째 샘플: 사용자5, 상품42, 카테고리7
  [18, 30, 12]  # 두 번째 샘플: 사용자18, 상품30, 카테고리12
]

4.2 오프셋 적용

x.new_tensor(self.offsets)  # [0, 1000, 1500]

.unsqueeze(0)는 차원을 추가하여 배치 처리를 가능하게 합니다:

[[0, 1000, 1500]]  # 형태: [1, 3]

이제 브로드캐스팅을 통해 배치의 모든 샘플에 오프셋을 더합니다:

[5, 42, 7] + [0, 1000, 1500] = [5, 1042, 1507]
[18, 30, 12] + [0, 1000, 1500] = [18, 1030, 1512]

결과:

[
  [5, 1042, 1507],
  [18, 1030, 1512]
]

4.3 임베딩 룩업

이제 변환된 인덱스로 임베딩 테이블을 조회합니다:

  • 인덱스 5: 사용자 ID 5의 임베딩 벡터 (오프셋 0 + 5)
  • 인덱스 1042: 상품 ID 42의 임베딩 벡터 (오프셋 1000 + 42)
  • 인덱스 1507: 카테고리 7의 임베딩 벡터 (오프셋 1500 + 7)

등을 조회합니다.

5. 왜 이런 방식이 필요한가?

5.1 인덱스 충돌 방지

오프셋 없이 원래 인덱스를 그대로 사용하면:

  • 사용자 ID 5와 상품 ID 5는 같은 임베딩 벡터를 공유하게 됨
  • 하지만 이는 의미적으로 전혀 다른 엔티티임!

오프셋을 사용하면:

  • 사용자 ID 5 → 인덱스 5
  • 상품 ID 5 → 인덱스 1005 (1000 + 5)
  • 의미적으로 다른 엔티티가 다른 임베딩 벡터를 가지게 됨

5.2 통합 임베딩 테이블의 이점

별도의 임베딩 테이블을 사용하는 대신 오프셋으로 통합 테이블을 사용하는 이유:

  1. 구현 단순화: 하나의 임베딩 레이어만 관리
  2. 메모리 효율성: 특히 희소 필드가 있는 경우 유용
  3. 배치 처리 최적화: 단일 임베딩 조회 연산으로 모든 필드 처리 가능

6. 실제 예시 계산

field_dims = [1000, 500, 100], embed_dim = 16일 때:

  1. 임베딩 테이블 크기: [1600, 16] (1000 + 500 + 100 = 1600개 행, 각 16차원)
  2. 오프셋: [0, 1000, 1500]

입력 x = [[3, 45, 80]]일 때:

  1. 오프셋 적용: [3, 45, 80] + [0, 1000, 1500] = [3, 1045, 1580]
  2. 임베딩 조회:
    • 임베딩 테이블에서 3번, 1045번, 1580번 행 조회
    • 결과 형태: [1, 3, 16] (1개 샘플, 3개 필드, 각 16차원 벡터)

이 방식으로 각 필드의 범주형 특성이 고유한 임베딩 벡터를 가지면서도, 코드는 단순하고 효율적으로 유지됩니다.

728x90
반응형

'Language-LAB > Pytorch' 카테고리의 다른 글

Hook 메커니즘  (0) 2025.03.24
PyTorch에서 모델이나 레이어 호출  (0) 2025.03.24
torch.nn.Embedding  (0) 2025.03.23
torch.nn.Sequential  (0) 2025.03.23