공부하고 기록하는, 경제학과 출신 개발자의 노트

프로그래밍/이것저것_개발일지

PaliGemma 모델로 Object Detection Fine Tuning한 방법 정리

inspirit941 2024. 10. 31. 16:17
반응형

 

https://aifactory.space/task/2733/overview

 

2024 Gemma 파인튜닝톤 (아이디어톤)

🕹️ Gemma 파인튜닝 어디까지 해봤니?

aifactory.space

 

 

AIFactory Gemma 파인튜닝 아이디어톤에 제출해서, 3등 우수상으로 입상한 내용

  • PaliGemma로 Object Classification을 위한 데이터 준비 방법을 기록하기 위한 것.

PaliGemma란?

https://developers.googleblog.com/ko/gemma-explained-paligemma-architecture/

 

Gemma 설명: PaliGemma 아키텍처- Google Developers Blog

Gemma AI Announcements 전 세계 누구나 이해할 수 있도록 – Gemma 2를 사용한 다국어 AI 발전과 150,000달러가 걸린 챌린지 2024년 10월 3일

developers.googleblog.com

 

 

구글에서 2024년 6월에 발표한 sLLM 모델인 Gemma 종류 중 하나인 Vision-Language 모델.

  • 이미지와 텍스트를 입력받아서, 텍스트 응답을 생성한다.
  • Object Detection / Segmentation, Image Caption 등 다양한 기능을 사용할 수 있다고 함

학습에 사용할 데이터

https://www.aihub.or.kr/aihubdata/data/view.do?dataSetSn=71385

 

AI-Hub

샘플 데이터 ? ※샘플데이터는 데이터의 이해를 돕기 위해 별도로 가공하여 제공하는 정보로써 원본 데이터와 차이가 있을 수 있으며, 데이터에 따라서 민감한 정보는 일부 마스킹(*) 처리가 되

www.aihub.or.kr

 

AI허브에서 제공하는 '생활폐기물 데이터 활용 * 환류' 데이터를 활용한다.

 

Yolo모델 학습을 위해 구축된 데이터로, 데이터 형식은 아래와 같다.

 

스크린샷 2024-10-31 오후 3 13 21

 

{
    "objects": [
        {
            "id": "3422738f-5f99-43b6-97f3-062c43803def",
            "class_id": "292c5b9c-e203-4ec1-9ca8-7579273fc70c",
            "tracking_id": 1,
            "class_name": "c_6",
            "annotation_type": "box",
            "annotation": {
                "coord": {
                    "x": 543.73,
                    "y": 34.75,
                    "width": 182.93334597017542,
                    "height": 181.57
                },
                "meta": {
                    "z_index": 0,
                    "visible": false,
                    "alpha": 1,
                    "color": "#C9F6EA"
                }
            },
            "properties": []
        },
    ...
  ]
}

 

특이점이라면, bounding box를 위한 좌표가 Top-Left Width Height 방식이라는 것.

  • x와 y좌표가 top-left이므로, PaliGemma 학습을 위해 데이터 형식을 바꿔야 할 경우 유의해야 한다

 

PaliGemma로 Object Detection 학습시키기

PaliGemma로 데이터를 학습하려면, 데이터 형식을 아래와 같이 맞춰야 한다.

  • image: 학습에 사용할 이미지파일 이름
  • prefix: input prompt라고 보면 된다. 여러 개 이미지를 검출하려면 detect [CLASS1] ; [CLASS2] 처럼 세미콜론으로 구분한다.
  • suffix: PaliGemma가 이해할 수 있는 object의 x_min, x_min, x_max, y_max 좌표. bounding box 역할.

예시

{
  "image": "butterfly (199).jpg", 
  "prefix": "detect squirrel ; butterfly", 
  "suffix": "<loc0429><loc0022><loc0949><loc0592> butterfly ; <loc0002><loc0104><loc0429><loc0456> butterfly ; <loc0128><loc0438><loc0657><loc0997> butterfly"
}
...

 

좌표 변환

PaliGemma로 bounding box를 넘기려면 <locXXXX><locYYYY><locXXXX><locYYYY> 형태를 만들어야 한다.

  • 좌표 값은 이미지를 1024 * 1024로 normalize했을 때의 x_min, y_min, x_max, y_max 좌표값.
import json
import os
import pandas as pd
from PIL import Image

## 이미지의 width, height 확인하기
def get_image_metadata(imagePath: str):
    im = Image.open(imagePath)
    width, height = im.size
    return width, height

## 변환 함수
def convert_yolo_coords_to_llm(yolo_coord, image_width, image_height):
    x_upper = yolo_coord['x'] # left upper x
    y_upper = yolo_coord['y'] # left upper y
    width = yolo_coord['width']
    height = yolo_coord['height']


    x_min = max(0, x_upper)
    y_min = max(0, y_upper)
    x_max = min(x_min + width, image_width)
    y_max = min(y_min + height, image_height)

    # Normalize to PaliGemma's 0-1023 range
    y_min_norm = int((y_min / image_height) * 1024)
    x_min_norm = int((x_min / image_width) * 1024)
    y_max_norm = int((y_max / image_height) * 1024)
    x_max_norm = int((x_max / image_width) * 1024)

    # Create PaliGemma format string
    paligemma_format = f"<loc{y_min_norm:04d}><loc{x_min_norm:04d}><loc{y_max_norm:04d}><loc{x_max_norm:04d}>"
    return paligemma_format

## 학습에 사용할 메타데이터 정보를 metadata.csv로 저장하는 함수
def create_json_to_metadata(source_dir: str, destination_dir: str):
    image_list, prefix_list, suffix_list = [], [], []
    if not source_dir.endswith(".json"):
        json_files = glob.glob(source_dir + '/*.json', recursive=True)
    else:
        json_files = [source_dir]
    for json_file in json_files:
        # print(json_file)
        with open(json_file, "r") as f:
            data = json.load(f)

        image_name = data['Image']
        image_width, image_height = get_image_metadata(os.path.join("data", image_name))
        prefix_temp, suffix_temp = [], []
        for obj in data['objects']:
            class_name = obj['class_name']
            coords = obj['annotation']['coord']
            paligemma_bbox = convert_yolo_coords_to_llm(coords, image_width, image_height)
            prefix = class_name
            suffix = paligemma_bbox+ " " + class_name
            prefix_temp.append(prefix)
            suffix_temp.append(suffix)

        prefix = "detect " + " ; ".join(set(prefix_temp))
        suffix = " ; ".join(suffix_temp)
        prefix_list.append(prefix)
        suffix_list.append(suffix)
        image_list.append(image_name)

    ## dataset 패키지로 학습데이터를 구축하기 위한 metadata.csv 파일 생성
    df = pd.DataFrame({"file_name": image_list, "text": prefix_list, "suffix": suffix_list})
    df.to_csv(os.path.join(destination_dir, "metadata.csv"), index=False)

dataset 구축

아래와 같은 형태로 파일 시스템을 만들어준다.

data
├── A1C_20220818_000001.jpg
├── A1C_20220818_000016.jpg
├── A1C_20220818_000018.jpg
├── A1C_20220818_000027.jpg
├── A1C_20220818_000155.jpg
├── A2C_20220831_001277.jpg
├── A2C_20220831_003731.jpg
├── A2C_20220831_009409.jpg
├── A2C_20220831_009819.jpg
├── A2C_20220831_010724.jpg
├── A3C_20221018_000370.jpg
├── A3C_20221018_002413.jpg
├── A3C_20221018_003176.jpg
├── A3C_20221018_005824.jpg
├── A3C_20221018_006616.jpg
├── A4C_20221020_000000.jpg
├── A4C_20221020_000001.jpg
├── A4C_20221020_000011.jpg
├── A4C_20221020_000012.jpg
├── A4C_20221020_000014.jpg
├── A5C_20221107_000015.jpg
├── A5C_20221107_000020.jpg
├── A5C_20221107_000021.jpg
├── A5C_20221107_000022.jpg
├── A5C_20221107_000024.jpg
├── A6C_20221118_000007.jpg
├── A6C_20221118_000012.jpg
├── A6C_20221118_000018.jpg
├── A6C_20221118_000045.jpg
├── A6C_20221118_000047.jpg
├── A7C_20221118_000042.jpg
├── A7C_20221118_000047.jpg
├── A7C_20221118_000071.jpg
├── A7C_20221118_000115.jpg
├── A7C_20221118_000117.jpg
├── A8C_20221123_000000.jpg
├── A8C_20221123_000003.jpg
├── A8C_20221123_000004.jpg
├── A8C_20221123_000005.jpg
├── A8C_20221123_000006.jpg
├── A9C_20221124_000000.jpg
├── A9C_20221124_000001.jpg
├── A9C_20221124_000006.jpg
├── A9C_20221124_000008.jpg
├── A9C_20221124_000009.jpg
├── B10_20221115_000006.jpg
├── B10_20221115_000007.jpg
├── B10_20221115_000008.jpg
├── B10_20221115_000009.jpg
├── B10_20221115_000010.jpg
├── B1_20220715_000000.jpg
├── B1_20220715_000001.jpg
├── B1_20220715_000002.jpg
├── B1_20220715_000003.jpg
├── B1_20220715_000004.jpg
├── B2_20220721_000000.jpg
├── B2_20220721_000002.jpg
├── B2_20220721_000007.jpg
├── B2_20220721_000010.jpg
├── B2_20220721_000019.jpg
├── B3_20220823_000002.jpg
├── B3_20220823_000003.jpg
├── B3_20220823_000004.jpg
├── B3_20220823_000006.jpg
├── B3_20220823_000007.jpg
├── B4_20220916_000000.jpg
├── B4_20220916_000002.jpg
├── B4_20220916_000005.jpg
├── B4_20220916_000006.jpg
├── B4_20220916_000007.jpg
├── B5_20220926_000011.jpg
├── B5_20220926_000012.jpg
├── B5_20220926_000013.jpg
├── B5_20220926_000014.jpg
├── B5_20220926_000015.jpg
├── B6_20221017_000005.jpg
├── B6_20221017_000011.jpg
├── B6_20221017_000012.jpg
├── B6_20221017_000015.jpg
├── B6_20221017_000018.jpg
├── B7_20221020_005837.jpg
├── B7_20221020_005841.jpg
├── B7_20221020_005846.jpg
├── B7_20221020_005849.jpg
├── B7_20221020_005853.jpg
├── B8_20221101_000054.jpg
├── B8_20221101_000060.jpg
├── B8_20221101_000063.jpg
├── B8_20221101_000065.jpg
├── B8_20221101_000066.jpg
├── B9_20221104_000019.jpg
├── B9_20221104_000024.jpg
├── B9_20221104_000026.jpg
├── B9_20221104_000028.jpg
├── B9_20221104_000030.jpg
├── C_20220715_000000.jpg
├── C_20220715_000005.jpg
├── C_20220715_000009.jpg
├── C_20220715_000010.jpg
├── C_20220715_000011.jpg
├── C_20220715_000012.jpg
├── C_20220715_000017.jpg
├── C_20220715_000018.jpg
├── C_20220715_000019.jpg
├── C_20220715_000023.jpg
└── metadata.csv

0 directories, 106 files

 

metadata.csv 파일 형태는 아래와 같다.

 

스크린샷 2024-10-31 오후 3 53 17



위와 같은 파일 디렉토리를 datasets 패키지로 불러와서 학습에 사용하려면, 아래 코드를 실행한다.

 

from datasets import load_dataset, load_from_disk

dataset = load_dataset("imagefolder", data_dir="data")

학습에 사용한 코드 전체

 

전반적인 파인튜닝 코드는 reference에 적힌 유튜브 영상의 소스코드를 거의 그대로 사용했다.

파라미터 튜닝해가며 학습 성능을 비교할 만큼 GPU 리소스가 충분하지 않았기 때문.

 

'아이디어톤' 제출 코드라서, 말 그대로 PoC처럼 학습이 가능하다는 것만 증명해도 충분하다고 생각했다.

 

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

 

colab 인스턴스의 제한시간이 짧아서, 제대로 된 학습을 완전히 수행하기는 어려웠음.

Reference

반응형