boostcampaitech4lv23nlp1 / final-project-level3-nlp-03

Multi-Modal Model for DocVQA(Document Visual Question Answering)
3 stars 0 forks source link

Variables should be reassigned when using Torch function 'clamp'. #25

Closed chanmuzi closed 1 year ago

chanmuzi commented 1 year ago

기존 BaselineTrainer.py를 구성하는 코드의 일부는 아래와 같다.

ignored_index = start_logits.size(1)
start_positions = batch['start_positions'].to(self.device)
end_positions = batch['end_positions'].to(self.device)

start_positions.clamp(0, ignored_index)
end_positions.clamp(0, ignored_index)

criterion = nn.CrossEntropyLoss(ignore_index=ignored_index)
loss = (criterion(start_logits, start_positions) + criterion(end_logits, end_positions)) / 2

이때 start_positionsend_positions에 적용되는 torch의 함수 clamp는 기존의 변수에서 변경된 값을 출력할 뿐이고 변경된 값을 저장(할당)해주지는 않는다. 따라서 처음에 의도한 바와 달리 min = 0, max = ignored_index가 반영되지 않은 상태로 loss를 계산하고 있는 것이다.

clamp함수가 기존 변수의 값을 변경하지 않는다는 것을 확인하는 코드와 결과는 아래와 같다. (실제 Trainer에서 돌아가는 것과 마찬가지로 3차원 tensor를 예시로 들었다)

example2 = torch.rand(2,3,4)

print(example2.clamp(0,0.5))
print(example2)

print('-------------------------------------------')
print(torch.clamp(example2,0,0.5))
print(example2)
tensor([[[0.5000, 0.0361, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.0982, 0.5000, 0.5000, 0.5000]],

        [[0.1020, 0.4356, 0.5000, 0.5000],
         [0.5000, 0.1884, 0.5000, 0.5000],
         [0.4171, 0.4317, 0.5000, 0.3482]]])
tensor([[[0.7487, 0.0361, 0.8313, 0.8034],
         [0.7987, 0.9131, 0.7549, 0.6861],
         [0.0982, 0.8372, 0.7606, 0.9634]],

        [[0.1020, 0.4356, 0.8022, 0.8345],
         [0.9490, 0.1884, 0.8695, 0.8135],
         [0.4171, 0.4317, 0.7643, 0.3482]]])
-------------------------------------------
tensor([[[0.5000, 0.0361, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.0982, 0.5000, 0.5000, 0.5000]],

        [[0.1020, 0.4356, 0.5000, 0.5000],
         [0.5000, 0.1884, 0.5000, 0.5000],
         [0.4171, 0.4317, 0.5000, 0.3482]]])
tensor([[[0.7487, 0.0361, 0.8313, 0.8034],
         [0.7987, 0.9131, 0.7549, 0.6861],
         [0.0982, 0.8372, 0.7606, 0.9634]],

        [[0.1020, 0.4356, 0.8022, 0.8345],
         [0.9490, 0.1884, 0.8695, 0.8135],
         [0.4171, 0.4317, 0.7643, 0.3482]]])

clamp 함수는 변수.clamp(최솟값,최댓값) 또는 torch.clamp(변수,최솟값,최댓값) 형태로 쓰인다. 둘 중 어떤 것을 사용해도 원래 변수 example2의 값이 변경되지 않았다는 것을 확인할 수 있다.

따라서 변수 = 변수.clamp(최솟값,최댓값) 또는 변수 = torch.clamp(변수,최솟값,최댓값) 으로 변경해야 한다.

chanmuzi commented 1 year ago

fix: clamp #25에 반영 완료