데이터야놀자 2023 - Whisper 모델의 불필요한 Weight 줄여서 학습비용 절약하기
https://youtu.be/MDE2HRsfr7g?si=7s4yRsOlUSZ8_nnu
Socar AI Research 조충현님.
- STT 모델 경량화 연구내용의 소개.
카셰어링 업체가 자체 STT를 구축해야 했던 이유?
- CS 상담사의 업무에는 여러 가지가 있다.
- 상담가이드 문서 검색, 유저 정보 확인, 예약내역 확인, 상담내용 기록...
- 이걸 매 상담마다 반복해야 함.
- 이렇다보니 상담이 많아지면 대기시간이 길어짐... CS만족도 하락.
상담원은 고객업무에만 집중하도록, 상담 외 업무는 AI를 사용하는 방향으로.
그러려면, 고객과 상담원의 대화 내역을 문서화하는 작업은 필수.
- 문서가 있어야 AI가 뭐라도 할 수 있기 때문.
Whisper: OpenAI에서 공개한 오픈소스 STT.
- 코드와 weight 공개되어 있음
- MultiTask Model (STT / Translate 가능)
- 영어 위주로 학습되어 있었기에 한글 성능은 좋지 않다.
그러나 모델 크기가 클수록 추가학습에 드는 비용이 크다.
모델 학습비용은 아끼면서 성능을 유지하는 방법은?
어떻게 학습 비용을 낮출 것인가
한국어 데이터만 추가학습하면 되는데, 모델 전체의 파라미터를 학습대상으로 쓴다...
- Pruning과 Adapter 기능을 써보자.
Adapter: 작은 사이즈의 network layer 추가하고, 해당 layer만 학습에 참여시키는 방식.
- 전체 모덿을 학습하는 게 아니므로 비용 절감, 적은 리소스로 학습 가능.
- 예컨대 LLaMa 모델 전체를 학습하는 게 아니라, adapter에 해당하는 영역만 학습한 뒤 LLaMa에 붙이는 식.
방법: Whisper 모델에 LoRA adapter를 붙인다.
Adapter 중에서는 LoRA 사용.
- 모델 output을 input으로 넣는 방식이 아니라, 모델과 같은 형태의 input 받아서 병렬처리한 뒤 모델의 output과 합치는 방식.
- Whisper 모델에 Lora Adapter를 붙이고
- 한국어 인식에 덜 중요한 파라미터는 Pruning으로 제거.
- Pruning 기법은 Lottery Ticket Hypothesis 사용.
Adapter 붙이고 pruning... Adapter와 연결된 파라미터가 Pruning으로 삭제되는 걸 방지하기 위함.
- adapter와 연결된 파라미터가 삭제되면 학습 성능에 영향을 미치기 때문.
Lottery Ticket Hypothesis (LTH)
- 모델 파라미터 초기화 -> N회 반복하며 학습 -> 학습된 파라미터의 p%를 pruning
- 제거된 파라미터 빼고 나머지는 학습 전 값으로 초기화.
학습 전 값으로 초기화했더니 성능저하가 많이 없었다는 게 논문의 내용.
결과 및 평가
- 의도대로 동작했는가
- 다른 한국어 데이터에도 잘 동작하는가?
- Pruning의 효과는 어느 정도인가?
STT 성능 평가 방법: CER / WER.
- Character Error Rate / Word Error Rate의 약자.
- 음절 / 단어 오류율을 계산한 수치.
모델이 잘못 삭제한 거, 잘못 대체된 거, 잘못 추가된 거 전부 더해서 비율 계산하는 것. 수치가 낮을수록 좋다.
강의자료에 N (정답) 기호가 빠져 있는데, 수치가 낮을수록 좋다는 표현으로 보았을 땐 ((S+D+I) / N) * 100 인데 오타난 것으로 보임
기법이 원하는대로 작동하는가?
- Zero-shot: pretrained 모델 그대로
- FFT: 모델 전체를 학습
- LoRA: adapter 붙여서 adapter만 학습
- LTH: LTH 튜닝
학습 결과 더 적은 파라미터로도 성능 하락 없이 비슷한 수준의 퍼포먼스를 기록했고, 가장 학습데이터가 많았던 영어도 모델 전체를 학습한 것과 비슷한 성능을 보임
다른 한국어 데이터에서도 성능이 괜찮은가
네이버 클로바에서 제공하는 고객센터 도메인 음성데이터 사용.
- 전체 학습한 결과가 성능이 제일 좋지만, 최적화 기법을 사용한 모델이 다른 비교군보다 성능이 좋다.
Pruning의 적정 수치는 어느 정도?
50% 까지는 기존 모델 대비 준수한 성능을 보이지만, 50% 넘어갈 경우 성능 하락이 점차 커진다.
정리
결과 정리해서 논문도 제출.
이후 고민점
- 성능 향상 방법은?
- 더 나은 경량화 방법은?
- Pruning 속도 향상법?
Q&A
Q. STT에서 주로 쓰는 경량화 기법이 있는지? / STT 쪽 주요 이슈는?
우리가 찾아본 바로는 STT 쪽의 경량화 기법이 없었다. 그래서 LLM 쪽 경량화 기법을 실험해본 것.
잘 적용이 된다는 걸 증명함.
음성 데이터... Whisper 이전까지는 STT 쪽 성능 자체가 좋지 않았음. 지금도 whisper가 완벽하지 않기에, 성능 쪽 이슈가 많음.
Q. STT 쪽 테스트 / 성능 검사는 어떻게 하나?
정답이 있고, whisper가 만들어낸 문장을 찾아서 Error rate 계산해서 비교하는 식.
Q. 중요하지 않은 파라미터는 어떻게 파악해야 하나?
Pruning: 학습해서 바뀌는 파라미터 값이 있는데, 변한 값의 magnitude 기준으로 가장 작은 것들 상위 n% 날리는 식. (절대값으로 치환했을 때 가장 변화량이 적은 것)
원래 Pruning에 논리적인 근거는 딱히 없음. 경험적으로 이렇게 하니까 좋던데? 정도.
Q. whisper의 base 모델이 궁금하다
음성 받아서 CNN 통과 -> transformer 통과하면 text 변환.
Q. LoRA에는 LTH 적용 안한 거 맞죠?
그렇다. LoRA 연결한 뒤 whisper에만 LTH 적용한 것.
Q. 1~3번 학습 반복 횟수는 어느 정도였나?
5회를 넘지 않는다. 5회 넘으면 비용이슈도 있고, 학습 시간도 오래 걸린다.
Pruning 소요시간도 데이터양에 따라 다르지만, 5회 반복 시 3시간은 학습에 쓰임.
Pruning하려면 학습 반복이 필요한데... 학습시간이 오래 걸린다는 단점이 있음.