Integrate Cohere with Astra DB Serverless
After embedding text with Cohere’s embedding models, Astra DB Serverless indexes the text to perform similarity searches.
This tutorial uses a portion of the Stanford Question Answering Dataset (SQuAD), which consists of questions and answers to set up a Retrieval-Augmented Generation (RAG) pipeline. To perform a similarity search and retrieve relevant answers from the database, embed the questions and store the questions alongside the answers in the database. Then, embed the user’s query.
Prerequisites
The code samples on this page assume the following:
-
You have an active Astra account.
-
You have created a vector database.
-
You have created an application token with the Database Administrator role.
Install packages
Install and import the required Python packages:
-
cohere
is the interface for the Cohere models. -
astrapy
is the Python interface for Astra DB’s Data API. -
datasets
contains the SQuAD dataset with the question-and-answer data to be stored in Astra DB Serverless. Thedatasets
can access many datasets from the Hugging Face Datasets Hub. -
python-dotenv
allows the program to load the required credentials from a.env
file.
pip install -U cohere astrapy datasets python-dotenv
import cohere
import os
from dotenv import load_dotenv
from astrapy.db import AstraDB, AstraDBCollection
from astrapy.ops import AstraDBOps
from datasets import load_dataset
# ...
Set environment variables
Create a .env
file with your Cohere and Astra DB Serverless credentials:
-
Complete the
ASTRA_DB_APPLICATION_TOKEN
andASTRA_DB_API_ENDPOINT
lines with the values from the Astra Portal. -
Complete the
COHERE_API_KEY
line with your Cohere API key from the Cohere dashboard.ASTRA_DB_APPLICATION_TOKEN=TOKEN ASTRA_DB_API_ENDPOINT=ENDPOINT COHERE_API_KEY=COHERE_API_KEY
-
Load the credentials in your test file:
# ... load_dotenv() token = os.getenv("ASTRA_DB_APPLICATION_TOKEN") api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT") cohere_key = os.getenv("COHERE_API_KEY") astra_db = AstraDB(token=token, api_endpoint=api_endpoint) cohere = cohere.Client(cohere_key) # ...
Create a collection
To create a collection, call the create_collection
method.
For a vector store application, you must create your collection with a specified embedding dimension. Only embeddings of this specific length can be inserted into the collection. Therefore, you must select your Cohere embedding model first. Use that model to set the embedding dimension for your collection. The following table includes Cohere embedding models and their embedding dimensions.
In your code, set DIMENSION
to the value of your chosen model.
See the [Cohere API Reference](https://docs.cohere.com/reference/embed) for a list of available models and their dimensions.
# ...
astra_db.create_collection(collection_name="cohere", dimension=DIMENSION)
collection = AstraDBCollection(
collection_name="cohere", astra_db=astra_db
)
# ...
Prepare the data
-
Load the SQuAD dataset.
# ... # Select the first 2,000 rows of the training set. # These rows contain both questions and answers. squad = load_dataset('squad', split='train[:2000]') # Show some example questions. squad["question"][0:20] # ...
-
Ask Cohere for the embeddings of all of these questions. The
model
you selected matches the embedding dimension of the collection.# ... embeddings = cohere.embed( texts=squad["question"], model="embed-english-v3.0", input_type="search_document", truncate="END", ).embeddings # Check that the embeddings are the correct length. len(embeddings[0]) # ...
When using Cohere for RAG, embed the documents with an
input_type
ofsearch_document
, and then embed the query with aninput_type
ofsearch_query
.The
truncate
value of 'END' means that if the provided text is too long to embed, the model cuts off the end of the offending text and returns an embedding of only the beginning part. Other options for this parameter include 'START', which cuts off the beginning of the text, andNONE
, which returns an error message if the text is too long. -
Combine each dictionary, which represents a row from the SQuAD dataset, with its generated embedding.
This process creates one dictionary with the SQuAD dataset keys and values untouched and the embedding associated with the 'vector' key. Embeddings need to be top-level values associated with the 'vector' key to be valid vector search targets in Astra DB Serverless.
# ... to_insert = [] for i in range(len(squad)): to_insert.append({**squad[i], "$vector":embeddings[i]}) # ...
Insert the data
Use the insert_many
method to insert your documents into Astra DB.
You have a list of 2,000 dictionaries to insert, but |
# ...
batch_size = 20
i = 0
while i < (len(to_insert)):
res = collection.insert_many(documents=to_insert[i: ( i + batch_size)])
print(i)
i = i + batch_size
# ...
Embed the Query
Call cohere.embed
again with the input_type
of 'search_query'.
Use the same model
and truncate
values.
Replace the text in user_query
to search for an answer to a different question.
# ...
user_query = "What's in front of Notre Dame?"
embedded_query = cohere.embed(
texts=[user_query],
model="embed-english-v3.0",
input_type='search_query',
truncate='END'
).embeddings[0]
# ...
Find the answer
-
Use the
vector_find
method to extract a limited number of rows with questions similar to the embeddded query.# ... results = collection.vector_find(embedded_query, limit=50) # ...
-
Extract the answer value and similarity score from those rows.
# ... print(f"Query: {user_query}") print("Answers:") for idx, answer in enumerate(results): print(f"\t Answer {idx}: {answer['answers']['text']}, Score: {answer['$similarity']}") # ...