Integrate Cohere with Astra DB Serverless

After embedding text with Cohere’s embedding models, Astra DB Serverless indexes the text to perform similarity searches.

This tutorial uses a portion of the Stanford Question Answering Dataset (SQuAD), which consists of questions and answers to set up a Retrieval-Augmented Generation (RAG) pipeline. To perform a similarity search and retrieve relevant answers from the database, embed the questions and store the questions alongside the answers in the database. Then, embed the user’s query.

Prerequisites

The code samples on this page assume the following:

Install packages

Install and import the required Python packages:

  • cohere is the interface for the Cohere models.

  • astrapy is the Python interface for Astra DB’s Data API.

  • datasets contains the SQuAD dataset with the question-and-answer data to be stored in Astra DB Serverless. The datasets can access many datasets from the Hugging Face Datasets Hub.

  • python-dotenv allows the program to load the required credentials from a .env file.

pip install -U cohere astrapy datasets python-dotenv
import cohere
import os
from dotenv import load_dotenv
from astrapy.db import AstraDB, AstraDBCollection
from astrapy.ops import AstraDBOps
from datasets import load_dataset

# ...

Set environment variables

Create a .env file with your Cohere and Astra DB Serverless credentials:

  1. Complete the ASTRA_DB_APPLICATION_TOKEN and ASTRA_DB_API_ENDPOINT lines with the values from the Astra Portal.

  2. Complete the COHERE_API_KEY line with your Cohere API key from the Cohere dashboard.

    ASTRA_DB_APPLICATION_TOKEN=TOKEN
    ASTRA_DB_API_ENDPOINT=ENDPOINT
    COHERE_API_KEY=COHERE_API_KEY
  3. Load the credentials in your test file:

    # ...
    
    load_dotenv()
    
    token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
    api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT")
    cohere_key = os.getenv("COHERE_API_KEY")
    
    astra_db = AstraDB(token=token, api_endpoint=api_endpoint)
    cohere = cohere.Client(cohere_key)
    
    # ...

Create a collection

To create a collection, call the create_collection method.

For a vector store application, you must create your collection with a specified embedding dimension. Only embeddings of this specific length can be inserted into the collection. Therefore, you must select your Cohere embedding model first. Use that model to set the embedding dimension for your collection. The following table includes Cohere embedding models and their embedding dimensions.

In your code, set DIMENSION to the value of your chosen model. See the [Cohere API Reference](https://docs.cohere.com/reference/embed) for a list of available models and their dimensions.

# ...

astra_db.create_collection(collection_name="cohere", dimension=DIMENSION)
collection = AstraDBCollection(
    collection_name="cohere", astra_db=astra_db
)

# ...

Prepare the data

  1. Load the SQuAD dataset.

    # ...
    
    # Select the first 2,000 rows of the training set.
    # These rows contain both questions and answers.
    squad = load_dataset('squad', split='train[:2000]')
    
    # Show some example questions.
    squad["question"][0:20]
    
    # ...
  2. Ask Cohere for the embeddings of all of these questions. The model you selected matches the embedding dimension of the collection.

    # ...
    
    embeddings = cohere.embed(
        texts=squad["question"],
        model="embed-english-v3.0",
        input_type="search_document",
        truncate="END",
    ).embeddings
    
    # Check that the embeddings are the correct length.
    len(embeddings[0])
    
    # ...

    When using Cohere for RAG, embed the documents with an input_type of search_document, and then embed the query with an input_type of search_query.

    The truncate value of 'END' means that if the provided text is too long to embed, the model cuts off the end of the offending text and returns an embedding of only the beginning part. Other options for this parameter include 'START', which cuts off the beginning of the text, and NONE, which returns an error message if the text is too long.

  3. Combine each dictionary, which represents a row from the SQuAD dataset, with its generated embedding.

    This process creates one dictionary with the SQuAD dataset keys and values untouched and the embedding associated with the 'vector' key. Embeddings need to be top-level values associated with the 'vector' key to be valid vector search targets in Astra DB Serverless.

    # ...
    
    to_insert = []
    for i in range(len(squad)):
        to_insert.append({**squad[i], "$vector":embeddings[i]})
    
    # ...

Insert the data

Use the insert_many method to insert your documents into Astra DB.

You have a list of 2,000 dictionaries to insert, but insert_many only takes up to 20 documents at a time. You can use a loop to work around this limitation.

# ...

batch_size = 20
i = 0
while i < (len(to_insert)):
    res = collection.insert_many(documents=to_insert[i: ( i + batch_size)])
    print(i)
    i = i + batch_size

# ...

Embed the Query

Call cohere.embed again with the input_type of 'search_query'. Use the same model and truncate values. Replace the text in user_query to search for an answer to a different question.

# ...

user_query = "What's in front of Notre Dame?"
embedded_query = cohere.embed(
    texts=[user_query],
    model="embed-english-v3.0",
    input_type='search_query',
    truncate='END'
).embeddings[0]

# ...

Find the answer

  1. Use the vector_find method to extract a limited number of rows with questions similar to the embeddded query.

    # ...
    
    results = collection.vector_find(embedded_query, limit=50)
    
    # ...
  2. Extract the answer value and similarity score from those rows.

    # ...
    
    print(f"Query: {user_query}")
    print("Answers:")
    for idx, answer in enumerate(results):
        print(f"\t Answer {idx}: {answer['answers']['text']}, Score: {answer['$similarity']}")
    
    # ...

Was this helpful?

Give Feedback

How can we improve the documentation?

© 2024 DataStax | Privacy policy | Terms of use

Apache, Apache Cassandra, Cassandra, Apache Tomcat, Tomcat, Apache Lucene, Apache Solr, Apache Hadoop, Hadoop, Apache Pulsar, Pulsar, Apache Spark, Spark, Apache TinkerPop, TinkerPop, Apache Kafka and Kafka are either registered trademarks or trademarks of the Apache Software Foundation or its subsidiaries in Canada, the United States and/or other countries. Kubernetes is the registered trademark of the Linux Foundation.

General Inquiries: +1 (650) 389-6000, info@datastax.com