공부하고 기록하는, 경제학과 출신 개발자의 노트

학습일지/AI

Gemini Multimodal RAG Applications with LangChain

inspirit941 2024. 5. 10. 18:09
반응형

https://www.youtube.com/live/vxF8-ay9Bzk?si=n9uDkQSpvdr1gkbJ

 

맨 앞부분의 langchain 소개 내용, langchain으로 단순히 구글 LLM 써서 결과 얻는 실습은 내용에서 제외함.
multiModal 관련 내용만 정리.

 

 


Vertex AI Integration with LangChain

스크린샷 2024-05-09 오후 2 50 12스크린샷 2024-05-09 오후 3 03 25

 

대충 우리 구글모델도 langchain에서 쉽게 쓸 수 있다는 내용들.

스크린샷 2024-05-09 오후 5 08 10스크린샷 2024-05-09 오후 5 08 45스크린샷 2024-05-09 오후 5 09 41스크린샷 2024-05-09 오후 5 10 30

 

embedding, vectorstore search도 된다.

스크린샷 2024-05-09 오후 5 10 44

 

다양한 종류의 input을 받을 수 있는 multiModal retriever도 가능함.

MultiModal RAG Google

사전 세팅

from langchain_google_vertexai import VertexAI, ChatVertexAI, VertexAIEmbeddings
PROJECT_ID = "project_id"
REGION = 'region'

import vertexai
vertexai.init(project = PROJECT_ID, location = REGION)

Data Loading

import logging
import zipfile
import requests

logging.basicConfig(level=logging.INFO)
data_url = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/cj.zip"
result = requests.get(data_url)
filename = "cj.zip"
with open(filename, 'wb') as file:
  file.write(result.content)

with zipfile.ZipFile(filename, 'r') as zip_ref:
  zip_ref.extractall()
from langchain_community.document_loaders import PyPDFLoader

loader = PyPDFLoader("./cj/cj.pdf")
doc = loader.load()
tables = []
texts = [d.page_content for d in docs]
len(texts) # returns 21

Multi-Vector Retrieval

from langchain.prompts import PromptTemplate
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda

def generate_text_summaries(texts, tables, summarize_texts=False):
  """
  Summarize text elements
  texts: List of str
  tables: List of str
  summarize_texts: Bool to summarize texts
  """

  # Prompt
  prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
  These summaries will be embedded and used to retrieve the raw text or table elements. \
  Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element}
  """
  prompt = PromptTemplate.from_template(promp_text)
  empty_response = RunnableLambda(
    lambda x: AIMessage(content="Error processing document")
  )
  model = VertexAI(temperature=0, model_name="gemini-pro", max_output_token=1024).with_fallbacks([empty_response])
  summarize_chain = {"element":lambda x: x} | prompt | model | StrOutputParser()

  ## init empty summaries
  text_summaries = []
  table_summaries = []
  if texts and summarize_texts:
    text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})
  elif texts:
    text_summaries = texts

  if tables:
    table_summaries = summarize_chain.batch(tables, {"max_concurrency":1})
  return text_summaries, table_summaries

text_summaries, table_summaries = generate_text_summaries(
  texts, tables, summarize_texts=True
)

print(len(text_summaries)) # return 21
import base64
import os

from langchain_core.messages import HumanMessage

def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')

def image_summarize(img_base64, prompt):
  model = ChatVertexAI(model_name="gemini-pro-vision", max_output_token=1024)
  msg = model(
    HumanMessage(
      content=[
        {"type":"text","text":prompt},
        {
          "type":"image_url",
          "image_url":{"url":f"data:image/jpeg;base64,{img_base64}"},
        },
      ]
    )
  )
  return msg.content

def generate_img_summarize(path):
  """
  Generate summarize and base64 encoded strings for images
  path: Path to list of .jpg files extracted by Unstructured
  """

  img_base64_list = []
  image_summarize = []

  prompt = """You are an assistant tasked with summarizing images for retrieval. \
  These summaries will be embedded and used to retrieve the raw image. \
  Give a concise summary of the image that is well optimized for retrieval. """

  # apply to image
  for img_file in sorted(os.listdir(path)):
    if img_file.endswith(".jpg"):
      img_path = os.path.join(path, img_file)
      base64_image = encode_image(img_path)
      img_base64_list.append(base64_image)
      image_summaries.append(image_summarize(base64_image, prompt))

  return img_base64_list, image_summaries

# image summaries
img_base64_list, image_summaries = generate_img_summaries("./cj")

Vector Store에 저장

import uuid

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document

def create_multi_vector_retriever(
  vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, image
):
  """
  Create retriever that indexes summaries, but returns raw images or texts.
  """

  store = InMemoryStore()
  id_key = "doc_id"

  retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=docstore,
    id_key=id_key,
  )

  # helper function to add documents to the vectorstore and docstore
  def add_documents(retriever, doc_summaries, doc_contents):
    doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
    summary_docs = [
      Document(page_content=s, metadata={id_key:doc_ids[i]}) for i, s in enumerate(doc_summaries)
    ]
    retriever.vectorstore.add_documents(summary_docs)
    retriever.docstore.mset(list(zip(doc_ids, doc_contents)))

  # add texts, tables, images
  if text_summaries:
    add_documents(retriever, text_summaries, texts)
  if table_summaries:
    add_documents(retriever, table_summaries, tables)
  if image_summaries:
    add_documents(retriever, image_summaries, images)

  return retriever

# the vectorstore to use the index the summaries.
# embedding model의 성능이 결국 RAG의 성능을 좌우한다. vertexAIEmbedding 클래스에서도 다양한 SOTA embedding 모델을 지원한다고 함
vectorstores = Chroma(
  collection_name="mm_rag_cj_blog",
  embedding_function=VertexAIEmbeddings(model_name="textembedding-gecko@latest")
)

# create retriever
retriever_multi_vector_img = create_multi_vector_retriever(
  vectorstores,
  text_summaries,
  texts,
  table_summaries,
  tables,
  image_summaries,
  img_base64_list,
)

Building a RAG

import io
import re

from IPython.display import HTML, display
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from PIL import Image


def plt_img_base64(img_base64):
  """Display base64 encoded string as image"""
  image_html=f'<img src="data:image/jpeg;base64,{img_base64}" />'
  display(HTML(image_html))

def looks_like_base64(sb):
  """check if the string looks like base64"""
  return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None

def is_image_data(b64data):
  """check if the base64 data is an image by looking at the start of the data"""
  image_signatures = {
    b"\xFF\xD8\xFF:": "jpg"
    b"\x89\x50\x4E\x0D\x0A\x1A\x0A": "png",
    b"\x47\x49\x46\x38": "gif",
    b"\x52\x49\x46\x46": "webp",
  }
  try:
    header = base64.b64decode(b64data)[:8] # decode and get the first 8bytes
    for sig, format in image_signatures.items():
      if header.startswith(sig):
          return True
      return False
  except Exception:
    return False

def resize_base64_image(base64_string, size=(128, 128)):
  """Resize an image encoded as a Base64 string"""
  # Decode base64 string
  img_data = base64.b64decode(base64_string)
  img = Image.open(io.BytesIO(img_data))

  # resize the image
  resized_img = img.resize(size, Image.LANCZOS)

  buffered = io.BytesIO()
  resized_img.save(buffered, format=img.format)
  return base64.b64encode(buffered.getvalue()).decode("utf-8")


def split_image_text_types(docs):
  """Split base64-encoded images and texts"""
  b64_images = []
  texts = []
  for doc in docs:
    if isinstance(doc, Document):
      doc = doc.page_content
    if looks_like_base64(doc) and is_image_data(doc):
      doc = resize_base64_image(doc, size=(1300, 600))
      b64_images.append(doc)
    else:
      texts.append(docs)
  if len(b64_images) > 0:
    return {"images":b64_images[:1], "texts": {}}
  return {"images":b64_images, "texts" : texts }

def img_prompt_func(data_dict):
  """Join the context into a single string"""
  formatted_texts = "\n".join(data_dict['context']['texts'])
  messages = []

  text_message = {
    "type":"text",
    "text": (
      "You are financial analyst taking with providing investment advice.\n"
      "You will be given a mixed of texts, tables, and image(s) usually of charts or graphs. \n"
      "Use this information to provide investment advice related to the user question. \n"
      f"User-provided question: {data_dict['question']}\n\n"
      "Text and / or tables:\n"
      f"{formatted_texts}"
    ),
  }
  messages.append(text_message)
  if data_dict['context']['images']:
    for image in data_dict['context']['images']:
      image_message = {
        "type": "image_url",
        "image_url":{"url":f"data:image/jpeg;base64,{image}"}
      }
      messages.append(image_message)
  return [HumanMessage(content=messages)]

def multi_modal_rag_chain(retriever):
  """Multi-Modal RAG Chain"""

  # multi-modal LLM
  model = ChatVertexAI(
    temperature=0, model_name="gemini-pro-vision", max_output_tokens=1024
  )

  # rag pipeline
  chain = (
    {
      "context": retriever | RunnableLambda(split_image_text_types),
      "question": RunnablePassthrough(),
    }
    | RunnableLambda(img_prompt_func)
    | model
    | StrOutputParser()
  )
  return chain

## create RAG chain
chain_multimmodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)
query = "What are the EV / NTM and NTM rev growth for MongoDB, CloudFlare, and DataDog?"
docs = retriever_multi_vector_img.get_relevant_documents(query, limit=1)

print(len(docs)) # return 4

스크린샷 2024-05-10 오후 4 04 58스크린샷 2024-05-10 오후 5 17 04

 

  • docs 응답을 보면 text도 있지만, base64 이미지도 있는 걸 볼 수 있음.
  • image 형태로 표시하면 위와 같다. query에 부합하는 이미지가 나온 걸 확인할 수 있음.
result = chain_multimodal_rag.invoke(query)

from IPython.display import Markdown as md

md(result)

result = chain_multimodal_rag.invoke("what are the EV / NTM rev growth for Adobe?")
md(result) # return "Adobe's EV/NTM is 12.3x and NTM Rev Growth is 12%
반응형