You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
3.4 KiB
112 lines
3.4 KiB
import os |
|
import re |
|
from typing import Optional, Dict, List, Any |
|
|
|
from dotenv import load_dotenv |
|
from arangoasync import ArangoClient |
|
from arangoasync.auth import Auth |
|
|
|
load_dotenv() |
|
if "ARANGO_HOSTS" not in os.environ: |
|
import env_manager |
|
env_manager.set_env() |
|
|
|
|
|
class Arango: |
|
""" |
|
Async wrapper for python-arango-async. |
|
|
|
Usage (preferred, ensures cleanup): |
|
async with Arango() as ar: |
|
await ar.create_collection("my_col") |
|
docs = await ar.execute_aql("FOR d IN my_col RETURN d") |
|
|
|
Or manual: |
|
ar = Arango() |
|
await ar.connect() |
|
... |
|
await ar.close() |
|
""" |
|
|
|
def __init__( |
|
self, |
|
db_name: Optional[str] = None, |
|
username: Optional[str] = None, |
|
password: Optional[str] = None, |
|
hosts: Optional[str] = None, |
|
): |
|
self.hosts = hosts or os.environ.get("ARANGO_HOSTS") |
|
self.db_name = db_name or os.environ.get("ARANGO_DB") |
|
self.username = username or os.environ.get("ARANGO_USERNAME") |
|
self.password = password or os.environ.get("ARANGO_PWD") |
|
|
|
self._client: Optional[ArangoClient] = None |
|
self._db = None |
|
self._auth = Auth(username=self.username, password=self.password) |
|
|
|
# context manager support |
|
async def __aenter__(self): |
|
await self.connect() |
|
return self |
|
|
|
async def __aexit__(self, exc_type, exc, tb): |
|
await self.close() |
|
|
|
async def connect(self): |
|
""" |
|
Create the ArangoClient and connect to the configured database. |
|
Must be called before doing DB operations if not using `async with`. |
|
""" |
|
# create client (no network call yet) |
|
self._client = ArangoClient(hosts=self.hosts) |
|
|
|
# IMPORTANT: client.db(...) is an async call and must be awaited, |
|
# and you should pass auth if your server requires authentication. |
|
self._db = await self._client.db(self.db_name, auth=self._auth) |
|
return self |
|
|
|
async def close(self): |
|
"""Close the underlying client (releases connections).""" |
|
if self._client: |
|
await self._client.close() |
|
self._client = None |
|
self._db = None |
|
|
|
def fix_key(self, _key: str) -> str: |
|
return re.sub(r"[^A-Za-z0-9_\-\.@()+=;\$!*\'%:]", "_", _key) |
|
|
|
async def execute_aql( |
|
self, |
|
query: str, |
|
bind_vars: Optional[Dict[str, Any]] = None, |
|
batch_size: Optional[int] = None, |
|
) -> List[Dict[str, Any]]: |
|
""" |
|
Execute AQL and return a list of documents. |
|
Uses async cursor pattern from the docs. |
|
""" |
|
await self.connect() |
|
assert self._db is not None, "call connect() or use `async with` first" |
|
|
|
# You may choose to use `async with await self._db.aql.execute(query) as cursor:` |
|
# or just `cursor = await self._db.aql.execute(...)` and then 'async for' over it. |
|
print(query) |
|
cursor = await self._db.aql.execute(query, bind_vars=bind_vars or {}, batch_size=batch_size) |
|
results: List[Dict[str, Any]] = [] |
|
|
|
# Using `async with` ensures the cursor is cleaned up server-side when done. |
|
async with cursor: |
|
async for doc in cursor: |
|
results.appe |
|
|
|
|
|
# Example usage |
|
async def main(): |
|
arango = Arango() |
|
results = await arango.execute_aql("FOR doc IN talks LIMIT 2 RETURN doc") |
|
print(results) |
|
|
|
|
|
if __name__ == "__main__": |
|
import asyncio |
|
asyncio.run(main()) |