SQL agents with LangGraph 🦜🕸️
Creating accurate SQL queries with LLMs becomes challenging as query complexity increases. Simple prompts suffice for basic SQL, but complex joins and logic require detailed prompts, iterative feedback, and error handling. This post explores building an agentic SQL generation workflow using LangGraph, a framework in the LangChain ecosystem designed for creating stateful, multi-node graphs. It explains how to set up the graph with nodes, edges, and state management, integrate error propagation without breaking flow, and optimize prompt engineering to improve SQL generation accuracy. Experiments using the Sakila database show how richer prompts—adding schema details and few-shot examples—significantly improve query quality. For consistently correct SQL, especially with complex joins, introducing SQL views is recommended.
SQL query generation
creating SQL commands can be greatly affected by the details that can be given in the prompt. normal LLM calls are enough to generate simple SQL commands but if need more complex queries, they need more informative prompts and a trial and error propagating pipeline which step by step tells LLM to correct mistakes and generate SQL commands accurately. One way to do this is to build agentic workflows with LangGraph which can achieve the same thing by adopting a graph approach. LangGraph is a part of the LangChain eco-system which focuses on creating directed graphs rather than a chain to build agents.
LangGraph is a library for building stateful, multi-actor applications with LLMs, used to create agent and multi-agent workflows. Compared to other LLM frameworks, it offers these core benefits: cycles, controllability, and persistence.
to build an SQL agent using this platform, you need state, nodes and edges. nodes are the tools (aka functions) or states and edges are the logical routes that define how the graph uses nodes and decide when to stop. state is a way to store the current state (a snapshot) of the graph. so the graphs can be either state or message. we use a state graph to build the SQL agent which uses Python TypeDict
or Pydantic BaseModel
as the data structure to store the graph state.
one of the main component in this workflow is the SQL query generation node. which we tell the question or what we want then llm generate a SQL query. we can do that,
def query_gen_node(state: State): | |
llm = init_llm() | |
query_gen_chain = query_generation | llm.bind_tools([db_query_tool], tool_choice='db_query_tool') | |
message = query_gen_chain.invoke(state) | |
return { | |
'messages': [message] | |
} |
because we use a state graph approach all our nodes will receive the state as the first argument, you can use config if you need it as the second argument. so here our state looks like this,
from typing import Annotated | |
from typing_extensions import TypedDict | |
from langgraph.graph.message import AnyMessage, add_messages | |
class State(TypedDict): | |
answer: list[dict] | |
messages: Annotated[list[AnyMessage], add_messages] |
init_llm
the function uses use LangChain chat model to initiate an LLM which we bind db_query_tool
which is also a Python function to execute the generated query by the LLM to check whether the generated query is error-free. init_llm
is something like this,
from langchain_ollama import ChatOllama | |
def init_llm(): | |
# here you can use any of supported chat models by | |
# langchain ex: ChatOpenAI, ... | |
return ChatOllama(model='llama3.1:latest') |
db_query_tool
is a tool by definition, there are few ways to define a tool in LangGraph, the easiest way is to use a Python function with an informative doc string like this
from utils.db_tools import get_engine | |
from langchain_community.utilities import SQLDatabase | |
def sql_database(): | |
return SQLDatabase(get_engine()) | |
def db_query_tool(query: str) -> str: | |
""" | |
execute SQL query. | |
""" | |
database = sql_database() | |
# usable tables from connection | |
print("*-" * 20) | |
print("Usable tables from connection: ") | |
print("*-" * 20) | |
pprint.pprint(database.get_usable_table_names()) | |
print("*-" * 20) | |
results = database.run_no_throw(query, include_columns=True) | |
# handle query errors | |
if results == "": | |
return "Empty: There are no results for this query. Return Empty as the answer." | |
if not results: | |
return "Error: There was an error in the query. Please check the query and try again." | |
return results |
or use @tool
decorator to convert the Python function into a LangGraph tool. also can use the class approach as well.
we also need a node to actually run the SQL query and get the query output. this output can be an error message or if the query runs successfully the retrieved data. even though we bind the db_query_tool
into LLM earlier it can not execute the tool itself. that way we create a tool node. so the LLM call will generate the query and it will send the generated query into the tool node by calling the bonded tool. before creating the tool node we need to put a mechanism to propagate errors into the workflow without raising and breaking the flow if the query has errors. by doing this we put a feedback loop into the workflow to generate SQL queries with trails and errors.
Error propagating
LangGraph also has ways to achieve this without much fuzz. like in the below,
from typing import Any | |
from langchain_core.messages import ToolMessage | |
from langgraph.prebuilt import ToolNode | |
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks | |
def handle_tool_error(state) -> dict: | |
error = state.get('error') | |
tool_calls = state['messages'][-1].tool_calls | |
return { | |
'messages': [ | |
ToolMessage( | |
content=f"Error: {repr(error)}\n, fix your mistakes.", | |
tool_call_id=tc['id'] | |
) | |
for tc in tool_calls | |
] | |
} |
This is how to get LLM to know it made a mistake in SQL generation and it raise this error when executing it. so db_query_tool
we specified how we execute SQL queries with SQLDatabase
wrapper class by results = database.run_no_throw(query, include_columns=True). so this will not raise exceptions to break the execution, instead, we export exceptions as messages into the state so we can pass it as feedback into LLM.
then we create our tool node to execute SQL queries and handle errors as feedback to improve LLM generations,
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]: | |
""" | |
Creates a ToolNode with a fallback mechanism to gracefully handle errors encountered | |
during tool execution and surface them to the agent for appropriate handling. | |
""" | |
return ToolNode(tools).with_fallbacks( | |
[RunnableLambda(handle_tool_error)], exception_key='error' | |
) |
the last node we need is to prepare the results we got from executing the SQL generated by LLM. Because I need results from the database as JSON objects I decided to use output parser from langchain.
from langchain_core.messages import HumanMessage | |
from agent.states import State | |
from langchain_core.output_parsers import JsonOutputParser | |
def final_answer_node(state: State): | |
last_message = state["messages"][-1] | |
try: | |
last_message = HumanMessage(content=last_message.content | |
.replace("'", "\"") | |
.replace("None", "\"no data\"")) | |
json_parser = JsonOutputParser() | |
answer = json_parser.invoke(last_message) | |
except: | |
answer = last_message.content | |
return { | |
'messages': state['messages'], | |
'answer': answer | |
} |
those are the all nodes I need to create in this SQL agent. if you need extra steps you can add more nodes to the graph.
Feedback loop
in LangGraph there are a few types of edges normal, conditional, entry, and conditional entry. each of them has a different purpose, for this agent, we only need normal and conditional edges. conditional edges are the routers which decide which node should go next based on the logic implemented. so logic we have in the SQL generations is if LLM generates error-free SQL we need to execute it and get the data and if not SQL generated error-free it needs to go back to LLM with feedback and say ‘this code you created is wrong, and this is what I got! create new one based on this feedback’. so to do this we need to use add_conditional_edges
.
workflow.add_conditional_edges("execute_query", should_continue)
to do the routing we need a Python function with the above logic implemented. it is something like this,
def should_continue(state: State) -> Literal["final_answer", "query_gen"]: | |
messages = state["messages"] | |
last_message = messages[-1] | |
if last_message.content.startswith("Error:"): | |
print('error in query, go into query gen node') | |
return "query_gen" | |
else: | |
print('final answer was gathered, go into final answer node') | |
return "final_answer" |
Complete graph
this is what it looks like when we put all those parts in together to build our agent. START
and END
are special nodes which mark the start (entry point) and end (exiting point) in the graph.
from langgraph.graph import StateGraph, START, END | |
from langgraph.graph.graph import CompiledGraph | |
from states import State | |
from tools.database_tools import db_query_tool | |
from tools.tool_nodes import query_gen_node, should_continue, final_answer_node | |
from tools.error_handling import create_tool_node_with_fallback | |
def run() -> CompiledGraph: | |
workflow = StateGraph(State) | |
workflow.add_node("query_gen", query_gen_node) | |
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool])) | |
workflow.add_node('final_answer', final_answer_node) | |
workflow.add_edge(START, "query_gen") | |
workflow.add_edge("query_gen", "execute_query") | |
workflow.add_conditional_edges("execute_query", should_continue) | |
workflow.add_edge("final_answer", END) | |
# Compile the workflow into a runnable | |
app = workflow.compile() | |
return app |
Prompt engineering
as always the LLMs prompt is more important than someone realises to get better outputs from LLMs. two things that we must do is give table schemas we hope to use in the LLM and give as much clear annotation about the columns. special things like primary keys, foreign keys, de-duplication and most importantly categorical values if categorical columns are being used. otherwise, LLM will hallucinate if user queries refer to those categorical columns.
let's do some experiments with prompt engineering here to see how SQL generations getting better with each iteration of prompt optimizations.
- database: Sakila
- tables: payment, rental, inventory, film
- user question: what are the top 3 in terms of revenue, PG rating films?
Exp 01: initial prompt
from langchain.prompts import ChatPromptTemplate | |
from utils.db_tools import get_engine | |
from tools.table_schemas import get_table_schemas | |
query_generation = """ | |
You are a SQL query expert capable of analyzing questions and generating error-free SQL queries to answer them. | |
## Main Tasks | |
1. Thoroughly understand the question before generating SQL queries. | |
2. Decide query type: Aggregate (using aggregation functions) or Normal query. | |
3. Write SQL queries to answer the question precisely. | |
4. Ensure the SQL queries are free of errors. | |
## Resources | |
- Schema: | |
{SCHEMA} | |
- Categorical values: | |
{CATEGORICAL_VALUES} | |
### Step 1: Question Analysis | |
- Identify relevant columns from Schema. | |
- Recognize syntactically similar column names or values. | |
- Determine if numerical results are required. | |
- Identify necessary filters. | |
- Assess need for aggregation functions. | |
- Determine if joins or sub-queries are needed. | |
### Step 2: SQL Query Writing | |
Requirements to write correct SQL queries: | |
1. Select relevant columns from Step 1. | |
2. Use aliases, especially for aggregation functions. | |
3. Use sub-queries/CTEs and joins for complex queries. | |
4. Apply WHERE clause for filtering if needed. | |
5. Filter NULL values. | |
6. Avoid ORDER BY unless required. | |
7. Limit to 10 records (if user want to limit output results). | |
8. Ignore limit for aggregation functions. | |
9. Use LIKE with wildcards (%, _) instead of direct string comparison. | |
10. Use IN operator for multiple category comparisons. | |
11. Include grouping columns in SELECT clause | |
12. cast date columns to TEXT for better readability. | |
13. Avoid DML statements. | |
14. Sanitize input to prevent SQL injections. | |
### Step 3: Error Checking | |
- Check for SQL syntax errors. | |
- Check for logical errors. | |
- Verify query matches Step 1 actions. | |
- Correct any SQL syntax errors. | |
## Response Guidelines | |
**YOU MUST CALL THE CORRECT TOOL.** | |
<Tools> | |
<Tool>db_query_tool</Tool> | |
</Tools> | |
""" | |
schemas = get_table_schemas(['payment', 'rental', 'inventory', 'film'], get_engine()) | |
table_schema = '' | |
for table, schema in schemas.items(): | |
table_schema += f'Table: {table}\n' | |
table_schema += schema | |
table_schema += '\n\n' | |
categorical_desc = "rating column only can be one of, [PG, G, NC-17, PG-13, R]" | |
query_generation = ChatPromptTemplate.from_messages([ | |
('system', query_generation.format(SCHEMA=table_schema, | |
CATEGORICAL_VALUES=categorical_desc)), | |
('placeholder', '{messages}') | |
]) |
wrong query:
SELECT title, rental_rate FROM film WHERE rating = \’PG\’ ORDER BY rental_rate DESC LIMIT 3;
{ | |
'answer': '[{"title": "BEHAVIOR RUNAWAY", "rental_rate": Decimal("4.99")}, {"title": "BIRCH ANTITRUST", "rental_rate": Decimal("4.99")}, {"title": "ALI FOREVER", "rental_rate": Decimal("4.99")}]', | |
'messages': [HumanMessage(content='what are the top 3 in-term of revenue in PG rating films?', | |
id='ce9697be-50d0-417f-a168-c1652ae4c35a'), AIMessage(content='', additional_kwargs={ | |
'tool_calls': [{'id': 'call_p6ts', 'function': { | |
'arguments': '{"query":"SELECT title, rental_rate FROM film WHERE rating = \'PG\' ORDER BY rental_rate DESC LIMIT 3;"}', | |
'name': 'db_query_tool'}, 'type': 'function'}]}, response_metadata={ | |
'token_usage': {'completion_tokens': 55, 'prompt_tokens': 1584, 'total_tokens': 1639, | |
'completion_time': 0.180338999, 'prompt_time': 0.131903739, | |
'queue_time': 0.023737848000000006, 'total_time': 0.312242738}, | |
'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_87cbfbbc4d', 'finish_reason': 'tool_calls', | |
'logprobs': None}, id='run-7cbfb8bc-637d-4eb0-9379-454b89f5a829-0', tool_calls=[{'name': 'db_query_tool', | |
'args': { | |
'query': "SELECT title, rental_rate FROM film WHERE rating = 'PG' ORDER BY rental_rate DESC LIMIT 3;"}, | |
'id': 'call_p6ts', | |
'type': 'tool_call'}]), | |
ToolMessage( | |
content="[{'title': 'BEHAVIOR RUNAWAY', 'rental_rate': Decimal('4.99')}, {'title': 'BIRCH ANTITRUST', 'rental_rate': Decimal('4.99')}, {'title': 'ALI FOREVER', 'rental_rate': Decimal('4.99')}]", | |
name='db_query_tool', id='683f5266-5b78-4bfc-b5cd-96cbdd6b3b8c', tool_call_id='call_p6ts')]}[ | |
{"title": "BEHAVIOR RUNAWAY", "rental_rate": Decimal("4.99")}, {"title": "BIRCH ANTITRUST", | |
"rental_rate": Decimal("4.99")}, { | |
"title": "ALI FOREVER", "rental_rate": Decimal("4.99")}] |
Exp 02: initial prompt + column descriptions
let's give more details about columns which in our tables,
""" | |
film table column descriptions: | |
film_id: Unique identifier for each film | |
title: Name of the film | |
description: Brief synopsis of the film's plot | |
release_year: Year the film was released | |
language_id: ID of the film's primary language | |
original_language_id: ID of the film's original language (if different) | |
rental_duration: Standard rental period in days | |
rental_rate: Cost to rent the film | |
length: Duration of the film in minutes | |
replacement_cost: Fee charged if the film is not returned or damaged | |
rating: Film's rating (e.g., G, PG, R) | |
special_features: Additional content or formats available | |
last_update: Timestamp of the most recent update to the record | |
inventory table column descriptions: | |
inventory_id: Unique identifier for each inventory item | |
film_id: Foreign key referencing the film table, indicating which film this inventory item is | |
store_id: Identifier of the store where this inventory item is located | |
last_update: Timestamp of the most recent update to this inventory record | |
payment table column descriptions: | |
payment_id: Unique identifier for each payment transaction | |
customer_id: Foreign key referencing the customer who made the payment | |
staff_id: Foreign key referencing the staff member who processed the payment | |
rental_id: Foreign key referencing the associated rental transaction | |
amount: The payment amount | |
payment_date: Date and time when the payment was made | |
last_update: Timestamp of the most recent update to this payment record | |
rental table column descriptions: | |
rental_id: Unique identifier for each rental transaction | |
rental_date: Date and time when the rental was made | |
inventory_id: Foreign key referencing the specific inventory item rented | |
customer_id: Foreign key referencing the customer who rented the item | |
return_date: Date and time when the item was returned (null if not yet returned) | |
staff_id: Foreign key referencing the staff member who processed the rental | |
last_update: Timestamp of the most recent update to this rental record | |
""" |
let's add this to the prompt and see how this will improve,
Wrong (but close) query:
SELECT f.title, SUM(p.amount) as revenue FROM film f JOIN inventory i ON f.film_id = i.film_id JOIN rental r ON i.inventory_id = r.inventory_id JOIN payment p ON r.rental_id = p.rental_id WHERE f.rating=\’PG\’ GROUP BY f.title ORDER BY revenue DESC LIMIT 3;
{ | |
'answer': '[{"title": "TELEGRAPH VOYAGE", "revenue": Decimal("231.73")}, {"title": "GOODFELLAS SALUTE", "revenue": Decimal("209.69")}, {"title": "TITANS JERK", "revenue": Decimal("201.71")}]', | |
'messages': [HumanMessage(content='what are the top 3 in-term of revenue in PG rating films?', | |
id='724a1d3c-7915-4cd3-85c0-7ba2a6736f12'), AIMessage(content='', additional_kwargs={ | |
'tool_calls': [{'id': 'call_fvz7', 'function': { | |
'arguments': '{"query":"SELECT f.title, SUM(p.amount) as revenue FROM film f JOIN inventory i ON f.film_id = i.film_id JOIN rental r ON i.inventory_id = r.inventory_id JOIN payment p ON r.rental_id = p.rental_id WHERE f.rating=\'PG\' GROUP BY f.title ORDER BY revenue DESC LIMIT 3"}', | |
'name': 'db_query_tool'}, 'type': 'function'}]}, response_metadata={ | |
'token_usage': {'completion_tokens': 135, 'prompt_tokens': 2007, 'total_tokens': 2142, | |
'completion_time': 0.43354108, 'prompt_time': 0.168751257, | |
'queue_time': 0.004758699000000005, 'total_time': 0.602292337}, | |
'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_c1a4bcec29', 'finish_reason': 'tool_calls', | |
'logprobs': None}, id='run-d370a7e1-c1e3-44f4-8c0c-81ebb4bd2652-0', tool_calls=[{'name': 'db_query_tool', | |
'args': { | |
'query': "SELECT f.title, SUM(p.amount) as revenue FROM film f JOIN inventory i ON f.film_id = i.film_id JOIN rental r ON i.inventory_id = r.inventory_id JOIN payment p ON r.rental_id = p.rental_id WHERE f.rating='PG' GROUP BY f.title ORDER BY revenue DESC LIMIT 3"}, | |
'id': 'call_fvz7', | |
'type': 'tool_call'}]), | |
ToolMessage( | |
content="[{'title': 'TELEGRAPH VOYAGE', 'revenue': Decimal('231.73')}, {'title': 'GOODFELLAS SALUTE', 'revenue': Decimal('209.69')}, {'title': 'TITANS JERK', 'revenue': Decimal('201.71')}]", | |
name='db_query_tool', id='ad92694d-ba63-47c1-a026-dcd53d4e4681', tool_call_id='call_fvz7')]}[ | |
{"title": "TELEGRAPH VOYAGE", "revenue": Decimal("231.73")}, {"title": "GOODFELLAS SALUTE", | |
"revenue": Decimal("209.69")}, { | |
"title": "TITANS JERK", "revenue": Decimal("201.71")}] |
Exp 03: initial prompt + column descriptions + few shot examples
let’s give some queries with joins and subqueries as an example for LLM to get an idea about how to combine these tables to get answers.
""" | |
Question: which film has the most rentals? | |
SQL Query: | |
select title, | |
count(rental_id) as rental_count | |
from film | |
join (select i.film_id, | |
r.rental_id | |
from inventory as i | |
join rental as r on i.inventory_id = r.inventory_id) as t | |
on film.film_id = t.film_id | |
group by title | |
order by rental_count desc; | |
Question: which film has the most revenue? | |
SQL Query: | |
select title, | |
sum(amount * rental_period) as revenue | |
from film | |
join (select i.film_id, | |
r.* | |
from inventory as i | |
join (select rental.inventory_id, | |
rental.rental_id, | |
amount, | |
datediff(return_date, rental_date) as rental_period | |
from rental | |
join payment on rental.rental_id = payment.rental_id) as r | |
on i.inventory_id = r.inventory_id) as t | |
on film.film_id = t.film_id | |
group by title | |
order by revenue desc | |
limit 1; | |
Question: which film has the highest rental period? | |
SQL Query: | |
select title, | |
max(datediff(return_date, rental_date)) as rental_period | |
from film | |
join (select i.film_id, | |
r.rental_id, | |
rental_date, | |
return_date | |
from inventory as i | |
join rental as r | |
on i.inventory_id = r.inventory_id) as t | |
on film.film_id = t.film_id | |
group by title; | |
""" |
Correct (but agent not consistent) query:
SELECT title, sum(amount * rental_period) as revenue FROM film JOIN (SELECT i.film_id, r.* FROM inventory AS i JOIN (SELECT rental.inventory_id, rental.rental_id, amount, DATEDIFF(return_date, rental_date) as rental_period FROM rental JOIN payment on rental.rental_id = payment.rental_id) AS r ON i.inventory_id = r.inventory_id) AS t ON film.film_id = t.film_id WHERE rating = ‘PG’ GROUP BY title ORDER BY revenue DESC LIMIT 3
{ | |
'answer': '[{"title": "TELEGRAPH VOYAGE", "revenue": Decimal("1631.25")}, {"title": "PELICAN COMFORTS", "revenue": Decimal("1234.47")}, {"title": "TITANS JERK", "revenue": Decimal("1217.44")}]', | |
'messages': [HumanMessage(content='what are the top 3 in-term of revenue in PG rating films?', | |
id='6a0c3175-16b2-49db-b33f-a096ffc65556'), AIMessage(content='', additional_kwargs={ | |
'tool_calls': [{'id': 'call_8y82', 'function': { | |
'arguments': '{"query": "select title, sum(amount * datediff(return_date, rental_date)) as revenue from film join (select i.film_id, r.* from inventory as i join (select rental.inventory_id, rental.rental_id, amount, return_date, rental_date from rental join payment on rental.rental_id = payment.rental_id) as r on i.inventory_id = r.inventory_id) as t on film.film_id = t.film_id where film.rating = \'PG\' group by title order by revenue desc limit 3"}', | |
'name': 'db_query_tool'}, 'type': 'function'}]}, response_metadata={ | |
'token_usage': {'completion_tokens': 118, 'prompt_tokens': 1779, 'total_tokens': 1897, | |
'completion_time': 0.472, 'prompt_time': 0.420915458, 'queue_time': 0.00543037399999996, | |
'total_time': 0.892915458}, 'model_name': 'llama-3.1-70b-versatile', | |
'system_fingerprint': 'fp_9260b4bb2e', 'finish_reason': 'tool_calls', 'logprobs': None}, | |
id='run-621a4be6-433a-463d-bd17-f024c8427c42-0', | |
tool_calls=[ | |
{'name': 'db_query_tool', | |
'args': { | |
'query': "select title, sum(amount * datediff(return_date, rental_date)) as revenue from film join (select i.film_id, r.* from inventory as i join (select rental.inventory_id, rental.rental_id, amount, return_date, rental_date from rental join payment on rental.rental_id = payment.rental_id) as r on i.inventory_id = r.inventory_id) as t on film.film_id = t.film_id where film.rating = 'PG' group by title order by revenue desc limit 3"}, | |
'id': 'call_8y82', | |
'type': 'tool_call'}]), | |
ToolMessage( | |
content="[{'title': 'TELEGRAPH VOYAGE', 'revenue': Decimal('1631.25')}, {'title': 'PELICAN COMFORTS', 'revenue': Decimal('1234.47')}, {'title': 'TITANS JERK', 'revenue': Decimal('1217.44')}]", | |
name='db_query_tool', id='7b28c830-bd83-4148-b38d-42771ac350d4', tool_call_id='call_8y82')]}[ | |
{"title": "TELEGRAPH VOYAGE", "revenue": Decimal("1631.25")}, {"title": "PELICAN COMFORTS", | |
"revenue": Decimal("1234.47")}, { | |
"title": "TITANS JERK", "revenue": Decimal("1217.44")}] |
with a few examples agent did eventually generate the correct SQL but the agent does not consistently create correct queries.
normally LLMs are done well in simple SQL queries but in this case, they struggle with the queries that need complex joins to get the correct answer. One way to overcome this is to use SQL view. which avoid the need of complex joins. to do that we need to decide what actually we need to get from the database. based on that we can create a view and give it to the agent and can do the SQL querying.
Comments
Post a Comment