본문 바로가기

심층학습/Natural Language Processing

순환 신경망(Recurrent Neural Network) - Teacher Forcing & BPTT

본 포스팅은 Ian Goodfellow 등 2인이 저술한 심층학습(Deep Learning Adaptive Computation and Machine Learning)및 오일석 저술의 기계학습, 사이토 고키 저술의 밑바닥부터 시작하는 딥러닝 등을 기반으로 작성합니다.

 

저번 포스팅에서 RNN에 대해서 개괄적으로 이해해보았으니, 이번 포스팅에서는 Teacher Forcing과 BPTT에 대해서 이해해보도록 하겠습니다. 

 

1. Teacher Forcing

 

티쳐 포싱은 RNN에서 Target 단어를 디코더의 다음 입력으로 넣어주는 기법을 의미합니다. 예시를 통해 이해해보도록 하죠.

원래 예측하려는 결과값은 "Two People Running~"이라고 가정합시다. 이를 예측하기 위해 Two라는 단어를 입력했습니다. 그런데 "Two"라는 단어 뒤에 "Birds"를 예측했습니다. 그러면 RNN은 이전의 예측값을 다시 입력값으로 활용하니 "Two Birds Flying~" 다음과 같이 예측합니다. 이전 예측을 고려해주는 RNN의 장점이, 한 번 예측이 잘못된 경우 첫 단추를 잘못 끼운거마냥 그 후의 예측이 모두 틀리게 되면서 엄청난 단점으로 작용하는 것입니다. 

 

그래서 Teacher Forcing을 이용하기도 합니다. 예측을 "Two Birds"를 했음에도 불구하고, 실제 데이터 값(Ground Truth)인 "Two People"을 넣어줌으로써, "Two birds Running~"과 같이 후의 예측 오차를 줄여주는 것입니다. 

 

간단히 비유를 하자면, 

문제 A : a값을 구하시오
문제 B : 문제 A에서 구한 답을 이용해 b를 구하시오
문제 C : 문제 B에서 구한 답을 이용해 c를 구하시오

다음과 같은 문제가 있다고 가정합시다. 이 문제들은 문제 A에서 푼 답을 이용해 B를 풀고, C를 푸는 서로 연결된 문제입니다. 

만일 Teacher Forcing을 사용하지 않을 경우 학생이 문제 A, B, C를 순서대로 풀고, 답을 한꺼번에 작성해서 제출함 -> 이를 한꺼번에 채점해 점수를 알려주는 시스템입니다. 

Teacher Forcing을 사용하면, 학생이 문제 A를 풀고 답을 제출하면 그 답을 채점한 뒤 정답 a를 알려줍니다. 그리고 학생은 정답 a를 가지고 문제 B를 풀고 정답 b를 제출하고 이것이 계속 반복됩니다.

 

학생의 점수는 어떤 게 높을까요. 당연히 후자가 높을 것입니다. 전자의 경우 첫단추를 잘못 꿰는 순간, 뒤의 문제는 우수수 틀릴 테니까 말이죠. 결국 Teacher Forcing의 장점은 학습 속도를 빠르게 할 수 있다는 점입니다.

 

2. BPTT(Back-Propagation through Time) RNN : 시간방향으로 펼친 신경망의 오차역전파법

RNN에도 역시 역전파 계산이 존재합니다. 역전파 계산이란 인공지능이 순전파로 계산한 예측값과 실제값의 오차를 이용해, 경사 하강법을 통해서 가중치를 업데이트하는 과정입니다. 하지만 RNN은 시계열 데이터(Sequential Data)를 입력으로 받기에, 이를 고려해줄 필요가 있습니다. 하지만 과정은 거의 동일합니다. Loss값의 Gradient를 계산하고 그 값을 이용해 Parameter를 수정하는거죠.

 

그런데 중요한 것은 시계열 데이터의 특성을 고려해줄 필요가 있다고 했죠? 그게 뭐냐면 바로 시계열 데이터는 이전 값들을 누적해 저장하는 특성이 있습니다. 만일 RNN 계층의 중간 데이터를 저장해두지 않는다면, 이런 특성을 잃어버리게 되죠. 결국, 시계열 데이터가 길어지면 길어질 수록, BPTT가 소모하는 메모리 자원의 양이 증가한다는 것입니다. 또한 시간의 크기가 커지면 역전파때의 기울기 불안정 현상이 나타납니다.

 

3. Truncated BPTT

결국 이를 해결하기 위한 방법으로는 Truncated BPTT를 이용할 수 있습니다. Truncated의 뜻은 "잘린" 이라는 뜻으로, 적당한 길이로 데이터를 '잘라내서' 오차역전파법을 수행하는 것이죠. 신경망의 연결을 일부 끊고, 데이터의 크기를 줄이는 방법입니다. 그런데 가장 중요한 점은 "순전파"의 연결을 끊어서는 안 된다는 점입니다! 역전파의 연결만 적당한 길이로 잘라내서 학습을 하는 것입니다.

 

예를 들어보죠. 길이가 100인 데이터가 있다고 할 때, 자연어 처리 문제에서는 단어 100개짜리 말뭉치가 될 것입니다. 이 데이터를 RNN 계층으로 쭉 펼치면 가로로 100인 긴 신경망이 됩니다. 이렇게 되면 두 가지 문제가 발생합니다. 

 

1. 계층이 길어지면서 계산량과 메모리 사용량이 기하급수적으로 증가합니다. 

2. 계층이 길어질때마다 Gradient의 값이 조금씩 작아져서, 초기 시간에 도달하기 전에 0이 되어 기울기가 소멸할 수 있습니다. 

'심층학습 > Natural Language Processing' 카테고리의 다른 글

Sequencial Model  (0) 2023.11.24
순환 신경망(Recurrent Neural Network) - 개요  (0) 2022.03.12