parent
111bd04e82
commit
74d2d700b4
2 changed files with 131 additions and 17 deletions
@ -0,0 +1,112 @@ |
||||
import os |
||||
import re |
||||
from typing import Optional, Dict, List, Any |
||||
|
||||
from dotenv import load_dotenv |
||||
from arangoasync import ArangoClient |
||||
from arangoasync.auth import Auth |
||||
|
||||
load_dotenv() |
||||
if "ARANGO_HOSTS" not in os.environ: |
||||
import env_manager |
||||
env_manager.set_env() |
||||
|
||||
|
||||
class Arango: |
||||
""" |
||||
Async wrapper for python-arango-async. |
||||
|
||||
Usage (preferred, ensures cleanup): |
||||
async with Arango() as ar: |
||||
await ar.create_collection("my_col") |
||||
docs = await ar.execute_aql("FOR d IN my_col RETURN d") |
||||
|
||||
Or manual: |
||||
ar = Arango() |
||||
await ar.connect() |
||||
... |
||||
await ar.close() |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
db_name: Optional[str] = None, |
||||
username: Optional[str] = None, |
||||
password: Optional[str] = None, |
||||
hosts: Optional[str] = None, |
||||
): |
||||
self.hosts = hosts or os.environ.get("ARANGO_HOSTS") |
||||
self.db_name = db_name or os.environ.get("ARANGO_DB") |
||||
self.username = username or os.environ.get("ARANGO_USERNAME") |
||||
self.password = password or os.environ.get("ARANGO_PWD") |
||||
|
||||
self._client: Optional[ArangoClient] = None |
||||
self._db = None |
||||
self._auth = Auth(username=self.username, password=self.password) |
||||
|
||||
# context manager support |
||||
async def __aenter__(self): |
||||
await self.connect() |
||||
return self |
||||
|
||||
async def __aexit__(self, exc_type, exc, tb): |
||||
await self.close() |
||||
|
||||
async def connect(self): |
||||
""" |
||||
Create the ArangoClient and connect to the configured database. |
||||
Must be called before doing DB operations if not using `async with`. |
||||
""" |
||||
# create client (no network call yet) |
||||
self._client = ArangoClient(hosts=self.hosts) |
||||
|
||||
# IMPORTANT: client.db(...) is an async call and must be awaited, |
||||
# and you should pass auth if your server requires authentication. |
||||
self._db = await self._client.db(self.db_name, auth=self._auth) |
||||
return self |
||||
|
||||
async def close(self): |
||||
"""Close the underlying client (releases connections).""" |
||||
if self._client: |
||||
await self._client.close() |
||||
self._client = None |
||||
self._db = None |
||||
|
||||
def fix_key(self, _key: str) -> str: |
||||
return re.sub(r"[^A-Za-z0-9_\-\.@()+=;\$!*\'%:]", "_", _key) |
||||
|
||||
async def execute_aql( |
||||
self, |
||||
query: str, |
||||
bind_vars: Optional[Dict[str, Any]] = None, |
||||
batch_size: Optional[int] = None, |
||||
) -> List[Dict[str, Any]]: |
||||
""" |
||||
Execute AQL and return a list of documents. |
||||
Uses async cursor pattern from the docs. |
||||
""" |
||||
await self.connect() |
||||
assert self._db is not None, "call connect() or use `async with` first" |
||||
|
||||
# You may choose to use `async with await self._db.aql.execute(query) as cursor:` |
||||
# or just `cursor = await self._db.aql.execute(...)` and then 'async for' over it. |
||||
print(query) |
||||
cursor = await self._db.aql.execute(query, bind_vars=bind_vars or {}, batch_size=batch_size) |
||||
results: List[Dict[str, Any]] = [] |
||||
|
||||
# Using `async with` ensures the cursor is cleaned up server-side when done. |
||||
async with cursor: |
||||
async for doc in cursor: |
||||
results.appe |
||||
|
||||
|
||||
# Example usage |
||||
async def main(): |
||||
arango = Arango() |
||||
results = await arango.execute_aql("FOR doc IN talks LIMIT 2 RETURN doc") |
||||
print(results) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
import asyncio |
||||
asyncio.run(main()) |
||||
Loading…
Reference in new issue