You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

452 lines
15 KiB

import chromadb
import os
from typing import Union, List, Dict, Tuple, Any, Union
import re
from chromadb.config import Settings
from dotenv import load_dotenv
from colorprinter.print_color import *
from models import ChunkSearchResults
load_dotenv(".env")
class ChromaDB:
def __init__(self, local_deployment: bool = False, db="sci_articles", host=None):
if local_deployment:
self.db = chromadb.PersistentClient(f"chroma_{db}")
else:
if not host:
host = os.getenv("CHROMA_HOST")
credentials = os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS")
auth_token_transport_header = os.getenv(
"CHROMA_AUTH_TOKEN_TRANSPORT_HEADER"
)
self.db = chromadb.HttpClient(
host=host,
#database=db,
settings=Settings(
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
chroma_client_auth_credentials=credentials,
chroma_auth_token_transport_header=auth_token_transport_header,
),
)
def query(
self,
query,
collection,
n_results=6,
n_sources=3,
max_retries=None,
where: dict = None,
**kwargs,
):
"""
Query the vector database for relevant documents.
Args:
query (str): The query text to search for.
collection (str): The name of the collection to search within.
n_results (int, optional): The number of results to return. Defaults to 6.
n_sources (int, optional): The number of unique sources to return. Defaults to 3.
max_retries (int, optional): The maximum number of retries for querying. Defaults to None.
where (dict, optional): Additional filtering criteria for the query. Defaults to None.
**kwargs: Additional keyword arguments to pass to the query.
Returns:
dict: A dictionary containing the query results with keys 'ids', 'metadatas', 'documents', and 'distances'.
"""
if not isinstance(n_sources, int):
n_sources = int(n_sources)
if not isinstance(n_results, int):
n_results = int(n_results)
if not max_retries:
max_retries = n_sources
if n_sources > n_results:
n_sources = n_results
col = self.db.get_collection(collection)
sources = []
n = 0
print('Collection', collection)
result = {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]}
while True:
n += 1
if n > max_retries:
break
if where == {}:
where = None
print_rainbow(kwargs)
print('N_results:', n_results)
print('Sources:', sources)
print('Query:', query)
r = col.query(
query_texts=query,
n_results=n_results - len(sources),
where=where,
**kwargs,
)
if r["ids"][0] == []:
if result["ids"][0] == []:
print_rainbow(r)
print_red("No results found in vector database.")
else:
print_red("No more results found in vector database.")
break
# Manually extend each list within the lists of lists
for key in result:
if key in r:
result[key][0].extend(r[key][0])
# Order result by distance
combined = sorted(
zip(
result["distances"][0],
result["ids"][0],
result["metadatas"][0],
result["documents"][0],
),
key=lambda x: x[0],
)
(
result["distances"][0],
result["ids"][0],
result["metadatas"][0],
result["documents"][0],
) = map(list, zip(*combined))
sources += list(set([i["_id"] for i in result["metadatas"][0]]))
if len(sources) >= n_sources:
break
elif n != max_retries:
for k, v in result.items():
if k not in r["included"]:
continue
result[k][0] = v[0][: n_results - (n_sources - len(sources))]
if where and "_id" in where:
where["_id"]["$in"] = [
i for i in where["_id"]["$in"] if i not in sources
]
if where["_id"]["$in"] == []:
break
else:
break
return result
def search(
self,
query: str,
collection: str,
n_results: int = 6,
n_sources: int = 3,
where: dict = None,
format_results: bool = False,
**kwargs,
) -> Union[dict, ChunkSearchResults]:
"""
An enhanced search method that provides a cleaner interface for querying and processing results.
Args:
query (str): The search query
collection (str): Collection name to search in
n_results (int): Maximum number of results to return
n_sources (int): Maximum number of unique sources to include
where (dict, optional): Additional filtering criteria
format_results (bool): Whether to return formatted ChunkSearchResults
**kwargs: Additional arguments to pass to the query
Returns:
List[dict]: List of dictionaries containing the search results
"""
# Get raw query results with existing query method
result = self.query(
query=query,
collection=collection,
n_results=n_results,
n_sources=n_sources,
where=where,
**kwargs,
)
# If no formatting requested, return raw results
if not format_results:
return result
# Process results into dictionary format
combined_chunks = []
for doc, meta, dist, _id in zip(
result["documents"][0],
result["metadatas"][0],
result["distances"][0],
result["ids"][0],
):
combined_chunks.append(
{"document": doc, "metadata": meta, "distance": dist, "id": _id}
)
return combined_chunks
def clean_result_text(self, documents: list) -> list:
"""
Clean text in document results by removing footnote references.
Args:
documents (list): List of document dictionaries
Returns:
list: Documents with cleaned text
"""
import re
for doc in documents:
if "document" in doc:
doc["document"] = re.sub(r"\[\d+\]", "", doc["document"])
return documents
def filter_by_unique_sources(
self, results: list, n_sources: int, source_key: str = "_id"
) -> Tuple[List, List]:
"""
Filters search results to keep only a specified number of unique sources.
Args:
results (list): List of documents from search
n_sources (int): Maximum number of unique sources to include
source_key (str): The key in metadata that identifies the source
Returns:
tuple: (filtered_results, remaining_results)
"""
sources = set()
filtered_results = []
remaining_results = []
for item in results:
source_id = item["metadata"].get(source_key, "no_id")
if source_id not in sources and len(sources) < n_sources:
sources.add(source_id)
filtered_results.append(item)
else:
remaining_results.append(item)
return filtered_results, remaining_results
def backfill_results(
self, filtered_results: list, remaining_results: list, n_results: int
) -> list:
"""
Adds additional results from remaining_results to filtered_results
until n_results is reached.
Args:
filtered_results (list): Initial filtered results
remaining_results (list): Other results that can be added
n_results (int): Target number of total results
Returns:
list: Combined results up to n_results
"""
if len(filtered_results) >= n_results:
return filtered_results[:n_results]
needed = n_results - len(filtered_results)
return filtered_results + remaining_results[:needed]
def search_chunks(
self,
query: str,
collections: List[str],
n_results: int = 7,
n_sources: int = 4,
where: dict = None,
**kwargs,
) -> ChunkSearchResults:
"""
Complete pipeline for processing chunks: search, filter, clean, and format.
Args:
query (str): The search query
collections (List[str]): List of collection names to search
n_results (int): Maximum number of results to return
n_sources (int): Maximum number of unique sources to include
where (dict, optional): Additional filtering criteria
**kwargs: Additional arguments to pass to search
Returns:
ChunkSearchResults: Processed chunks with Chroma IDs
"""
combined_chunks = []
if isinstance(collections, str):
collections = [collections]
# Search all collections
for collection in collections:
chunks = self.search(
query=query,
collection=collection,
n_results=n_results,
n_sources=n_sources,
where=where,
format_results=True,
**kwargs,
)
for chunk in chunks:
combined_chunks.append({
"document": chunk["document"],
"metadata": chunk["metadata"],
"distance": chunk["distance"],
"id": chunk["id"],
})
# Sort and filter results
combined_chunks.sort(key=lambda x: x["distance"])
# Filter by unique sources and backfill
closest_chunks, remaining_chunks = self.filter_by_unique_sources(
combined_chunks, n_sources
)
closest_chunks = self.backfill_results(
closest_chunks, remaining_chunks, n_results
)
# Clean text
closest_chunks = self.clean_result_text(closest_chunks)
return closest_chunks
def add_document(self, _id, collection: str, document: str, metadata: dict = None):
"""
Adds a single document to a specified collection in the database.
Args:
_id (str): Arango ID for the document, used as a unique identifier.
collection (str): The name of the collection to add the document to.
document (str): The document text to be added.
metadata (dict, optional): Metadata to be associated with the document. Defaults to None.
Returns:
None
"""
col = self.db.get_or_create_collection(collection)
if metadata is None:
metadata = {}
col.add(ids=[_id], documents=[document], metadatas=[metadata])
def add_chunks(self, collection: str, chunks: list, _key, metadata: dict = None):
"""
Adds chunks to a specified collection in the database.
Args:
collection (str): The name of the collection to add chunks to.
chunks (list): A list of chunks to be added to the collection.
_key: A key used to generate unique IDs for the chunks.
metadata (dict, optional): Metadata to be associated with each chunk. Defaults to None.
Returns:
None
"""
col = self.db.get_or_create_collection(collection)
ids = []
metadatas = []
for number in chunks:
if metadata:
metadata["number"] = number
metadatas.append(metadata)
else:
metadatas.append({})
ids.append(f"{_key}_{number}")
col.add(ids=ids, metadatas=metadatas, documents=chunks)
def get_collection(self, collection: str) -> chromadb.Collection:
"""
Retrieves a collection from the database.
Args:
collection (str): The name of the collection to retrieve.
Returns:
chromadb.Collection: The requested collection.
"""
return self.db.get_or_create_collection(collection)
def is_reference_chunk(text: str) -> bool:
"""
Determine if a text chunk primarily consists of academic references.
Args:
text (str): Text chunk to analyze
Returns:
bool: True if the chunk appears to be mainly references
"""
# Count significant reference indicators
indicators = 0
# Check for DOI links (very strong indicator)
doi_matches = len(re.findall(r'https?://doi\.org/10\.\d+/\S+', text))
if doi_matches >= 2: # Multiple DOIs almost certainly means references
return True
elif doi_matches == 1:
indicators += 3
# Check for citation patterns with year, volume, pages (e.g., 2018;178:551–60)
citation_patterns = len(re.findall(r'\d{4};\d+:\d+[-–]\d+', text))
indicators += citation_patterns * 2
# Check for year patterns in brackets [YYYY]
year_brackets = len(re.findall(r'\[\d{4}\]', text))
indicators += year_brackets
# Check for multiple lines starting with author name patterns
lines = [line.strip() for line in text.split('\n') if line.strip()]
author_started_lines = 0
for line in lines:
# Common pattern in references: starts with Author Name(s)
if re.match(r'^\s*[A-Z][a-z]+\s+[A-Z][a-z]+', line):
author_started_lines += 1
# If multiple lines start with author names (common in reference lists)
if author_started_lines >= 2:
indicators += 2
# Check for academic reference terms
if re.search(r'\bet al\b|\bet al\.\b', text, re.IGNORECASE):
indicators += 1
# Return True if we have sufficient indicators
return indicators >= 4 # Adjust threshold as needed
if __name__ == "__main__":
from colorprinter.print_color import *
chroma = ChromaDB()
print(chroma.db.list_collections())
print('DB', chroma.db.database)
print('SETTINGS', chroma.db.get_version())
result = chroma.search_chunks(
query="What is Open Science)",
collections="lasse__other_documents",
n_results=2,
n_sources=3,
max_retries=4,
)
collection = chroma.db.get_or_create_collection("lasse__other_documents")
result = collection.query(
query_texts="What is Open Science?",
n_results=2,
)
from pprint import pprint
pprint(result)
#print_rainbow(result["metadatas"][0])