Build a text-to-SQL generator

query_builder 60 min

Text-to-SQL is an application of LLMs that converts natural language queries into SQL. People can use text-to-SQL to query their databases, even if they don’t know SQL.

To improve the accuracy of text-to-SQL, you can use a technique called dynamic few-shot prompting. This tutorial shows how to implement dynamic few-shot prompting with Astra DB Serverless.

Prerequisites

To complete this tutorial, you’ll need the following:

Initialize the clients

First, initialize the OpenAI, Astra, and RedShift clients.

  1. Initialize the OpenAI client with your API key.

    # Initialize the OpenAI client
    import os
    from typing import List
    
    from getpass import getpass
    import openai
    from tenacity import retry, wait_exponential
    
    if "OPENAI_API_KEY" not in os.environ:
        os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")
    
    client = openai.OpenAI()
    
    # This introduces exponential backoff in case ChatGPT is being rate limited/has errors
    @retry(wait=wait_exponential(multiplier=1.2, min=4, max=30))
    def chat_gpt_completion(prompt: str) -> str:
        return client.chat.completions.create(
            messages=[{
                "role": "user",
                "content": prompt,
            }],
            model="gpt-3.5-turbo",
        ).choices[0].message.content
    
    @retry(wait=wait_exponential(multiplier=1.2, min=4, max=30))
    def ada_embedding(text: str) -> List[float]:
        return client.embeddings.create(
            input=text,
            model="text-embedding-3-small",
            timeout=10,
        ).data[0].embedding
  2. Initialize the Astra client with your application token and API endpoint.

    # Initialize the Astra client
    import os
    
    from astrapy.db import AstraDB, AstraDBCollection
    from astrapy.ops import AstraDBOps
    
    # Get the Astra token and API endpoint from the environment or user input
    ASTRA_TOKEN = os.getenv(
        "ASTRA_DB_TOKEN",
        getpass("Astra DB Token: "),
    )
    ASTRA_API_ENDPOINT = os.getenv(
        "ASTRA_DB_API_ENDPOINT",
        input("Astra DB Endpoint: "),
    )
    
    astra_db = AstraDB(token=ASTRA_TOKEN, api_endpoint=ASTRA_API_ENDPOINT)
    Response
    Astra DB Token:  ········
    Astra DB Endpoint:  https://3a6953df-27bf-4272-a584-31f7b4dcaa80-us-east1.apps.astra.datastax.com
  3. Initialize the RedShift connector with your credentials.

    To connect to Amazon RedShift using other forms of authorization, see Connecting to Amazon RedShift.

    import redshift_connector
    
    RS_HOST_URL = os.getenv(
        "REDSHIFT_HOST_URL",
        getpass("RedShift Host URL: "),
    )
    RS_USER = os.getenv(
        "REDSHIFT_USER",
        input("RedShift User: "),
    )
    RS_PASSWORD = os.getenv(
        "REDSHIFT_PASSWORD",
        getpass("RedShift Password: "),
    )
    RS_DB_NAME = os.getenv(
        "REDSHIFT_DB_NAME",
        input("RedShift DB Name: "),
    )
    
    def rs_execute(*statements, return_as_df: bool = True, commit: bool = True) -> pd.DataFrame:
        """Execute RedShift statements and return the result of the final query"""
        with redshift_connector.connect(
            host=RS_HOST_URL,
            database=RS_DB_NAME,
            user=RS_USER,
            password=RS_PASSWORD,
        ) as conn:
            with conn.cursor() as cursor:
                for statement in statements:
                    cursor.execute(statement)
    
                try:
                    if return_as_df:
                        res = cursor.fetch_dataframe()
                    else:
                        res = cursor.fetchall()
                except redshift_connector.ProgrammingError:
                    # No result set (final statement was not a SELECT)
                    res = None
    
                if commit:
                    conn.commit()
    
                return res
    Response
    RedShift Host URL:  ········
    RedShift User:  test_redshift
    RedShift Password:  ········
    RedShift DB Name:  dev

Load the Spider dataset

The Spider dataset is commonly used to evaluate generated SQL performance. The dataset consists of pairs of questions and queries, which show the ideal query to generate from a given natural language question.

  1. Use the datasets package to load the Spider dataset and schema.

    from datasets import load_dataset
    
    spider = load_dataset("spider", split="validation")
    spider_df = spider.to_pandas()
    spider_schema = load_dataset("richardr1126/spider-schema", split="train")
    spider_schema_df = spider_schema.to_pandas()
  2. Show the first few rows of the dataset.

    spider_df.head(5)
    Response
    db_id query question query_toks query_toks_no_value question_toks

    0

    concert_singer

    SELECT count(*) FROM singer

    How many singers do we have?

    [SELECT, count, (, *, ), FROM, singer]

    [select, count, (, *, ), from, singer]

    [How, many, singers, do, we, have, ?]

    1

    concert_singer

    SELECT count(*) FROM singer

    What is the total number of singers?

    [SELECT, count, (, *, ), FROM, singer]

    [select, count, (, *, ), from, singer]

    [What, is, the, total, number, of, singers, ?]

    2

    concert_singer

    SELECT name , country , age FROM singer ORDER BY name

    Show name, country, age for all singers ordered by name

    [SELECT, name, ,, country, ,, age, FROM, singe…​

    [select, name, ,, country, ,, age, from, singe…​

    [Show, name, ,, country, ,, age, for, all, sin…​

    3

    concert_singer

    SELECT name , country , age FROM singer ORDER BY name

    What are the names, countries, and ages for every singer ordered by name?

    [SELECT, name, ,, country, ,, age, FROM, singe…​

    [select, name, ,, country, ,, age, from, singe…​

    [What, are, the, names, ,, countries, ,, and, …​

    4

    concert_singer

    SELECT avg(age) , min(age) , max(age) FROM singer

    What is the average, minimum, and maximum age of all singers?

    [SELECT, avg, (, age, ), ,, min, (, age, ), ,,…​

    [select, avg, (, age, ), ,, min, (, age, ), ,,…​

    [What, is, the, average, ,, minimum, ,, and, m…​

  3. Set aside a few questions for testing. These questions come from two schemas in the Spider dataset. One schema contains tables about singers and where they’ve performed, and the other schema contains tables about pets and which students own those pets.

    # Remove some questions to test the pipeline with
    test_ndxs = [21, 28, 60, 63, 75]  # Randomly pick some queries the model succeeds and fails on
    ndxs_to_drop = test_ndxs.copy()
    test_df = spider_df.loc[test_ndxs]
    
    for ndx in test_ndxs:
        row = spider_df.loc[ndx]
        print(f"Question ({row['db_id']}):\n    {row['question']}")
        print(f"Gold Query SQL:\n    {row['query']}")
        print("=" * 80)
    
        # Remove any rows that have the exact same query as the answer
        # This prevents data leakage so that the dynamic few-shot approach can't see the correct query
        same_query_rows = spider_df[spider_df['query'] == row['query']]
        for same_query_ndx in same_query_rows.index:
            ndxs_to_drop.append(same_query_ndx)
    
    spider_df = spider_df.drop(ndxs_to_drop)
    Response
    Question (concert_singer):
        How many concerts occurred in 2014 or 2015?
    Gold Query SQL:
        SELECT count(*) FROM concert WHERE YEAR  =  2014 OR YEAR  =  2015
    ================================================================================
    Question (concert_singer):
        Show the stadium names without any concert.
    Gold Query SQL:
        SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
    ================================================================================
    Question (pets_1):
        What are the students' first names who have both cats and dogs as pets?
    Gold Query SQL:
        SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat' INTERSECT SELECT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'dog'
    ================================================================================
    Question (pets_1):
        Find the id of students who do not have a cat pet.
    Gold Query SQL:
        SELECT stuid FROM student EXCEPT SELECT T1.stuid FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat'
    ================================================================================
    Question (pets_1):
        Find the first name and age of students who have a pet.
    Gold Query SQL:
        SELECT DISTINCT T1.fname ,  T1.age FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid
    ================================================================================

Set up a SQL database

Store the structured data from the Spider dataset in a SQL-compatible database.

If you have an existing database with usage data formatted as (question, query) pairs, you can connect to that database instead of using the Spider dataset for the few-shot examples below.

  1. Set up the concert_singer database. Create tables matching the Spider schema and insert fake data for each test question.

    # Define concert_singer database
    CREATE_TABLES_SQL = """
    -- Create the stadium table
    CREATE TABLE stadium (
        Stadium_ID INT PRIMARY KEY,
        Location TEXT,
        Name TEXT,
        Capacity INT,
        Highest INT,
        Lowest INT,
        Average INT
    );
    
    -- Create the singer table
    CREATE TABLE singer (
        Singer_ID INT PRIMARY KEY,
        Name TEXT,
        Country TEXT,
        Song_Name TEXT,
        Song_release_year TEXT,
        Age INT,
        Is_male BOOLEAN
    );
    
    -- Create the concert table
    CREATE TABLE concert (
        concert_ID INT PRIMARY KEY,
        concert_Name TEXT,
        Theme TEXT,
        Stadium_ID INT,
        Year TEXT,
        FOREIGN KEY (Stadium_ID) REFERENCES stadium(Stadium_ID)
    );
    
    -- Create the singer_in_concert table
    CREATE TABLE singer_in_concert (
        concert_ID INT,
        Singer_ID INT,
        PRIMARY KEY (concert_ID, Singer_ID),
        FOREIGN KEY (concert_ID) REFERENCES concert(concert_ID),
        FOREIGN KEY (Singer_ID) REFERENCES singer(Singer_ID)
    );
    """
    
    POPULATE_DATA_SQL = """
    -- Populate the stadium table
    INSERT INTO stadium (Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average) VALUES
    (1, 'New York, USA', 'Liberty Stadium', 50000, 1000, 500, 750),
    (2, 'London, UK', 'Royal Arena', 60000, 1500, 600, 900),
    (3, 'Tokyo, Japan', 'Sunshine Dome', 55000, 1200, 550, 800),
    (4, 'Sydney, Australia', 'Ocean Field', 40000, 900, 400, 650),
    (5, 'Berlin, Germany', 'Eagle Grounds', 45000, 1100, 450, 700);
    
    -- Populate the singer table
    INSERT INTO singer (Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male) VALUES
    (1, 'John Doe', 'USA', 'Freedom Song', '2018', 28, TRUE),
    (2, 'Emma Stone', 'UK', 'Rolling Hills', '2019', 25, FALSE),
    (3, 'Haruki Tanaka', 'Japan', 'Tokyo Lights', '2020', 30, TRUE),
    (4, 'Alice Johnson', 'Australia', 'Ocean Waves', '2021', 27, FALSE),
    (5, 'Max Müller', 'Germany', 'Berlin Nights', '2017', 32, TRUE);
    
    -- Populate the concert table
    INSERT INTO concert (concert_ID, concert_Name, Theme, Stadium_ID, Year) VALUES
    (1, 'Freedom Fest', 'Pop', 1, '2021'),
    (2, 'Rock Mania', 'Rock', 2, '2022'),
    (3, 'Electronic Waves', 'Electronic', 3, '2020'),
    (4, 'Jazz Evenings', 'Jazz', 3, '2019'),
    (5, 'Classical Mornings', 'Classical', 5, '2023');
    
    -- Populate the singer_in_concert table
    INSERT INTO singer_in_concert (concert_ID, Singer_ID) VALUES
    (1, 1),
    (1, 2),
    (2, 3),
    (3, 4),
    (4, 5),
    (5, 1),
    (2, 2),
    (3, 3),
    (4, 4),
    (5, 5);
    """
    
    statements = [
        statement.strip() for statement in (CREATE_TABLES_SQL + POPULATE_DATA_SQL).split(";")
        if len(statement.strip()) > 0
    ]
    rs_execute(*statements)
  2. Set up the pets_1 database.

    # Define pets_1 database
    CREATE_TABLES_SQL = """
    -- Create the Student table
    CREATE TABLE Student (
        StuID INT PRIMARY KEY,
        LName VARCHAR(255),
        Fname VARCHAR(255),
        Age INT,
        Sex VARCHAR(10),
        Major INT,
        Advisor INT,
        city_code VARCHAR(50)
    );
    
    -- Create the Pets table
    CREATE TABLE Pets (
        PetID INT PRIMARY KEY,
        PetType VARCHAR(255),
        pet_age INT,
        weight INT
    );
    
    -- Create the Has_Pet table
    CREATE TABLE Has_Pet (
        StuID INT,
        PetID INT,
        FOREIGN KEY (StuID) REFERENCES Student(StuID),
        FOREIGN KEY (PetID) REFERENCES Pets(PetID)
    );
    """
    
    POPULATE_DATA_SQL = """
    -- Populate the Student table
    INSERT INTO Student (StuID, LName, Fname, Age, Sex, Major, Advisor, city_code) VALUES
    (101, 'Smith', 'John', 20, 'M', 501, 301, 'NYC'),
    (102, 'Johnson', 'Emma', 22, 'F', 502, 302, 'LAX'),
    (103, 'Williams', 'Michael', 21, 'M', 503, 303, 'CHI'),
    (104, 'Brown', 'Sarah', 23, 'F', 504, 304, 'HOU'),
    (105, 'Jones', 'David', 19, 'M', 505, 305, 'PHI');
    
    -- Populate the Pets table
    INSERT INTO Pets (PetID, PetType, pet_age, weight) VALUES
    (201, 'dog', 3, 20.5),
    (202, 'cat', 5, 10.2),
    (203, 'dog', 2, 8.1),
    (204, 'parrot', 4, 0.5),
    (205, 'hamster', 1, 0.7);
    
    -- Populate the Has_Pet table
    INSERT INTO Has_Pet (StuID, PetID) VALUES
    (101, 201),
    (101, 202),
    (105, 203),
    (103, 204),
    (104, 205),
    (105, 201);
    """
    
    statements = [
        statement.strip() for statement in (CREATE_TABLES_SQL + POPULATE_DATA_SQL).split(";")
        if len(statement.strip()) > 0
    ]
    rs_execute(*statements)
  3. Define an evaluation function that accepts a list of queries. The function should return if the queries run successfully against the fake data, and if they return the same results that the gold SQL queries return.

    def eval_generated_queries(queries: List[str]) -> pd.DataFrame:
        """Evaluate the given queries against the test set, and report the performance of each row"""
        report = []
    
        for ndx, pred_sql in enumerate(queries):
            row = test_df.iloc[ndx]
            gold_sql = row["query"]
            gold_results = rs_execute(gold_sql)
    
            pred_sql = queries[ndx]
            try:
                pred_results = rs_execute(pred_sql)
                err = None
    
                # Figure out correctness by deciding between set matching (no order) and list matching (ordered)
                correct = len(gold_results) == len(pred_results)
                if correct:
                    # Sort columns so the arrays match up
                    pred_results = pred_results.reindex(sorted(pred_results.columns), axis=1)
                    gold_results = gold_results.reindex(sorted(gold_results.columns), axis=1)
                    if "order by" in gold_sql.lower():
                        # All row orderings must match
                        correct = (gold_results.values == pred_results.values).all()
                    else:
                        for col in gold_results.columns:
                            if set(gold_results[col].tolist()) != set(pred_results[col].tolist()):
                                correct = False
    
            except redshift_connector.ProgrammingError as e:
                pred_results = None
                err = e
                correct = False
    
            report.append({
                "DB ID": row["db_id"],
                "Question": row["question"],
                "Correct": correct,
                "Error": err,
            })
    
        return pd.DataFrame(report)

Use dynamic few-shot prompting to generate SQL from text

Traditional methods like zero-shot prompting or few-shot prompting with fixed examples are a good starting point for a text-to-SQL implementation. However, these methods only get 3 of 5 answers correct on our test dataset. To improve performance, we can use dynamic few-shot prompting with Astra DB Serverless.

Dynamic few-shot prompting is a variation on few-shot prompting that provides tailored information to the model to improve its accuracy. If you generate embeddings for the questions, you can perform a vector search to find similar questions to the one being asked. You can then use these similar questions to construct an improved few-shot prompt for the model.

  1. Load the rest of the Spider dataset into Astra DB Serverless.

    # Create a collection for vector search
    collection = astra_db.create_collection("text2sql_examples_test1", dimension=1536)
    Response
    Astra DB Collection[name="text2sql_examples_test1", endpoint="https://3a6953df-27bf-4272-a584-31f7b4dcaa80-us-east1.apps.astra.datastax.com"]
  2. Generate embeddings for all questions in the training set.

    from tqdm.auto import tqdm
    
    tqdm.pandas()
    
    spider_embeddings = spider_df["question"].progress_apply(lambda question: ada_embedding(question))
    spider_df["question_embedding"] = spider_embeddings
    Response
      0%|          | 0/1028 [00:00<?, ?it/s]
  3. Upload the training set and their embeddings to Astra DB Serverless.

    def insert_doc(row):
        return collection.insert_one({
            "_id": row.name,
            "question": row['question'],
            "gold_sql": row['query'],
            "db_id": row["db_id"],
            "$vector": row["question_embedding"],
        })
    
    spider_df.progress_apply(insert_doc, axis=1)
    Response
      0%|          | 0/1028 [00:00<?, ?it/s]
    0          {'status': {'insertedIds': [0]}}
    1          {'status': {'insertedIds': [1]}}
    2          {'status': {'insertedIds': [2]}}
    3          {'status': {'insertedIds': [3]}}
    4          {'status': {'insertedIds': [4]}}
                           ...
    1029    {'status': {'insertedIds': [1029]}}
    1030    {'status': {'insertedIds': [1030]}}
    1031    {'status': {'insertedIds': [1031]}}
    1032    {'status': {'insertedIds': [1032]}}
    1033    {'status': {'insertedIds': [1033]}}
    Length: 1028, dtype: object
  4. Create a prompt function that generates a few-shot prompt based on the most similar questions to the one being asked.

    def dynamic_few_shot_prompt_fn(db_id: str, question: str) -> str:
        emb = ada_embedding(question)
        docs = collection.vector_find(emb, limit=2)
        closest_q_ids = [
            doc["_id"] for doc in docs
        ]
    
        return few_shot_prompt(db_id, question, indices=closest_q_ids)
  5. Generate SQL from the test questions using the dynamic few-shot prompt.

    # Include the prompt in the output to show the model's context
    sql1 = generate_sql(
        question="How many concerts occurred in 2014 or 2015?",
        db_id="concert_singer",
        prompt_fn=dynamic_few_shot_prompt_fn,
        debug_prompt=True,
    )
    print(sql1)
    Response
    Prompting ChatGPT with:
    Here is an example: Convert text to SQL:
    
    [Schema : (values)]: | concert_singer | stadium : Stadium_ID , Location , Name , Capacity , Highest , Lowest , Average | singer : Singer_ID , Name , Country , Song_Name , Song_release_year , Age , Is_male | concert : concert_ID , concert_Name , Theme , Stadium_ID , Year | singer_in_concert : concert_ID , Singer_ID ;
    
    [Column names (type)]: stadium : Stadium_ID (number) | stadium : Location (text) | stadium : Name (text) | stadium : Capacity (number) | stadium : Highest (number) | stadium : Lowest (number) | stadium : Average (number) | singer : Singer_ID (number) | singer : Name (text) | singer : Country (text) | singer : Song_Name (text) | singer : Song_release_year (text) | singer : Age (number) | singer : Is_male (others) | concert : concert_ID (number) | concert : concert_Name (text) | concert : Theme (text) | concert : Stadium_ID (text) | concert : Year (text) | singer_in_concert : concert_ID (number) | singer_in_concert : Singer_ID (text) ;
    
    [Primary Keys]: stadium : Stadium_ID | singer : Singer_ID | concert : concert_ID | singer_in_concert : concert_ID
    
    [Foreign Keys]: concert : Stadium_ID equals stadium : Stadium_ID | singer_in_concert : Singer_ID equals singer : Singer_ID | singer_in_concert : concert_ID equals concert : concert_ID
    
    [Q]: What are the names and locations of the stadiums that had concerts that occurred in both 2014 and 2015?
    
    [SQL]: SELECT T2.name ,  T2.location FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id  =  T2.stadium_id WHERE T1.Year  =  2014 INTERSECT SELECT T2.name ,  T2.location FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id  =  T2.stadium_id WHERE T1.Year  =  2015
    
    Here is an example: Convert text to SQL:
    
    [Schema : (values)]: | concert_singer | stadium : Stadium_ID , Location , Name , Capacity , Highest , Lowest , Average | singer : Singer_ID , Name , Country , Song_Name , Song_release_year , Age , Is_male | concert : concert_ID , concert_Name , Theme , Stadium_ID , Year | singer_in_concert : concert_ID , Singer_ID ;
    
    [Column names (type)]: stadium : Stadium_ID (number) | stadium : Location (text) | stadium : Name (text) | stadium : Capacity (number) | stadium : Highest (number) | stadium : Lowest (number) | stadium : Average (number) | singer : Singer_ID (number) | singer : Name (text) | singer : Country (text) | singer : Song_Name (text) | singer : Song_release_year (text) | singer : Age (number) | singer : Is_male (others) | concert : concert_ID (number) | concert : concert_Name (text) | concert : Theme (text) | concert : Stadium_ID (text) | concert : Year (text) | singer_in_concert : concert_ID (number) | singer_in_concert : Singer_ID (text) ;
    
    [Primary Keys]: stadium : Stadium_ID | singer : Singer_ID | concert : concert_ID | singer_in_concert : concert_ID
    
    [Foreign Keys]: concert : Stadium_ID equals stadium : Stadium_ID | singer_in_concert : Singer_ID equals singer : Singer_ID | singer_in_concert : concert_ID equals concert : concert_ID
    
    [Q]: Which year has most number of concerts?
    
    [SQL]: SELECT YEAR FROM concert GROUP BY YEAR ORDER BY count(*) DESC LIMIT 1
    
    Here is the test question to be answered: Convert text to SQL:
    
    [Schema : (values)]: | concert_singer | stadium : Stadium_ID , Location , Name , Capacity , Highest , Lowest , Average | singer : Singer_ID , Name , Country , Song_Name , Song_release_year , Age , Is_male | concert : concert_ID , concert_Name , Theme , Stadium_ID , Year | singer_in_concert : concert_ID , Singer_ID ;
    
    [Column names (type)]: stadium : Stadium_ID (number) | stadium : Location (text) | stadium : Name (text) | stadium : Capacity (number) | stadium : Highest (number) | stadium : Lowest (number) | stadium : Average (number) | singer : Singer_ID (number) | singer : Name (text) | singer : Country (text) | singer : Song_Name (text) | singer : Song_release_year (text) | singer : Age (number) | singer : Is_male (others) | concert : concert_ID (number) | concert : concert_Name (text) | concert : Theme (text) | concert : Stadium_ID (text) | concert : Year (text) | singer_in_concert : concert_ID (number) | singer_in_concert : Singer_ID (text) ;
    
    [Primary Keys]: stadium : Stadium_ID | singer : Singer_ID | concert : concert_ID | singer_in_concert : concert_ID
    
    [Foreign Keys]: concert : Stadium_ID equals stadium : Stadium_ID | singer_in_concert : Singer_ID equals singer : Singer_ID | singer_in_concert : concert_ID equals concert : concert_ID
    
    [Q]: How many concerts occurred in 2014 or 2015?
    
    [SQL]:
    
    
    SELECT COUNT(*) FROM concert WHERE Year IN (2014, 2015)
    sql2 = generate_sql(
        question="Show the stadium names without any concert.",
        db_id="concert_singer",
        prompt_fn=dynamic_few_shot_prompt_fn,
    )
    print(sql2)
    Response
    SELECT name
    FROM stadium
    WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)
    sql3 = generate_sql(
        question="What are the students' first names who have both cats and dogs as pets?",
        db_id="pets_1",
        prompt_fn=dynamic_few_shot_prompt_fn,
    )
    print(sql3)
    Response
    SELECT DISTINCT T1.Fname FROM student AS T1
    JOIN has_pet AS T2 ON T1.stuid = T2.stuid
    JOIN pets AS T3 ON T3.petid = T2.petid
    WHERE T3.pettype = 'cat' AND T1.stuid IN (
        SELECT DISTINCT T1.stuid FROM student AS T1
        JOIN has_pet AS T2 ON T1.stuid = T2.stuid
        JOIN pets AS T3 ON T3.petid = T2.petid
        WHERE T3.pettype = 'dog'
    )
    sql4 = generate_sql(
        question="Find the id of students who do not have a cat pet.",
        db_id="pets_1",
        prompt_fn=dynamic_few_shot_prompt_fn,
    )
    print(sql4)
    Response
    SELECT stuID FROM Student WHERE stuID NOT IN (SELECT stuID FROM Has_Pet JOIN Pets ON Has_Pet.petID = Pets.petID WHERE Pets.PetType = 'cat')
    sql5 = generate_sql(
        question="Find the first name and age of students who have a pet.",
        db_id="pets_1",
        prompt_fn=dynamic_few_shot_prompt_fn,
    )
    print(sql5)
    Response
    SELECT DISTINCT T1.fname, T1.age
    FROM student AS T1
    JOIN has_pet AS T2 ON T1.stuid = T2.stuid
    JOIN pets AS T3 ON T3.petid = T2.petid
  6. Evaluate all five queries and see how the LLM did.

    sqls = [sql1, sql2, sql3, sql4, sql5]
    report = eval_generated_queries(sqls)
    display(report)
    Response

    DB ID

    Question

    Correct

    Error

    0

    concert_singer

    How many concerts occurred in 2014 or 2015?

    True

    None

    1

    concert_singer

    Show the stadium names without any concert.

    True

    None

    2

    pets_1

    What are the students' first names who have both cats and dogs as pets?

    True

    None

    3

    pets_1

    Find the id of students who do not have a cat pet.

    True

    None

    4

    pets_1

    Find the first name and age of students who have a pet.

    True

    None

The model was able to correctly answer all 5 questions. Using Astra DB Serverless as a dynamic few-shot prompt store is a robust and low-effort way to improve your text-to-SQL applications in production.

Was this helpful?

Give Feedback

How can we improve the documentation?

© 2024 DataStax | Privacy policy | Terms of use

Apache, Apache Cassandra, Cassandra, Apache Tomcat, Tomcat, Apache Lucene, Apache Solr, Apache Hadoop, Hadoop, Apache Pulsar, Pulsar, Apache Spark, Spark, Apache TinkerPop, TinkerPop, Apache Kafka and Kafka are either registered trademarks or trademarks of the Apache Software Foundation or its subsidiaries in Canada, the United States and/or other countries. Kubernetes is the registered trademark of the Linux Foundation.

General Inquiries: +1 (650) 389-6000, info@datastax.com