LangChainでBigQueryデータを使ったグラウンディングを実装してみた

記事タイトルとURLをコピーする

G-gen 又吉です。当記事では、Google Cloud の LLM (Vertex AI PaLM API) と LangChain を組み合わせて、自然言語から BigQuery 上の統計データを取得する方法を紹介します。

はじめに

LangChain とは、大規模言語モデル (LLM) アプリケーションを効率よく実装するためのフレームワークです。LangChain についての詳細は以下の記事をご参照ください。

blog.g-gen.co.jp

LangChain にはさまざまな機能が提供されておりますが、今回は Agents 機能を用いて SQL データベースと対話するエージェントを作成します。

LLM のハルシネーション (幻覚) を抑制する手法として、指定した情報源だけに基づいて LLM に回答を生成させることをグラウンディングといいます。

最近では、テキストや画像をエンべディング (ベクトル化) する技術とベクトル検索の技術を組み合わせて RAG (Retrieval Augmented Generation) を行うグラウンディング手法が人気となっています。

しかし、自社のデータベースに保存された構造化データに対し、LLM が自然言語を理解して SQL を生成し、最終的に分析結果 (統計データ) を取得して回答を生成させたい場合にはどうしたらよいでしょうか。

例えば「先月の売上を教えて」というユーザーの入力に対し、LLM がデータベースに対して適切な SQL を発行し、先月の売上を算出して回答を生成するなどです。この場合も、指定した情報源(自社の売上データ)を基に回答を生成しているため、グラウンディングにあたります。

今回は、この構成を LangChain を用いて実装してみたいと思います。

今回の構成図

準備

実行環境

当記事では、Colab Enterprise の Notebook を使用します。Colab Enterprise は、マネージドな Notebook のためインフラストラクチャを管理せず実装に注力できます。

Colab Enterprise の Notebook 作成方法は公式ドキュメントのクイックスタートをご参考下さい。

cloud.google.com

使用するデータ

今回使用するデータは、架空の販売データ (sample-data) を使用します。

尚、BigQuery テーブルの作成は省略しますが、作成後のテーブルプレビューは以下になります。

sales_data テーブル

実装

ライブラリのインストール

Notebook が立ち上がり、ランタイムと接続できましたら以下のコードを実行してライブラリのインストールとインポートを行います。

# input:[1]
!pip install  sqlalchemy sqlalchemy-bigquery google-cloud-bigquery langchain

ライブラリのインストールができたら、ランタイムを再起動してから以下を実行します。

# input:[2]
from google.cloud import bigquery
from sqlalchemy import *
from sqlalchemy.engine import create_engine
from sqlalchemy.schema import *
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.llms import VertexAI
from langchain.agents import AgentExecutor

関数の定義

概要

ユーザーの入力 (自然言語) を引数に、BigQuery 上のデータから分析結果を出力する関数を定義します。

# input:[3]
project = ${PROJECT_ID}
dataset = ${DATASET_ID}
sqlalchemy_url = f"bigquery://{project}/{dataset}"
  
# SQLAlchemy エンジンを初期化
db = SQLDatabase.from_uri(database_uri=sqlalchemy_url)
  
# Vertex AI 基盤モデルを初期化
llm = VertexAI(
    model_name="text-bison@001",
    max_output_tokens=1024,
    temperature=0.1,
    top_p=0.8,
    top_k=40,
    verbose=True,
)
  
# SQL データベースと対話するエージェントの初期化
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    top_k=10,
)
  
def run_query(query):
    res = agent_executor.run(query)
    return res

解説

SQLAlchemy エンジンを初期化

8 行目で SQLAlchemy エンジンを初期化しています。

SQLDatabase オブジェクトの from_uri メソッドで URI から SQLAlchemy エンジンを初期化できます。

URI の記法については、Python-bigquery-sqlalchemy のドキュメントをご参照ください。また、BigQuery への認証情報については、google-api-core のドキュメントをご参照ください。

SQL データベースと対話するエージェントの初期化

20-26 行目でSQL データベースと対話するエージェントの初期化しています。

SQLDatabaseToolkit オブジェクトは、SQL データベースと対話するためのツールキットであり、create_sql_agent でエージェントを初期化しています。

このエージェントによって、ユーザーの入力から「次にどのようなアクションを取るか」を考え、実行、観測して答えがでるまで繰り返します。今回の場合は、ユーザーの入力に対して、SQL を生成して BigQuery にジョブを実行する操作を、最終的な結果を得るまで繰り返してくれます。

実行

それでは、実行してみましょう。

# input:[4]
run_query("プロダクト別の売上の上位10位を教えてください。")
# output:[4]
> Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input: 
Observation: sales_data
Thought:I found the table sales_data. I should query the schema of this table to see what columns I can use in my query.
Action: sql_db_schema
Action Input: sales_data
Observation: 
CREATE TABLE `sales_data` (
   `ctm_id` STRING, 
   `ctm_name` STRING, 
   `ctm_gender` STRING, 
   `ctm_birthdate` DATE, 
   `ctm_age` INT64, 
   `ctm_postcode` STRING, 
   `ctm_prefecture` STRING, 
   `ctm_prefecture_code` INT64, 
   `ctm_phone` STRING, 
   `ctm_email` STRING, 
   `ctm_ip_address` STRING, 
   `sls_date` DATE, 
   `sls_product` STRING, 
   `sls_amount` INT64
)

/*
3 rows from sales_data table:
ctm_id  ctm_name    ctm_gender  ctm_birthdate   ctm_age ctm_postcode    ctm_prefecture  ctm_prefecture_code ctm_phone   ctm_email   ctm_ip_address  sls_date    sls_product sls_amount
C1000000044 矢部 剛  男 1996-07-07  25 134-5508    北海道   1  080-5565-3368   GCK5pwtZy@test.jp 168.52.65.10    2021-02-28  Chromebook  3631
C1000000044 矢部 剛  男 1996-07-07  25 134-5508    北海道   1  080-5565-3368   GCK5pwtZy@test.jp 168.52.65.10    2020-02-28  Chromebook  3631
C1000000044 矢部 剛  男 1996-07-07  25 134-5508    北海道   1  080-5565-3368   GCK5pwtZy@test.jp 168.52.65.10    2022-02-28  Chromebook  3631
*/
Thought:The table sales_data has the columns sls_product and sls_amount. I can use these columns to query the product and amount of sales.
Action: sql_db_query
Action Input: SELECT sls_product, SUM(sls_amount) FROM sales_data GROUP BY sls_product ORDER BY SUM(sls_amount) DESC LIMIT 10
Observation: [('Chromebook', 19160341), ('Mac', 11702719), ('マウス', 8605347), ('iPhone', 8178375), ('Windows', 5620327), ('HDD', 5419557), ('Android', 4475812), ('キーボード', 4391927), ('SSD', 3995346), ('充電器', 2162517)]
Thought:I now know the final answer
Final Answer: The top 10 products by sales are: Chromebook, Mac, マウス, iPhone, Windows, HDD, Android, キーボード, SSD, 充電器.

> Finished chain.
The top 10 products by sales are: Chromebook, Mac, マウス, iPhone, Windows, HDD, Android, キーボード, SSD, 充電器.

処理の流れを一部見てみます。

  • [3 行目] Action : BigQuery データセットのテーブル一覧を取得
  • [5 行目] Observation : sales_data テーブルを検出
  • [6 行目] Action : sales_data テーブルのスキーマを取得して列を確認する
  • [33 行目] Thought : sales_data テーブルには、sls_product 列と sls_amount 列があり、これらの列を使用してプロダクトと売上高をクエリできます。
  • [35 行目] Action Input : SELECT sls_product, SUM(sls_amount) FROM sales_data GROUP BY sls_product ORDER BY SUM(sls_amount) DESC LIMIT 10

上記のように、ユーザーの入力に対して、「次にどのようなアクションを取るか」とエージェントが考えて実行しています。

そして最終的なアプトプットが Finished chain として 40 行目に出力されてます。なお、統計結果は正しいものでした。

せっかくなので、別のクエリも試してみます。

# input:[5]
run_query("2020年9月1日から2020年9月30日までの売上合計を教えてください。")
# output:[5]
> Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input: 
Observation: sales_data
Thought:I found the table sales_data. I should query the schema of this table to see what columns I can use in my query.
Action: sql_db_schema
Action Input: sales_data
Observation: 
CREATE TABLE `sales_data` (
   `ctm_id` STRING, 
   `ctm_name` STRING, 
   `ctm_gender` STRING, 
   `ctm_birthdate` DATE, 
   `ctm_age` INT64, 
   `ctm_postcode` STRING, 
   `ctm_prefecture` STRING, 
   `ctm_prefecture_code` INT64, 
   `ctm_phone` STRING, 
   `ctm_email` STRING, 
   `ctm_ip_address` STRING, 
   `sls_date` DATE, 
   `sls_product` STRING, 
   `sls_amount` INT64
)

/*
3 rows from sales_data table:
ctm_id  ctm_name    ctm_gender  ctm_birthdate   ctm_age ctm_postcode    ctm_prefecture  ctm_prefecture_code ctm_phone   ctm_email   ctm_ip_address  sls_date    sls_product sls_amount
C1000000044 矢部 剛  男 1996-07-07  25 134-5508    北海道   1  080-5565-3368   GCK5pwtZy@test.jp 168.52.65.10    2021-02-28  Chromebook  3631
C1000000044 矢部 剛  男 1996-07-07  25 134-5508    北海道   1  080-5565-3368   GCK5pwtZy@test.jp 168.52.65.10    2020-02-28  Chromebook  3631
C1000000044 矢部 剛  男 1996-07-07  25 134-5508    北海道   1  080-5565-3368   GCK5pwtZy@test.jp 168.52.65.10    2022-02-28  Chromebook  3631
*/
Thought:The table sales_data has the columns sls_date and sls_amount. I can use these columns to query the total sales for the period 2020-09-01 to 2020-09-30.
Action: sql_db_query
Action Input: SELECT SUM(sls_amount) FROM sales_data WHERE sls_date BETWEEN '2020-09-01' AND '2020-09-30'
Observation: [(997305,)]
Thought:I now know the final answer
Final Answer: The total sales for the period 2020-09-01 to 2020-09-30 is 997305.

> Finished chain.
The total sales for the period 2020-09-01 to 2020-09-30 is 997305.

こちらも最終的なアプトプットは正しいものでした。

このような LangChain を用いることで、自然言語によるインプットから LLM が自社の構造化データから統計的な分析結果を返してくれるグラウンディング手法も容易に実装ができますね。

又吉 佑樹(記事一覧)

クラウドソリューション部

はいさい、沖縄出身のクラウドエンジニア!

セールスからエンジニアへ転身。Google Cloud 全 11 資格保有。Google Cloud Champion Innovator (AI/ML)。Google Cloud Partner Top Engineer 2024。Google Cloud 公式ユーザー会 Jagu'e'r でエバンジェリスト。好きな分野は生成 AI。