반응형
https://www.youtube.com/live/vxF8-ay9Bzk?si=n9uDkQSpvdr1gkbJ
맨 앞부분의 langchain 소개 내용, langchain으로 단순히 구글 LLM 써서 결과 얻는 실습은 내용에서 제외함.
multiModal 관련 내용만 정리.
Vertex AI Integration with LangChain
대충 우리 구글모델도 langchain에서 쉽게 쓸 수 있다는 내용들.
embedding, vectorstore search도 된다.
다양한 종류의 input을 받을 수 있는 multiModal retriever도 가능함.
MultiModal RAG Google
사전 세팅
from langchain_google_vertexai import VertexAI, ChatVertexAI, VertexAIEmbeddings
PROJECT_ID = "project_id"
REGION = 'region'
import vertexai
vertexai.init(project = PROJECT_ID, location = REGION)
Data Loading
import logging
import zipfile
import requests
logging.basicConfig(level=logging.INFO)
data_url = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/cj.zip"
result = requests.get(data_url)
filename = "cj.zip"
with open(filename, 'wb') as file:
file.write(result.content)
with zipfile.ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall()
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader("./cj/cj.pdf")
doc = loader.load()
tables = []
texts = [d.page_content for d in docs]
len(texts) # returns 21
Multi-Vector Retrieval
from langchain.prompts import PromptTemplate
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
def generate_text_summaries(texts, tables, summarize_texts=False):
"""
Summarize text elements
texts: List of str
tables: List of str
summarize_texts: Bool to summarize texts
"""
# Prompt
prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
These summaries will be embedded and used to retrieve the raw text or table elements. \
Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element}
"""
prompt = PromptTemplate.from_template(promp_text)
empty_response = RunnableLambda(
lambda x: AIMessage(content="Error processing document")
)
model = VertexAI(temperature=0, model_name="gemini-pro", max_output_token=1024).with_fallbacks([empty_response])
summarize_chain = {"element":lambda x: x} | prompt | model | StrOutputParser()
## init empty summaries
text_summaries = []
table_summaries = []
if texts and summarize_texts:
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})
elif texts:
text_summaries = texts
if tables:
table_summaries = summarize_chain.batch(tables, {"max_concurrency":1})
return text_summaries, table_summaries
text_summaries, table_summaries = generate_text_summaries(
texts, tables, summarize_texts=True
)
print(len(text_summaries)) # return 21
import base64
import os
from langchain_core.messages import HumanMessage
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def image_summarize(img_base64, prompt):
model = ChatVertexAI(model_name="gemini-pro-vision", max_output_token=1024)
msg = model(
HumanMessage(
content=[
{"type":"text","text":prompt},
{
"type":"image_url",
"image_url":{"url":f"data:image/jpeg;base64,{img_base64}"},
},
]
)
)
return msg.content
def generate_img_summarize(path):
"""
Generate summarize and base64 encoded strings for images
path: Path to list of .jpg files extracted by Unstructured
"""
img_base64_list = []
image_summarize = []
prompt = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval. """
# apply to image
for img_file in sorted(os.listdir(path)):
if img_file.endswith(".jpg"):
img_path = os.path.join(path, img_file)
base64_image = encode_image(img_path)
img_base64_list.append(base64_image)
image_summaries.append(image_summarize(base64_image, prompt))
return img_base64_list, image_summaries
# image summaries
img_base64_list, image_summaries = generate_img_summaries("./cj")
Vector Store에 저장
import uuid
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
def create_multi_vector_retriever(
vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, image
):
"""
Create retriever that indexes summaries, but returns raw images or texts.
"""
store = InMemoryStore()
id_key = "doc_id"
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=docstore,
id_key=id_key,
)
# helper function to add documents to the vectorstore and docstore
def add_documents(retriever, doc_summaries, doc_contents):
doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
summary_docs = [
Document(page_content=s, metadata={id_key:doc_ids[i]}) for i, s in enumerate(doc_summaries)
]
retriever.vectorstore.add_documents(summary_docs)
retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
# add texts, tables, images
if text_summaries:
add_documents(retriever, text_summaries, texts)
if table_summaries:
add_documents(retriever, table_summaries, tables)
if image_summaries:
add_documents(retriever, image_summaries, images)
return retriever
# the vectorstore to use the index the summaries.
# embedding model의 성능이 결국 RAG의 성능을 좌우한다. vertexAIEmbedding 클래스에서도 다양한 SOTA embedding 모델을 지원한다고 함
vectorstores = Chroma(
collection_name="mm_rag_cj_blog",
embedding_function=VertexAIEmbeddings(model_name="textembedding-gecko@latest")
)
# create retriever
retriever_multi_vector_img = create_multi_vector_retriever(
vectorstores,
text_summaries,
texts,
table_summaries,
tables,
image_summaries,
img_base64_list,
)
Building a RAG
import io
import re
from IPython.display import HTML, display
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from PIL import Image
def plt_img_base64(img_base64):
"""Display base64 encoded string as image"""
image_html=f'<img src="data:image/jpeg;base64,{img_base64}" />'
display(HTML(image_html))
def looks_like_base64(sb):
"""check if the string looks like base64"""
return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
def is_image_data(b64data):
"""check if the base64 data is an image by looking at the start of the data"""
image_signatures = {
b"\xFF\xD8\xFF:": "jpg"
b"\x89\x50\x4E\x0D\x0A\x1A\x0A": "png",
b"\x47\x49\x46\x38": "gif",
b"\x52\x49\x46\x46": "webp",
}
try:
header = base64.b64decode(b64data)[:8] # decode and get the first 8bytes
for sig, format in image_signatures.items():
if header.startswith(sig):
return True
return False
except Exception:
return False
def resize_base64_image(base64_string, size=(128, 128)):
"""Resize an image encoded as a Base64 string"""
# Decode base64 string
img_data = base64.b64decode(base64_string)
img = Image.open(io.BytesIO(img_data))
# resize the image
resized_img = img.resize(size, Image.LANCZOS)
buffered = io.BytesIO()
resized_img.save(buffered, format=img.format)
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def split_image_text_types(docs):
"""Split base64-encoded images and texts"""
b64_images = []
texts = []
for doc in docs:
if isinstance(doc, Document):
doc = doc.page_content
if looks_like_base64(doc) and is_image_data(doc):
doc = resize_base64_image(doc, size=(1300, 600))
b64_images.append(doc)
else:
texts.append(docs)
if len(b64_images) > 0:
return {"images":b64_images[:1], "texts": {}}
return {"images":b64_images, "texts" : texts }
def img_prompt_func(data_dict):
"""Join the context into a single string"""
formatted_texts = "\n".join(data_dict['context']['texts'])
messages = []
text_message = {
"type":"text",
"text": (
"You are financial analyst taking with providing investment advice.\n"
"You will be given a mixed of texts, tables, and image(s) usually of charts or graphs. \n"
"Use this information to provide investment advice related to the user question. \n"
f"User-provided question: {data_dict['question']}\n\n"
"Text and / or tables:\n"
f"{formatted_texts}"
),
}
messages.append(text_message)
if data_dict['context']['images']:
for image in data_dict['context']['images']:
image_message = {
"type": "image_url",
"image_url":{"url":f"data:image/jpeg;base64,{image}"}
}
messages.append(image_message)
return [HumanMessage(content=messages)]
def multi_modal_rag_chain(retriever):
"""Multi-Modal RAG Chain"""
# multi-modal LLM
model = ChatVertexAI(
temperature=0, model_name="gemini-pro-vision", max_output_tokens=1024
)
# rag pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
| model
| StrOutputParser()
)
return chain
## create RAG chain
chain_multimmodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)
query = "What are the EV / NTM and NTM rev growth for MongoDB, CloudFlare, and DataDog?"
docs = retriever_multi_vector_img.get_relevant_documents(query, limit=1)
print(len(docs)) # return 4
- docs 응답을 보면 text도 있지만, base64 이미지도 있는 걸 볼 수 있음.
- image 형태로 표시하면 위와 같다. query에 부합하는 이미지가 나온 걸 확인할 수 있음.
result = chain_multimodal_rag.invoke(query)
from IPython.display import Markdown as md
md(result)
result = chain_multimodal_rag.invoke("what are the EV / NTM rev growth for Adobe?")
md(result) # return "Adobe's EV/NTM is 12.3x and NTM Rev Growth is 12%
반응형