-
Notifications
You must be signed in to change notification settings - Fork 461
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: initializing qa on sql databases with langchain notebook
- Loading branch information
Showing
1 changed file
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,382 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ef942509-6981-40b7-894d-4f6c33607de2", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"<img width=\"10%\" alt=\"Naas\" src=\"https://landen.imgix.net/jtci2pxwjczr/assets/5ice39g4.png?w=160\"/>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ca4c2cd5-6acc-49c1-9336-72eb54c5ec83", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"# Langchain - QA on SQL Database" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "df7190c7-9f84-4ef9-afed-cbeec9189bdf", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"**Tags**: #langchain #toolkit #database #qa #python #sql" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cda2f1ea-c2df-45c5-9329-d832077dadfb", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"**Author:** [Léa Buendé](https://www.linkedin.com/in/leabuende)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6b607685-d026-4f09-8724-013227fcf2c5", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"**Last update:** 2023-10-19 (Created: 2023-10-08) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "09056bc2-135e-46e8-a422-4ef3dfabf41e", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"**Description:** This notebook demonstrates how to perform queries on a SQL Database, as well as how to generate SQL queries from questions or text." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a250dcaf-66c0-40a0-a7f5-c2ff8d50a4c6", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"**References:** \n", | ||
"- [Langchain - SQL](https://python.langchain.com/docs/use_cases/qa_structured/sql)\n", | ||
"- [Langsmith - Text-to-SQL](https://smith.langchain.com/hub/rlm/text-to-sql)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "70cb3b03-1af8-47b0-a299-2c798e4e0452", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"## Input" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "90208573-1f14-4e2b-8900-e304035a2570", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"### Import libraries\n", | ||
"Libraries to be used within this notebook.<br>\n", | ||
"\n", | ||
"Note: You may need to restart the kernel to use updated packages." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "de70f076-2e23-4d17-9e65-d7812366aad5", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"try:\n", | ||
" import langchain\n", | ||
" import langchain_experimental\n", | ||
" import openai\n", | ||
"except: \n", | ||
" !pip install langchain langchain-experimental openai --user\n", | ||
"\n", | ||
"from langchain.utilities import SQLDatabase\n", | ||
"from langchain.llms import OpenAI\n", | ||
"import naas" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0e4b4668-1cfa-4aa8-b453-2c9f863b6f29", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"### Setup variables\n", | ||
"\n", | ||
"- `openai_api_key`: [Get your API key here](https://openai.com/docs/api-overview/).\n", | ||
"- `temp`: Default value is 0 but preferred to have 0.7. You can change this value according to your requirements\n", | ||
"- `database_uri`: URI of your database. We will create one for the purpose of this notebook." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"id": "54158c1d-3ddf-45ed-8989-9f1fd2240b6f", | ||
"metadata": { | ||
"papermill": {}, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"openai_api_key = naas.secret.get(\"OPENAI_API_KEY\") or \"Paste-your-key-here\"\n", | ||
"temp = 0\n", | ||
"sql_uri = \"sqlite:///Chinook.db\" # Replace with the URI of your database" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2f7ab5dc", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Creating a mock database (Optional)\n", | ||
"The Chinook database is a sample database for you to run quick tests, such as trying out this notebook. It represents a digital media store, including tables for artists, albums, media tracks, invoices and customers.\n", | ||
"\n", | ||
"To install the Chinook database :\n", | ||
"- Download the [Chinook.sql file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) in the same directory as this notebook, and save it as Chinook_Sqlite.sql\n", | ||
"- In your terminal, run ```sqlite3 Chinook.db``` to connect to SQLite\n", | ||
"- Then, run ```.read Chinook_Sqlite.sql``` to execute the script from the file\n", | ||
"\n", | ||
"Your database is now initialized." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3ba70fc1", | ||
"metadata": {}, | ||
"source": [ | ||
"## *Text to SQL query*" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9b1db6dd", | ||
"metadata": {}, | ||
"source": [ | ||
"## Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 33, | ||
"id": "86ba23e8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.chat_models import ChatOpenAI\n", | ||
"from langchain.chains import create_sql_query_chain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4dbca1bc", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Build an SQL query using Create_sql_query_chain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f66d1a9b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"db = SQLDatabase.from_uri(sql_uri) \n", | ||
"chain = create_sql_query_chain(ChatOpenAI(temperature=temp, openai_api_key=openai_api_key), db)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4443a2ae", | ||
"metadata": {}, | ||
"source": [ | ||
"## Output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e1eecee9", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Display result" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c3e94abb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"response = chain.invoke({\"question\":\"How many songs are in the Grunge playlist ?\"})\n", | ||
"print(response)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ebfef2e5", | ||
"metadata": {}, | ||
"source": [ | ||
"You can test out the SQL query by running :" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "cfb1d20e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"db.run(response)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a4923704", | ||
"metadata": {}, | ||
"source": [ | ||
"## *Text to SQL query and execution*" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4abf9e57", | ||
"metadata": {}, | ||
"source": [ | ||
"## Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 39, | ||
"id": "df536f67", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_experimental.sql import SQLDatabaseChain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3caef398", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Build and run an SQL query using SQLDatabaseChain " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 40, | ||
"id": "9cc14e45", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"db = SQLDatabase.from_uri(sql_uri) \n", | ||
"llm = OpenAI(temperature=temp, verbose=True, openai_api_key=openai_api_key)\n", | ||
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "20a4e159", | ||
"metadata": {}, | ||
"source": [ | ||
"## Output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9ea4deda", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Display result" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "eb795f63", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"db_chain.run(\"Who is the artist with the most albums ?\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.11.3 64-bit", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.3" | ||
}, | ||
"naas": { | ||
"notebook_id": "658a8244b01dfe532c4faaf8423ce808aca1020f8afdf491d99810b8e1cb4ba1", | ||
"notebook_path": "LangChain/LangChain_CSV_Agent.ipynb" | ||
}, | ||
"papermill": { | ||
"default_parameters": {}, | ||
"environment_variables": {}, | ||
"parameters": {}, | ||
"version": "2.4.0" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" | ||
} | ||
}, | ||
"widgets": { | ||
"application/vnd.jupyter.widget-state+json": { | ||
"state": {}, | ||
"version_major": 2, | ||
"version_minor": 0 | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |