You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
452 lines
15 KiB
452 lines
15 KiB
import chromadb |
|
import os |
|
from typing import Union, List, Dict, Tuple, Any, Union |
|
import re |
|
|
|
from chromadb.config import Settings |
|
from dotenv import load_dotenv |
|
from colorprinter.print_color import * |
|
from models import ChunkSearchResults |
|
|
|
|
|
load_dotenv(".env") |
|
|
|
|
|
class ChromaDB: |
|
def __init__(self, local_deployment: bool = False, db="sci_articles", host=None): |
|
if local_deployment: |
|
self.db = chromadb.PersistentClient(f"chroma_{db}") |
|
else: |
|
if not host: |
|
host = os.getenv("CHROMA_HOST") |
|
credentials = os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS") |
|
auth_token_transport_header = os.getenv( |
|
"CHROMA_AUTH_TOKEN_TRANSPORT_HEADER" |
|
) |
|
self.db = chromadb.HttpClient( |
|
host=host, |
|
#database=db, |
|
settings=Settings( |
|
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", |
|
chroma_client_auth_credentials=credentials, |
|
chroma_auth_token_transport_header=auth_token_transport_header, |
|
), |
|
) |
|
|
|
def query( |
|
self, |
|
query, |
|
collection, |
|
n_results=6, |
|
n_sources=3, |
|
max_retries=None, |
|
where: dict = None, |
|
**kwargs, |
|
): |
|
""" |
|
Query the vector database for relevant documents. |
|
|
|
Args: |
|
query (str): The query text to search for. |
|
collection (str): The name of the collection to search within. |
|
n_results (int, optional): The number of results to return. Defaults to 6. |
|
n_sources (int, optional): The number of unique sources to return. Defaults to 3. |
|
max_retries (int, optional): The maximum number of retries for querying. Defaults to None. |
|
where (dict, optional): Additional filtering criteria for the query. Defaults to None. |
|
**kwargs: Additional keyword arguments to pass to the query. |
|
|
|
Returns: |
|
dict: A dictionary containing the query results with keys 'ids', 'metadatas', 'documents', and 'distances'. |
|
""" |
|
if not isinstance(n_sources, int): |
|
n_sources = int(n_sources) |
|
if not isinstance(n_results, int): |
|
n_results = int(n_results) |
|
if not max_retries: |
|
max_retries = n_sources |
|
if n_sources > n_results: |
|
n_sources = n_results |
|
col = self.db.get_collection(collection) |
|
sources = [] |
|
n = 0 |
|
print('Collection', collection) |
|
result = {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]} |
|
|
|
while True: |
|
n += 1 |
|
if n > max_retries: |
|
break |
|
if where == {}: |
|
where = None |
|
|
|
print_rainbow(kwargs) |
|
print('N_results:', n_results) |
|
print('Sources:', sources) |
|
print('Query:', query) |
|
r = col.query( |
|
query_texts=query, |
|
n_results=n_results - len(sources), |
|
where=where, |
|
**kwargs, |
|
) |
|
if r["ids"][0] == []: |
|
if result["ids"][0] == []: |
|
print_rainbow(r) |
|
print_red("No results found in vector database.") |
|
else: |
|
print_red("No more results found in vector database.") |
|
break |
|
|
|
# Manually extend each list within the lists of lists |
|
for key in result: |
|
if key in r: |
|
result[key][0].extend(r[key][0]) |
|
|
|
# Order result by distance |
|
combined = sorted( |
|
zip( |
|
result["distances"][0], |
|
result["ids"][0], |
|
result["metadatas"][0], |
|
result["documents"][0], |
|
), |
|
key=lambda x: x[0], |
|
) |
|
( |
|
result["distances"][0], |
|
result["ids"][0], |
|
result["metadatas"][0], |
|
result["documents"][0], |
|
) = map(list, zip(*combined)) |
|
sources += list(set([i["_id"] for i in result["metadatas"][0]])) |
|
if len(sources) >= n_sources: |
|
break |
|
elif n != max_retries: |
|
for k, v in result.items(): |
|
if k not in r["included"]: |
|
continue |
|
result[k][0] = v[0][: n_results - (n_sources - len(sources))] |
|
if where and "_id" in where: |
|
where["_id"]["$in"] = [ |
|
i for i in where["_id"]["$in"] if i not in sources |
|
] |
|
if where["_id"]["$in"] == []: |
|
break |
|
else: |
|
break |
|
return result |
|
|
|
def search( |
|
self, |
|
query: str, |
|
collection: str, |
|
n_results: int = 6, |
|
n_sources: int = 3, |
|
where: dict = None, |
|
format_results: bool = False, |
|
**kwargs, |
|
) -> Union[dict, ChunkSearchResults]: |
|
""" |
|
An enhanced search method that provides a cleaner interface for querying and processing results. |
|
|
|
Args: |
|
query (str): The search query |
|
collection (str): Collection name to search in |
|
n_results (int): Maximum number of results to return |
|
n_sources (int): Maximum number of unique sources to include |
|
where (dict, optional): Additional filtering criteria |
|
format_results (bool): Whether to return formatted ChunkSearchResults |
|
**kwargs: Additional arguments to pass to the query |
|
|
|
Returns: |
|
List[dict]: List of dictionaries containing the search results |
|
""" |
|
# Get raw query results with existing query method |
|
result = self.query( |
|
query=query, |
|
collection=collection, |
|
n_results=n_results, |
|
n_sources=n_sources, |
|
where=where, |
|
**kwargs, |
|
) |
|
|
|
|
|
# If no formatting requested, return raw results |
|
if not format_results: |
|
return result |
|
|
|
# Process results into dictionary format |
|
combined_chunks = [] |
|
for doc, meta, dist, _id in zip( |
|
result["documents"][0], |
|
result["metadatas"][0], |
|
result["distances"][0], |
|
result["ids"][0], |
|
): |
|
combined_chunks.append( |
|
{"document": doc, "metadata": meta, "distance": dist, "id": _id} |
|
) |
|
|
|
return combined_chunks |
|
|
|
def clean_result_text(self, documents: list) -> list: |
|
""" |
|
Clean text in document results by removing footnote references. |
|
|
|
Args: |
|
documents (list): List of document dictionaries |
|
|
|
Returns: |
|
list: Documents with cleaned text |
|
""" |
|
import re |
|
|
|
for doc in documents: |
|
if "document" in doc: |
|
doc["document"] = re.sub(r"\[\d+\]", "", doc["document"]) |
|
return documents |
|
|
|
def filter_by_unique_sources( |
|
self, results: list, n_sources: int, source_key: str = "_id" |
|
) -> Tuple[List, List]: |
|
""" |
|
Filters search results to keep only a specified number of unique sources. |
|
|
|
Args: |
|
results (list): List of documents from search |
|
n_sources (int): Maximum number of unique sources to include |
|
source_key (str): The key in metadata that identifies the source |
|
|
|
Returns: |
|
tuple: (filtered_results, remaining_results) |
|
""" |
|
sources = set() |
|
filtered_results = [] |
|
remaining_results = [] |
|
|
|
for item in results: |
|
source_id = item["metadata"].get(source_key, "no_id") |
|
if source_id not in sources and len(sources) < n_sources: |
|
sources.add(source_id) |
|
filtered_results.append(item) |
|
else: |
|
remaining_results.append(item) |
|
|
|
return filtered_results, remaining_results |
|
|
|
def backfill_results( |
|
self, filtered_results: list, remaining_results: list, n_results: int |
|
) -> list: |
|
""" |
|
Adds additional results from remaining_results to filtered_results |
|
until n_results is reached. |
|
|
|
Args: |
|
filtered_results (list): Initial filtered results |
|
remaining_results (list): Other results that can be added |
|
n_results (int): Target number of total results |
|
|
|
Returns: |
|
list: Combined results up to n_results |
|
""" |
|
if len(filtered_results) >= n_results: |
|
return filtered_results[:n_results] |
|
|
|
needed = n_results - len(filtered_results) |
|
return filtered_results + remaining_results[:needed] |
|
|
|
def search_chunks( |
|
self, |
|
query: str, |
|
collections: List[str], |
|
n_results: int = 7, |
|
n_sources: int = 4, |
|
where: dict = None, |
|
**kwargs, |
|
) -> ChunkSearchResults: |
|
""" |
|
Complete pipeline for processing chunks: search, filter, clean, and format. |
|
|
|
Args: |
|
query (str): The search query |
|
collections (List[str]): List of collection names to search |
|
n_results (int): Maximum number of results to return |
|
n_sources (int): Maximum number of unique sources to include |
|
where (dict, optional): Additional filtering criteria |
|
**kwargs: Additional arguments to pass to search |
|
|
|
Returns: |
|
ChunkSearchResults: Processed chunks with Chroma IDs |
|
""" |
|
combined_chunks = [] |
|
|
|
if isinstance(collections, str): |
|
collections = [collections] |
|
|
|
# Search all collections |
|
|
|
for collection in collections: |
|
chunks = self.search( |
|
query=query, |
|
collection=collection, |
|
n_results=n_results, |
|
n_sources=n_sources, |
|
where=where, |
|
format_results=True, |
|
**kwargs, |
|
) |
|
|
|
|
|
for chunk in chunks: |
|
combined_chunks.append({ |
|
"document": chunk["document"], |
|
"metadata": chunk["metadata"], |
|
"distance": chunk["distance"], |
|
"id": chunk["id"], |
|
}) |
|
|
|
# Sort and filter results |
|
combined_chunks.sort(key=lambda x: x["distance"]) |
|
|
|
# Filter by unique sources and backfill |
|
closest_chunks, remaining_chunks = self.filter_by_unique_sources( |
|
combined_chunks, n_sources |
|
) |
|
closest_chunks = self.backfill_results( |
|
closest_chunks, remaining_chunks, n_results |
|
) |
|
|
|
# Clean text |
|
closest_chunks = self.clean_result_text(closest_chunks) |
|
return closest_chunks |
|
|
|
|
|
def add_document(self, _id, collection: str, document: str, metadata: dict = None): |
|
""" |
|
Adds a single document to a specified collection in the database. |
|
|
|
Args: |
|
_id (str): Arango ID for the document, used as a unique identifier. |
|
collection (str): The name of the collection to add the document to. |
|
document (str): The document text to be added. |
|
metadata (dict, optional): Metadata to be associated with the document. Defaults to None. |
|
|
|
Returns: |
|
None |
|
""" |
|
col = self.db.get_or_create_collection(collection) |
|
if metadata is None: |
|
metadata = {} |
|
col.add(ids=[_id], documents=[document], metadatas=[metadata]) |
|
|
|
def add_chunks(self, collection: str, chunks: list, _key, metadata: dict = None): |
|
""" |
|
Adds chunks to a specified collection in the database. |
|
|
|
Args: |
|
collection (str): The name of the collection to add chunks to. |
|
chunks (list): A list of chunks to be added to the collection. |
|
_key: A key used to generate unique IDs for the chunks. |
|
metadata (dict, optional): Metadata to be associated with each chunk. Defaults to None. |
|
|
|
Returns: |
|
None |
|
""" |
|
col = self.db.get_or_create_collection(collection) |
|
ids = [] |
|
metadatas = [] |
|
for number in chunks: |
|
if metadata: |
|
metadata["number"] = number |
|
metadatas.append(metadata) |
|
else: |
|
metadatas.append({}) |
|
ids.append(f"{_key}_{number}") |
|
col.add(ids=ids, metadatas=metadatas, documents=chunks) |
|
|
|
def get_collection(self, collection: str) -> chromadb.Collection: |
|
""" |
|
Retrieves a collection from the database. |
|
|
|
Args: |
|
collection (str): The name of the collection to retrieve. |
|
|
|
Returns: |
|
chromadb.Collection: The requested collection. |
|
""" |
|
return self.db.get_or_create_collection(collection) |
|
|
|
def is_reference_chunk(text: str) -> bool: |
|
""" |
|
Determine if a text chunk primarily consists of academic references. |
|
|
|
Args: |
|
text (str): Text chunk to analyze |
|
|
|
Returns: |
|
bool: True if the chunk appears to be mainly references |
|
""" |
|
# Count significant reference indicators |
|
indicators = 0 |
|
|
|
# Check for DOI links (very strong indicator) |
|
doi_matches = len(re.findall(r'https?://doi\.org/10\.\d+/\S+', text)) |
|
if doi_matches >= 2: # Multiple DOIs almost certainly means references |
|
return True |
|
elif doi_matches == 1: |
|
indicators += 3 |
|
|
|
# Check for citation patterns with year, volume, pages (e.g., 2018;178:551–60) |
|
citation_patterns = len(re.findall(r'\d{4};\d+:\d+[-–]\d+', text)) |
|
indicators += citation_patterns * 2 |
|
|
|
# Check for year patterns in brackets [YYYY] |
|
year_brackets = len(re.findall(r'\[\d{4}\]', text)) |
|
indicators += year_brackets |
|
|
|
# Check for multiple lines starting with author name patterns |
|
lines = [line.strip() for line in text.split('\n') if line.strip()] |
|
author_started_lines = 0 |
|
|
|
for line in lines: |
|
# Common pattern in references: starts with Author Name(s) |
|
if re.match(r'^\s*[A-Z][a-z]+\s+[A-Z][a-z]+', line): |
|
author_started_lines += 1 |
|
|
|
# If multiple lines start with author names (common in reference lists) |
|
if author_started_lines >= 2: |
|
indicators += 2 |
|
|
|
# Check for academic reference terms |
|
if re.search(r'\bet al\b|\bet al\.\b', text, re.IGNORECASE): |
|
indicators += 1 |
|
|
|
# Return True if we have sufficient indicators |
|
return indicators >= 4 # Adjust threshold as needed |
|
|
|
if __name__ == "__main__": |
|
from colorprinter.print_color import * |
|
|
|
|
|
chroma = ChromaDB() |
|
print(chroma.db.list_collections()) |
|
print('DB', chroma.db.database) |
|
print('SETTINGS', chroma.db.get_version()) |
|
|
|
result = chroma.search_chunks( |
|
query="What is Open Science)", |
|
collections="lasse__other_documents", |
|
n_results=2, |
|
n_sources=3, |
|
max_retries=4, |
|
) |
|
|
|
collection = chroma.db.get_or_create_collection("lasse__other_documents") |
|
result = collection.query( |
|
query_texts="What is Open Science?", |
|
n_results=2, |
|
) |
|
from pprint import pprint |
|
pprint(result) |
|
#print_rainbow(result["metadatas"][0])
|
|
|