2024. 6. 26. 22:57ㆍDL
기존에 LCEL기반의 LLM pipeline에서 Agent를 활용하는 방식으로 전환을 하고 있었습니다. 그 이유는 아래와 같습니다.
1. Retriever를 구조화하여 사용하기 힘들다.
2. 복잡한 형태, 크로스 도매인을 가진 질문에 대하여 답변하기 더 용이하다.
하지만 단점도 있습니다.
1. 답변을 생성하기 위한 과정이 추가되다보니 답변이 느리다.
2. LLM의 가끔의 멍청함으로 포맷을 정확히 지키지 못해, 추가 Token발생으로 비용이 쪼~금 더 나갈때가 있다.
그러나 이 단점중 2번째인 Formating 문제가 Claude 3.5 Sonnet이 출시되며 말끔하게 해결되어 Agent 전환을 하고 있는데...
Agent의 답변은 여러 단계를 거쳐 마지막 Final Answer가 출력되는 시점에 일괄 출력되도록 되어 있더군요 이러면.. Stream으로 호출해도, Stream_log로 호출을 해도 각 Agent의 진행 단계별 json을 떨굴뿐 우리가 기존 LCEL로만 구성한 Chain형태의 LLM pipeline에서처럼 글자, Token단위의 Streaming출력은 기대할 수 없었습니다.
그래서 여러 방법을 고안하던중 그나마.. 현실적인 수준에서 해결한 방법을 공유하고자 합니다.
1. LLM 객체에 직접적으로 CallBack함수 추가
바로 코드부터 보겠습니다.
from langchain_community.chat_models.bedrock import BedrockChat
import asyncio
from langchain.callbacks.base import BaseCallbackHandler
class StreamingCallback(BaseCallbackHandler):
def __init__(self):
self.queue = asyncio.Queue()
async def on_llm_new_token(self, token, **kwargs):
await self.queue.put(token)
bedrock_client = boto3_session.client(
service_name="bedrock-runtime",
region_name="ap-southeast-2"
)
claude_sonnet_llm = BedrockChat(
client=bedrock_second_client,
model_id="anthropic.claude-3-5-sonnet-20240620-v1:0",
model_kwargs={'temperature': 0},
streaming=True,
verbose=True,
callbacks=[StreamingCallback()]
다음과 같이 구성합니다. BaseCallbackHandler를 상속받은 StreamingCallback을 생성합니다.
그리고 내부에 초기화 함수에 비동기 Queue를 하나 생성하고 아래 비동기 함수로 토큰이 생성될 때마다. Queue에 데이터를 넣도록 합니다.
이렇게 설정하면, 새로운 토큰이 생성될 때 마다 Queue에 데이터가 쌓이게 되겠군요?
2. 이 CallBack에 저장된 LLM Token streaming을 Langserve가 아니라 FastAPI의 자체 Post로 구현한다.
마찬가지로 바로 코드부터 보겠습니다. 이미 Langserve를 사용한 순간부터 FastAPI는 선언 돼 있을거구요 이 선언된 FastAPI로 그냥 Post형태로 구성하도록 하겠습니다.
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from Chains.TextAgentV2 import agnet_pipeline
from ChainModule.BedrockCaht import StreamingCallback
import asyncio
import json
class QuestionInput(BaseModel):
question: str
chat_history: List[Tuple[str, str]] = Field(default_factory=list)
model_config = {
"json_schema_extra": {
"examples": [
{
"question": "와우 친구들 빡빡이 아저씨야!",
"chat_history": [
("이것은 질문입니다.", "이것은 답변입니다."),
("blahblah", "salasala")
]
}
]
}
}
app = FastAPI(
title="LangChain Server",
version="1.0",
description="Spin up a simple api server using Langchain's Runnable interfaces",
)
# Add filter to the logger
logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
@app.post("/evpedia_chat")
async def stream_agent(input_data: QuestionInput):
try:
question = input_data.question
chat_history = input_data.chat_history
callback_handler = StreamingCallback()
async def run_agent():
try:
await agnet_pipeline.ainvoke(
{"chat_history": chat_history, "question": question},
config={"callbacks": [callback_handler]}
)
except Exception as e:
# 오류 발생 시 오류 메시지를 스트림에 추가
await callback_handler.queue.put(f"Error: {str(e)}")
finally:
# 스트림 종료 신호 전송
await callback_handler.queue.put(None)
# run_agent를 태스크로 생성하고 참조 유지
task = asyncio.create_task(run_agent())
async def stream_tokens_with_error_handling():
try:
async for token in stream_tokens(callback_handler):
yield token
except Exception as e:
# 스트리밍 중 발생한 오류 처리
yield f"data: Error during streaming: {str(e)}\n\n"
finally:
# 태스크가 완료되지 않았다면 취소
if not task.done():
task.cancel()
return StreamingResponse(stream_tokens_with_error_handling(), media_type="text/event-stream")
except Exception as e:
# 요청 처리 중 발생한 전반적인 오류 처리
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
@app.get("/ping")
async def root():
return "pong"
아까 생성한 StreamingCallback을 여기서도 선언합니다. 그리고 Post를 구성하기 위한 pydatic Type Checker와 FastApi 객체를 선언합니다.
그리고 post로 사용할 함수를 선언하고 내부를 위 코드처럼 구성합니다. 여기서 가장 중요한점은 Generator를 사용하여 CallbackHandler의 Queue로 부터 데이터를 생성하는것 입니다. 이 과정에서 데이터의 원본 형태를 유지하기 위해서 데이터를 json형태로 변환하여 내리도록 합니다. (이러지 않으면 데이터가 깨져서 문제가 생깁니다.)
그리고는 FastAPI에서 지원하는 StreamingResponse를 사용하여 비동기 Generator를 실행시킵니다.
3. Client
import asyncio
import aiohttp
from IPython.display import clear_output
import json
token_save = []
async def stream_agent_response(chat_history, question):
url = "http://localhost:8000/agent_chat"
payload = {
"chat_history": chat_history,
"question": question
}
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status != 200:
print(f"Error: {response.status}")
print(await response.text())
return
full_response = ""
async for line in response.content:
if line:
line = line.decode('utf-8').strip()
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
token_data = json.loads(data)
token_save.append(token_data)
full_response += token_data['token']
clear_output(wait=True)
print(full_response, end="", flush=True)
print()
return full_response
async def main():
chat_history = []
while True:
question = input("\nAsk a question about electric vehicles (or type 'exit' to quit): ")
if question.lower() == 'exit':
break
response = await stream_agent_response(chat_history, question)
if response:
if "Error: " not in response:
chat_history.append((question, response))
else:
pass
# Jupyter Notebook에서 실행하기 위한 코드
import nest_asyncio
nest_asyncio.apply()
# 실행
asyncio.run(main())
Client는 다음과 같이 구성하고 실행해보면 아주 잘 떨어지는것을 확인 할 수 있습니다.
'DL' 카테고리의 다른 글
[LLM] 프롬프트를 활용하며 느낀점 & 효과가 있던 방법 1편 (0) | 2024.08.17 |
---|---|
[LLM] Langchain ReAct Agent에 DALL-E tool 추가하기 (0) | 2024.08.02 |
[LLM] Claude 3.5 Sonnet 출시 개인적인 느낀점 (0) | 2024.06.22 |
[LLM] Data Splitters (with Langchain) (0) | 2024.06.22 |
[LLM] RAG란 무엇인가? (0) | 2024.06.20 |