공부하고 기록하는, 경제학과 출신 개발자의 노트

프로그래밍/이것저것_개발일지

PaliGemma 모델로 Object Detection Fine Tuning한 방법 정리

inspirit941 2024. 10. 31. 16:17
반응형

 

https://aifactory.space/task/2733/overview

 

2024 Gemma 파인튜닝톤 (아이디어톤)

🕹️ Gemma 파인튜닝 어디까지 해봤니?

aifactory.space

 

 

AIFactory Gemma 파인튜닝 아이디어톤에 제출해서, 3등 우수상으로 입상한 내용

  • PaliGemma로 Object Classification을 위한 데이터 준비 방법을 기록하기 위한 것.

PaliGemma란?

https://developers.googleblog.com/ko/gemma-explained-paligemma-architecture/

 

Gemma 설명: PaliGemma 아키텍처- Google Developers Blog

Gemma AI Announcements 전 세계 누구나 이해할 수 있도록 – Gemma 2를 사용한 다국어 AI 발전과 150,000달러가 걸린 챌린지 2024년 10월 3일

developers.googleblog.com

 

 

구글에서 2024년 6월에 발표한 sLLM 모델인 Gemma 종류 중 하나인 Vision-Language 모델.

  • 이미지와 텍스트를 입력받아서, 텍스트 응답을 생성한다.
  • Object Detection / Segmentation, Image Caption 등 다양한 기능을 사용할 수 있다고 함

학습에 사용할 데이터

https://www.aihub.or.kr/aihubdata/data/view.do?dataSetSn=71385

 

AI-Hub

샘플 데이터 ? ※샘플데이터는 데이터의 이해를 돕기 위해 별도로 가공하여 제공하는 정보로써 원본 데이터와 차이가 있을 수 있으며, 데이터에 따라서 민감한 정보는 일부 마스킹(*) 처리가 되

www.aihub.or.kr

 

AI허브에서 제공하는 '생활폐기물 데이터 활용 * 환류' 데이터를 활용한다.

 

Yolo모델 학습을 위해 구축된 데이터로, 데이터 형식은 아래와 같다.

 

스크린샷 2024-10-31 오후 3 13 21

 

{
    "objects": [
        {
            "id": "3422738f-5f99-43b6-97f3-062c43803def",
            "class_id": "292c5b9c-e203-4ec1-9ca8-7579273fc70c",
            "tracking_id": 1,
            "class_name": "c_6",
            "annotation_type": "box",
            "annotation": {
                "coord": {
                    "x": 543.73,
                    "y": 34.75,
                    "width": 182.93334597017542,
                    "height": 181.57
                },
                "meta": {
                    "z_index": 0,
                    "visible": false,
                    "alpha": 1,
                    "color": "#C9F6EA"
                }
            },
            "properties": []
        },
    ...
  ]
}

 

특이점이라면, bounding box를 위한 좌표가 Top-Left Width Height 방식이라는 것.

  • x와 y좌표가 top-left이므로, PaliGemma 학습을 위해 데이터 형식을 바꿔야 할 경우 유의해야 한다

 

PaliGemma로 Object Detection 학습시키기

PaliGemma로 데이터를 학습하려면, 데이터 형식을 아래와 같이 맞춰야 한다.

  • image: 학습에 사용할 이미지파일 이름
  • prefix: input prompt라고 보면 된다. 여러 개 이미지를 검출하려면 detect [CLASS1] ; [CLASS2] 처럼 세미콜론으로 구분한다.
  • suffix: PaliGemma가 이해할 수 있는 object의 x_min, x_min, x_max, y_max 좌표. bounding box 역할.

예시

{
  "image": "butterfly (199).jpg", 
  "prefix": "detect squirrel ; butterfly", 
  "suffix": "<loc0429><loc0022><loc0949><loc0592> butterfly ; <loc0002><loc0104><loc0429><loc0456> butterfly ; <loc0128><loc0438><loc0657><loc0997> butterfly"
}
...

 

좌표 변환

PaliGemma로 bounding box를 넘기려면 <locXXXX><locYYYY><locXXXX><locYYYY> 형태를 만들어야 한다.

  • 좌표 값은 이미지를 1024 * 1024로 normalize했을 때의 x_min, y_min, x_max, y_max 좌표값.
import json
import os
import pandas as pd
from PIL import Image

## 이미지의 width, height 확인하기
def get_image_metadata(imagePath: str):
    im = Image.open(imagePath)
    width, height = im.size
    return width, height

## 변환 함수
def convert_yolo_coords_to_llm(yolo_coord, image_width, image_height):
    x_upper = yolo_coord['x'] # left upper x
    y_upper = yolo_coord['y'] # left upper y
    width = yolo_coord['width']
    height = yolo_coord['height']


    x_min = max(0, x_upper)
    y_min = max(0, y_upper)
    x_max = min(x_min + width, image_width)
    y_max = min(y_min + height, image_height)

    # Normalize to PaliGemma's 0-1023 range
    y_min_norm = int((y_min / image_height) * 1024)
    x_min_norm = int((x_min / image_width) * 1024)
    y_max_norm = int((y_max / image_height) * 1024)
    x_max_norm = int((x_max / image_width) * 1024)

    # Create PaliGemma format string
    paligemma_format = f"<loc{y_min_norm:04d}><loc{x_min_norm:04d}><loc{y_max_norm:04d}><loc{x_max_norm:04d}>"
    return paligemma_format

## 학습에 사용할 메타데이터 정보를 metadata.csv로 저장하는 함수
def create_json_to_metadata(source_dir: str, destination_dir: str):
    image_list, prefix_list, suffix_list = [], [], []
    if not source_dir.endswith(".json"):
        json_files = glob.glob(source_dir + '/*.json', recursive=True)
    else:
        json_files = [source_dir]
    for json_file in json_files:
        # print(json_file)
        with open(json_file, "r") as f:
            data = json.load(f)

        image_name = data['Image']
        image_width, image_height = get_image_metadata(os.path.join("data", image_name))
        prefix_temp, suffix_temp = [], []
        for obj in data['objects']:
            class_name = obj['class_name']
            coords = obj['annotation']['coord']
            paligemma_bbox = convert_yolo_coords_to_llm(coords, image_width, image_height)
            prefix = class_name
            suffix = paligemma_bbox+ " " + class_name
            prefix_temp.append(prefix)
            suffix_temp.append(suffix)

        prefix = "detect " + " ; ".join(set(prefix_temp))
        suffix = " ; ".join(suffix_temp)
        prefix_list.append(prefix)
        suffix_list.append(suffix)
        image_list.append(image_name)

    ## dataset 패키지로 학습데이터를 구축하기 위한 metadata.csv 파일 생성
    df = pd.DataFrame({"file_name": image_list, "text": prefix_list, "suffix": suffix_list})
    df.to_csv(os.path.join(destination_dir, "metadata.csv"), index=False)

dataset 구축

아래와 같은 형태로 파일 시스템을 만들어준다.

data
├── A1C_20220818_000001.jpg
├── A1C_20220818_000016.jpg
├── A1C_20220818_000018.jpg
├── A1C_20220818_000027.jpg
├── A1C_20220818_000155.jpg
├── A2C_20220831_001277.jpg
├── A2C_20220831_003731.jpg
├── A2C_20220831_009409.jpg
├── A2C_20220831_009819.jpg
├── A2C_20220831_010724.jpg
├── A3C_20221018_000370.jpg
├── A3C_20221018_002413.jpg
├── A3C_20221018_003176.jpg
├── A3C_20221018_005824.jpg
├── A3C_20221018_006616.jpg
├── A4C_20221020_000000.jpg
├── A4C_20221020_000001.jpg
├── A4C_20221020_000011.jpg
├── A4C_20221020_000012.jpg
├── A4C_20221020_000014.jpg
├── A5C_20221107_000015.jpg
├── A5C_20221107_000020.jpg
├── A5C_20221107_000021.jpg
├── A5C_20221107_000022.jpg
├── A5C_20221107_000024.jpg
├── A6C_20221118_000007.jpg
├── A6C_20221118_000012.jpg
├── A6C_20221118_000018.jpg
├── A6C_20221118_000045.jpg
├── A6C_20221118_000047.jpg
├── A7C_20221118_000042.jpg
├── A7C_20221118_000047.jpg
├── A7C_20221118_000071.jpg
├── A7C_20221118_000115.jpg
├── A7C_20221118_000117.jpg
├── A8C_20221123_000000.jpg
├── A8C_20221123_000003.jpg
├── A8C_20221123_000004.jpg
├── A8C_20221123_000005.jpg
├── A8C_20221123_000006.jpg
├── A9C_20221124_000000.jpg
├── A9C_20221124_000001.jpg
├── A9C_20221124_000006.jpg
├── A9C_20221124_000008.jpg
├── A9C_20221124_000009.jpg
├── B10_20221115_000006.jpg
├── B10_20221115_000007.jpg
├── B10_20221115_000008.jpg
├── B10_20221115_000009.jpg
├── B10_20221115_000010.jpg
├── B1_20220715_000000.jpg
├── B1_20220715_000001.jpg
├── B1_20220715_000002.jpg
├── B1_20220715_000003.jpg
├── B1_20220715_000004.jpg
├── B2_20220721_000000.jpg
├── B2_20220721_000002.jpg
├── B2_20220721_000007.jpg
├── B2_20220721_000010.jpg
├── B2_20220721_000019.jpg
├── B3_20220823_000002.jpg
├── B3_20220823_000003.jpg
├── B3_20220823_000004.jpg
├── B3_20220823_000006.jpg
├── B3_20220823_000007.jpg
├── B4_20220916_000000.jpg
├── B4_20220916_000002.jpg
├── B4_20220916_000005.jpg
├── B4_20220916_000006.jpg
├── B4_20220916_000007.jpg
├── B5_20220926_000011.jpg
├── B5_20220926_000012.jpg
├── B5_20220926_000013.jpg
├── B5_20220926_000014.jpg
├── B5_20220926_000015.jpg
├── B6_20221017_000005.jpg
├── B6_20221017_000011.jpg
├── B6_20221017_000012.jpg
├── B6_20221017_000015.jpg
├── B6_20221017_000018.jpg
├── B7_20221020_005837.jpg
├── B7_20221020_005841.jpg
├── B7_20221020_005846.jpg
├── B7_20221020_005849.jpg
├── B7_20221020_005853.jpg
├── B8_20221101_000054.jpg
├── B8_20221101_000060.jpg
├── B8_20221101_000063.jpg
├── B8_20221101_000065.jpg
├── B8_20221101_000066.jpg
├── B9_20221104_000019.jpg
├── B9_20221104_000024.jpg
├── B9_20221104_000026.jpg
├── B9_20221104_000028.jpg
├── B9_20221104_000030.jpg
├── C_20220715_000000.jpg
├── C_20220715_000005.jpg
├── C_20220715_000009.jpg
├── C_20220715_000010.jpg
├── C_20220715_000011.jpg
├── C_20220715_000012.jpg
├── C_20220715_000017.jpg
├── C_20220715_000018.jpg
├── C_20220715_000019.jpg
├── C_20220715_000023.jpg
└── metadata.csv

0 directories, 106 files

 

metadata.csv 파일 형태는 아래와 같다.

 

스크린샷 2024-10-31 오후 3 53 17



위와 같은 파일 디렉토리를 datasets 패키지로 불러와서 학습에 사용하려면, 아래 코드를 실행한다.

 

from datasets import load_dataset, load_from_disk

dataset = load_dataset("imagefolder", data_dir="data")

학습에 사용한 코드 전체

 

전반적인 파인튜닝 코드는 reference에 적힌 유튜브 영상의 소스코드를 거의 그대로 사용했다.

파라미터 튜닝해가며 학습 성능을 비교할 만큼 GPU 리소스가 충분하지 않았기 때문.

 

'아이디어톤' 제출 코드라서, 말 그대로 PoC처럼 학습이 가능하다는 것만 증명해도 충분하다고 생각했다.

 

 

colab 인스턴스의 제한시간이 짧아서, 제대로 된 학습을 완전히 수행하기는 어려웠음.

Reference

반응형