parent
638e2a00d3
commit
5ee1a062f1
14 changed files with 1704 additions and 2185 deletions
@ -1,440 +0,0 @@ |
||||
import os |
||||
import base64 |
||||
import re |
||||
import json |
||||
from typing import Any, Callable, Iterator, Literal, Mapping, Optional, Sequence, Union |
||||
|
||||
import tiktoken |
||||
from ollama import Client, AsyncClient, ResponseError, ChatResponse, Message, Tool, Options |
||||
from ollama._types import JsonSchemaValue, ChatRequest |
||||
|
||||
import env_manager |
||||
from colorprinter.print_color import * |
||||
|
||||
env_manager.set_env() |
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base") |
||||
|
||||
# Define a base class for common functionality |
||||
class BaseClient: |
||||
def chat( |
||||
self, |
||||
model: str = '', |
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, |
||||
*, |
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, |
||||
stream: bool = False, |
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, |
||||
options: Optional[Union[Mapping[str, Any], Options]] = None, |
||||
keep_alive: Optional[Union[float, str]] = None, |
||||
) -> Union[ChatResponse, Iterator[ChatResponse]]: |
||||
return self._request( |
||||
ChatResponse, |
||||
'POST', |
||||
'/api/chat', |
||||
json=ChatRequest( |
||||
model=model, |
||||
messages=[message for message in messages or []], |
||||
tools=[tool for tool in tools or []], |
||||
stream=stream, |
||||
format=format, |
||||
options=options, |
||||
keep_alive=keep_alive, |
||||
).model_dump(exclude_none=True), |
||||
stream=stream, |
||||
) |
||||
|
||||
# Define your custom MyAsyncClient class |
||||
class MyAsyncClient(AsyncClient, BaseClient): |
||||
async def _request(self, response_type, method, path, headers=None, **kwargs): |
||||
# Merge default headers with per-call headers |
||||
all_headers = {**self._client.headers, **(headers or {})} |
||||
|
||||
# Handle streaming separately |
||||
if kwargs.get('stream'): |
||||
kwargs.pop('stream') |
||||
async with self._client.stream(method, path, headers=all_headers, **kwargs) as response: |
||||
self.last_response = response # Store the response object |
||||
if response.status_code >= 400: |
||||
await response.aread() |
||||
raise ResponseError(response.text, response.status_code) |
||||
return self._stream(response_type, response) |
||||
else: |
||||
# Make the HTTP request with the combined headers |
||||
kwargs.pop('stream') |
||||
response = await self._request_raw(method, path, headers=all_headers, **kwargs) |
||||
self.last_response = response # Store the response object |
||||
|
||||
if response.status_code >= 400: |
||||
raise ResponseError(response.text, response.status_code) |
||||
return response_type.model_validate_json(response.content) |
||||
|
||||
async def chat( |
||||
self, |
||||
model: str = '', |
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, |
||||
*, |
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, |
||||
stream: bool = False, |
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, |
||||
options: Optional[Union[Mapping[str, Any], Options]] = None, |
||||
keep_alive: Optional[Union[float, str]] = None, |
||||
) -> Union[ChatResponse, Iterator[ChatResponse]]: |
||||
return await self._request( |
||||
ChatResponse, |
||||
'POST', |
||||
'/api/chat', |
||||
json=ChatRequest( |
||||
model=model, |
||||
messages=[message for message in messages or []], |
||||
tools=[tool for tool in tools or []], |
||||
stream=stream, |
||||
format=format, |
||||
options=options, |
||||
keep_alive=keep_alive, |
||||
).model_dump(exclude_none=True), |
||||
stream=stream, |
||||
) |
||||
|
||||
# Define your custom MyClient class |
||||
class MyClient(Client, BaseClient): |
||||
def _request(self, response_type, method, path, headers=None, **kwargs): |
||||
# Merge default headers with per-call headers |
||||
all_headers = {**self._client.headers, **(headers or {})} |
||||
|
||||
# Handle streaming separately |
||||
if kwargs.get('stream'): |
||||
kwargs.pop('stream') |
||||
with self._client.stream(method, path, headers=all_headers, **kwargs) as response: |
||||
self.last_response = response # Store the response object |
||||
if response.status_code >= 400: |
||||
raise ResponseError(response.text, response.status_code) |
||||
return self._stream(response_type, response) |
||||
else: |
||||
# Make the HTTP request with the combined headers |
||||
kwargs.pop('stream') |
||||
response = self._request_raw(method, path, headers=all_headers, **kwargs) |
||||
self.last_response = response # Store the response object |
||||
|
||||
if response.status_code >= 400: |
||||
raise ResponseError(response.text, response.status_code) |
||||
return response_type.model_validate_json(response.content) |
||||
|
||||
class LLM: |
||||
""" |
||||
LLM class for interacting with a language model. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
system_message="You are an assistant.", |
||||
temperature=0.01, |
||||
model: Optional[Literal["small", "standard", "vision"]] = "standard", |
||||
max_length_answer=4096, |
||||
messages=None, |
||||
chat=True, |
||||
chosen_backend=None, |
||||
) -> None: |
||||
|
||||
self.model = self.get_model(model) |
||||
self.system_message = system_message |
||||
self.options = {"temperature": temperature} |
||||
self.messages = messages or [{"role": "system", "content": self.system_message}] |
||||
self.max_length_answer = max_length_answer |
||||
self.chat = chat |
||||
self.chosen_backend = chosen_backend |
||||
|
||||
# Initialize the client with the host and default headers |
||||
credentials = f"{os.getenv('LLM_API_USER')}:{os.getenv('LLM_API_PWD_LASSE')}" |
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode() |
||||
default_headers = { |
||||
"Authorization": f"Basic {encoded_credentials}", |
||||
} |
||||
host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") |
||||
self.client = MyClient(host=host_url, headers=default_headers) |
||||
self.async_client = MyAsyncClient(host=host_url, headers=default_headers) |
||||
|
||||
def get_model(self, model_alias): |
||||
models = { |
||||
"standard": "LLM_MODEL", |
||||
"small": "LLM_MODEL_SMALL", |
||||
"vision": "LLM_MODEL_VISION", |
||||
"standard_64k": "LLM_MODEL_64K", |
||||
} |
||||
return os.getenv(models.get(model_alias, "LLM_MODEL")) |
||||
|
||||
def count_tokens(self): |
||||
num_tokens = 0 |
||||
for i in self.messages: |
||||
for k, v in i.items(): |
||||
if k == "content": |
||||
if not isinstance(v, str): |
||||
v = str(v) |
||||
tokens = tokenizer.encode(v) |
||||
num_tokens += len(tokens) |
||||
return int(num_tokens) |
||||
|
||||
def generate( |
||||
self, |
||||
query: str = None, |
||||
user_input: str = None, |
||||
context: str = None, |
||||
stream: bool = False, |
||||
tools: list = None, |
||||
function_call: dict = None, |
||||
images: list = None, |
||||
model: Optional[Literal["small", "standard", "vision"]] = None, |
||||
temperature: float = None, |
||||
): |
||||
""" |
||||
Generates a response from the language model based on the provided inputs. |
||||
""" |
||||
|
||||
# Prepare the model and temperature |
||||
model = self.get_model(model) if model else self.model |
||||
temperature = temperature if temperature else self.options["temperature"] |
||||
|
||||
# Normalize whitespace and add the query to the messages |
||||
query = re.sub(r"\s*\n\s*", "\n", query) |
||||
message = {"role": "user", "content": query} |
||||
|
||||
# Handle images if any |
||||
if images: |
||||
import base64 |
||||
|
||||
base64_images = [] |
||||
base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$") |
||||
|
||||
for image in images: |
||||
if isinstance(image, str): |
||||
if base64_pattern.match(image): |
||||
base64_images.append(image) |
||||
else: |
||||
with open(image, "rb") as image_file: |
||||
base64_images.append( |
||||
base64.b64encode(image_file.read()).decode("utf-8") |
||||
) |
||||
elif isinstance(image, bytes): |
||||
base64_images.append(base64.b64encode(image).decode("utf-8")) |
||||
else: |
||||
print_red("Invalid image type") |
||||
|
||||
message["images"] = base64_images |
||||
# Use the vision model |
||||
model = self.get_model("vision") |
||||
|
||||
self.messages.append(message) |
||||
|
||||
# Prepare headers |
||||
headers = {} |
||||
if self.chosen_backend: |
||||
headers["X-Chosen-Backend"] = self.chosen_backend |
||||
|
||||
if model == self.get_model("small"): |
||||
headers["X-Model-Type"] = "small" |
||||
|
||||
# Prepare options |
||||
options = Options(**self.options) |
||||
options.temperature = temperature |
||||
|
||||
# Prepare tools if any |
||||
if tools: |
||||
tools = [ |
||||
Tool(**tool) if isinstance(tool, dict) else tool |
||||
for tool in tools |
||||
] |
||||
|
||||
# Adjust the options for long messages |
||||
if self.chat or len(self.messages) > 15000: |
||||
num_tokens = self.count_tokens() + self.max_length_answer // 2 |
||||
if num_tokens > 8000: |
||||
model = self.get_model("standard_64k") |
||||
headers["X-Model-Type"] = "large" |
||||
|
||||
# Call the client.chat method |
||||
try: |
||||
response = self.client.chat( |
||||
model=model, |
||||
messages=self.messages, |
||||
headers=headers, |
||||
tools=tools, |
||||
stream=stream, |
||||
options=options, |
||||
keep_alive=3600 * 24 * 7, |
||||
) |
||||
except ResponseError as e: |
||||
print_red("Error!") |
||||
print(e) |
||||
return "An error occurred." |
||||
|
||||
# If user_input is provided, update the last message |
||||
if user_input: |
||||
if context: |
||||
if len(context) > 2000: |
||||
context = self.make_summary(context) |
||||
user_input = ( |
||||
f"{user_input}\n\nUse the information below to answer the question.\n" |
||||
f'"""{context}"""\n[This is a summary of the context provided in the original message.]' |
||||
) |
||||
system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message." |
||||
if system_message_info not in self.messages[0]["content"]: |
||||
self.messages[0]["content"] += system_message_info |
||||
self.messages[-1] = {"role": "user", "content": user_input} |
||||
|
||||
self.chosen_backend = self.client.last_response.headers.get("X-Chosen-Backend") |
||||
|
||||
# Handle streaming response |
||||
if stream: |
||||
return self.read_stream(response) |
||||
else: |
||||
# Process the response |
||||
if isinstance(response, ChatResponse): |
||||
result = response.message.content.strip('"') |
||||
self.messages.append({"role": "assistant", "content": result.strip('"')}) |
||||
if tools and not response.message.get("tool_calls"): |
||||
print_yellow("No tool calls in response".upper()) |
||||
if not self.chat: |
||||
self.messages = [self.messages[0]] |
||||
return result |
||||
else: |
||||
print_red("Unexpected response type") |
||||
return "An error occurred." |
||||
|
||||
def make_summary(self, text): |
||||
# Implement your summary logic using self.client.chat() |
||||
summary_message = { |
||||
"role": "user", |
||||
"content": f'Summarize the text below:\n"""{text}"""\nRemember to be concise and detailed. Answer in English.', |
||||
} |
||||
messages = [ |
||||
{"role": "system", "content": "You are summarizing a text. Make it detailed and concise. Answer ONLY with the summary. Don't add any new information."}, |
||||
summary_message, |
||||
] |
||||
try: |
||||
response = self.client.chat( |
||||
model=self.get_model("small"), |
||||
messages=messages, |
||||
options=Options(temperature=0.01), |
||||
keep_alive=3600 * 24 * 7, |
||||
) |
||||
summary = response.message.content.strip() |
||||
print_blue("Summary:", summary) |
||||
return summary |
||||
except ResponseError as e: |
||||
print_red("Error generating summary:", e) |
||||
return "Summary generation failed." |
||||
|
||||
def read_stream(self, response): |
||||
# Implement streaming response handling if needed |
||||
buffer = "" |
||||
message = "" |
||||
first_chunk = True |
||||
prev_content = None |
||||
for chunk in response: |
||||
if chunk: |
||||
content = chunk.message.content |
||||
if first_chunk and content.startswith('"'): |
||||
content = content[1:] |
||||
first_chunk = False |
||||
|
||||
if chunk.done: |
||||
if prev_content and prev_content.endswith('"'): |
||||
prev_content = prev_content[:-1] |
||||
if prev_content: |
||||
yield prev_content |
||||
break |
||||
else: |
||||
if prev_content: |
||||
yield prev_content |
||||
prev_content = content |
||||
self.messages.append({"role": "assistant", "content": message.strip('"')}) |
||||
|
||||
async def async_generate( |
||||
self, |
||||
query: str = None, |
||||
user_input: str = None, |
||||
context: str = None, |
||||
stream: bool = False, |
||||
tools: list = None, |
||||
function_call: dict = None, |
||||
images: list = None, |
||||
model: Optional[Literal["small", "standard", "vision"]] = None, |
||||
temperature: float = None, |
||||
): |
||||
""" |
||||
Asynchronous method to generate a response from the language model. |
||||
""" |
||||
|
||||
# Prepare the model and temperature |
||||
model = self.get_model(model) if model else self.model |
||||
temperature = temperature if temperature else self.options["temperature"] |
||||
|
||||
# Normalize whitespace and add the query to the messages |
||||
query = re.sub(r"\s*\n\s*", "\n", query) |
||||
message = {"role": "user", "content": query} |
||||
|
||||
# Handle images if any |
||||
if images: |
||||
# (Image handling code as in the generate method) |
||||
... |
||||
|
||||
self.messages.append(message) |
||||
|
||||
# Prepare headers |
||||
headers = {} |
||||
if self.chosen_backend: |
||||
headers["X-Chosen-Backend"] = self.chosen_backend |
||||
|
||||
if model == self.get_model("small"): |
||||
headers["X-Model-Type"] = "small" |
||||
|
||||
# Prepare options |
||||
options = Options(**self.options) |
||||
options.temperature = temperature |
||||
|
||||
# Prepare tools if any |
||||
if tools: |
||||
tools = [ |
||||
Tool(**tool) if isinstance(tool, dict) else tool |
||||
for tool in tools |
||||
] |
||||
|
||||
# Adjust options for long messages |
||||
# (Adjustments as needed) |
||||
... |
||||
|
||||
# Call the async client's chat method |
||||
try: |
||||
response = await self.async_client.chat( |
||||
model=model, |
||||
messages=self.messages, |
||||
tools=tools, |
||||
stream=stream, |
||||
options=options, |
||||
keep_alive=3600 * 24 * 7, |
||||
) |
||||
except ResponseError as e: |
||||
print_red("Error!") |
||||
print(e) |
||||
return "An error occurred." |
||||
|
||||
# Process the response |
||||
if isinstance(response, ChatResponse): |
||||
result = response.message.content.strip('"') |
||||
self.messages.append({"role": "assistant", "content": result.strip('"')}) |
||||
return result |
||||
else: |
||||
print_red("Unexpected response type") |
||||
return "An error occurred." |
||||
|
||||
# Usage example |
||||
if __name__ == "__main__": |
||||
import asyncio |
||||
|
||||
llm = LLM() |
||||
|
||||
async def main(): |
||||
result = await llm.async_generate(query="Hello, how are you?") |
||||
print(result) |
||||
|
||||
asyncio.run(main()) |
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,85 @@ |
||||
|
||||
|
||||
def create_plan_questions(agent, question): |
||||
query = f""" |
||||
A journalist wants to get a report that answers this question: "{question}" |
||||
THIS IS *NOT* A QUESTION YOU CAN ANSWER! Instead, you need to split it into multiple questions that can be answered through research. |
||||
The questions should be specific and focused on a single aspect of the topic. |
||||
For example, if the question is "What are the effects of climate change on agriculture?", you could split it into: |
||||
- How does temperature change affect crop yields? |
||||
- What are the impacts of changing rainfall patterns on agriculture? |
||||
- How does increased CO2 levels affect plant growth? |
||||
""" |
||||
|
||||
# Add project notes summary if available |
||||
if agent.project and hasattr(agent.project, "notes_summary"): |
||||
query += f'''\nTo help you understand the subject, here is a summary of notes the journalist has done: \n"""{agent.project.notes_summary}\n"""\n''' |
||||
|
||||
query += """ |
||||
Answer ONLY with the questions you have divided the original question into, not the answers to them (this will be done using research in a later step). |
||||
If the original question asked by the journalist is already specific, you can keep it as is. |
||||
Answer in a structured format with each of your question on a new line. |
||||
""" |
||||
return query |
||||
|
||||
def create_plan(agent, question): |
||||
""" |
||||
This function creates a research plan for answering a given question. It should be used after create_plan_questions and be in the same chat. |
||||
""" |
||||
available_sources_str = '' |
||||
for source, count in agent.available_sources.items(): |
||||
if source == 'scientific articles': |
||||
available_sources_str += f'- Scientific articles the journalist has gathered. Number of articles: {count}\n' |
||||
elif source == 'other articles': |
||||
available_sources_str += f'- Other articles the journalists has gathered, such as blog posts, news articles, etc. Number of articles: {count}\n' |
||||
elif source == 'notes': |
||||
available_sources_str += f'- The journalists own notes. Number of notes: {count}\n' |
||||
elif source == 'transcribed interviews': |
||||
available_sources_str += f'- Transcribed interviews (already done, you can\'t produce new ones). Number of interviews: {count}\n' |
||||
available_sources_str += '- An analyzing tool that can analyze the information you gather.\n' |
||||
|
||||
|
||||
query = f""" |
||||
Thanks! Now, create a research plan for answering the original question: "{question.replace('"', "'")}". |
||||
Include the questions you just created and any additional steps needed to answer the original question. |
||||
Include what type of information you need from what available sources. |
||||
|
||||
*Available sources are:* |
||||
{available_sources_str} |
||||
|
||||
All of the above sources are available in a database/LLM model, but you need to specify what you need. Be as precise as possible. |
||||
|
||||
You are working in a limited context and can't access the internet or external databases, and some "best practices" might not apply, like cross-referencing sources. Therefore, make the plan basic, easy to follow and with the available sources in mind. |
||||
|
||||
*IMPORTANT! Each step should try to answer one or many of the questions you created, an result in a summary of the information you found.* |
||||
|
||||
*Please structure the plan like:* |
||||
## Step 1: |
||||
- Task1: Description of task and outcome |
||||
- Task2: Description of task and outcome |
||||
## Step 2: |
||||
- Task1: Description of task and outcome |
||||
Etc, with as many steps and tasks as needed. |
||||
Do NOT include the writiong of the report as a step, ONLY the tasks needed to gather information. The report will be written in a later step. |
||||
|
||||
*Example of a plan:* |
||||
''' |
||||
Question: "What are the effects of climate change on agriculture?" |
||||
## Step 1: Read the notes |
||||
- Task1: Read the notes and pick out the most relevant information for the question. |
||||
- Task2: Summarize the information in a structured format. Try to formulate a hypothesis based on the notes and the question. |
||||
## Step 2: Read scientific articles |
||||
- Task1: Search for scientific articles to find information about the effects of climate change on agriculture. Use the information from the first step along with the question to formulate search queries. |
||||
- Task2: Read the articles and summarize the information in a structured format. Kepp the focus on the information that is relevant for the question. |
||||
## Step 3: Analyze the information |
||||
- Task1: Use the analyzing tool to analyze the information you gathered in the previous steps. Try to find patterns and connections between the different sources. |
||||
- Task2: From the information you gathered, and in regard to the question, is there any information that contradicts each other? If so, try to find out why. Is it because of the sources, or is it because of the information itself? |
||||
## Step 4: Read other articles |
||||
- Task1: Search for other articles to find information about the effects of climate change on agriculture. |
||||
- Task2: Read the articles and summarize the information in a structured format. Pick out some interesting facts that are related to what you found in the scientific articles (if there are any). |
||||
''' |
||||
|
||||
The example above is just an example, you can use other steps and tasks that are more relevant for the question. |
||||
""" |
||||
|
||||
return query |
||||
@ -0,0 +1,369 @@ |
||||
import requests |
||||
import json |
||||
import argparse |
||||
from typing import Optional, List, Literal, Union |
||||
from colorprinter.print_color import * |
||||
|
||||
|
||||
def search_semantic_scholar( |
||||
query: str, |
||||
limit: int = 10, |
||||
fields: Optional[List[str]] = None, |
||||
publication_types: Optional[ |
||||
List[ |
||||
Literal[ |
||||
"Review", |
||||
"JournalArticle", |
||||
"CaseReport", |
||||
"ClinicalTrial", |
||||
"Conference", |
||||
"Dataset", |
||||
"Editorial", |
||||
"LettersAndComments", |
||||
"MetaAnalysis", |
||||
"News", |
||||
"Study", |
||||
"Book", |
||||
"BookSection", |
||||
] |
||||
] |
||||
] = ["JournalArticle"], |
||||
open_access: bool = False, |
||||
min_citation_count: Optional[int] = None, |
||||
date_range: Optional[str] = None, |
||||
year_range: Optional[str] = None, |
||||
fields_of_study: Optional[ |
||||
List[ |
||||
Literal[ |
||||
"Computer Science", |
||||
"Medicine", |
||||
"Chemistry", |
||||
"Biology", |
||||
"Materials Science", |
||||
"Physics", |
||||
"Geology", |
||||
"Psychology", |
||||
"Art", |
||||
"History", |
||||
"Geography", |
||||
"Sociology", |
||||
"Business", |
||||
"Political Science", |
||||
"Economics", |
||||
"Philosophy", |
||||
"Mathematics", |
||||
"Engineering", |
||||
"Environmental Science", |
||||
"Agricultural and Food Sciences", |
||||
"Education", |
||||
"Law", |
||||
"Linguistics", |
||||
] |
||||
] |
||||
] = None, |
||||
): |
||||
""" |
||||
Search for papers on Semantic Scholar with various filters. |
||||
|
||||
Parameters: |
||||
----------- |
||||
query : str |
||||
The search query term |
||||
limit : int |
||||
Number of results to return (max 100) |
||||
fields : List[str], optional |
||||
List of fields to include in the response |
||||
publication_types : List[str], optional |
||||
Filter by publication types |
||||
open_access : bool |
||||
Only include papers with open access PDFs |
||||
min_citation_count : int, optional |
||||
Minimum number of citations |
||||
date_range : str, optional |
||||
Date range in format "YYYY-MM-DD:YYYY-MM-DD" |
||||
year_range : str, optional |
||||
Year range in format "YYYY-YYYY" or "YYYY-" or "-YYYY" |
||||
fields_of_study : List[str], optional |
||||
List of fields of study to filter by |
||||
|
||||
Returns: |
||||
-------- |
||||
dict |
||||
JSON response containing search results |
||||
""" |
||||
# Define the API endpoint URL |
||||
url = "https://api.semanticscholar.org/graph/v1/paper/search" |
||||
|
||||
# Set up default fields if not provided |
||||
if fields is None: |
||||
fields = [ |
||||
"title", |
||||
"url", |
||||
"abstract", |
||||
"year", |
||||
"publicationDate", |
||||
"authors.name", |
||||
"citationCount", |
||||
"openAccessPdf", |
||||
"tldr", |
||||
] |
||||
|
||||
# Build query parameters |
||||
params = {"query": query, "limit": limit, "fields": ",".join(fields)} |
||||
|
||||
# Add optional filters if provided |
||||
if publication_types: |
||||
params["publicationTypes"] = ",".join(publication_types) |
||||
|
||||
if open_access: |
||||
params["openAccessPdf"] = "" |
||||
|
||||
if min_citation_count: |
||||
params["minCitationCount"] = str(min_citation_count) |
||||
|
||||
if date_range: |
||||
params["publicationDateOrYear"] = date_range |
||||
|
||||
if year_range: |
||||
params["year"] = year_range |
||||
|
||||
if fields_of_study: |
||||
params["fieldsOfStudy"] = ",".join(fields_of_study) |
||||
|
||||
# Send the API request |
||||
try: |
||||
response = requests.get(url, params=params) |
||||
response.raise_for_status() # Raise an exception for HTTP errors |
||||
return response.json().get("data", []) |
||||
except requests.exceptions.HTTPError as e: |
||||
print(f"HTTP Error: {e}") |
||||
print(f"Response text: {response.text}") |
||||
return None |
||||
except requests.exceptions.RequestException as e: |
||||
print(f"Error: {e}") |
||||
return None |
||||
|
||||
|
||||
def main( |
||||
query: Optional[str] = None, |
||||
limit: int = 10, |
||||
fields: Optional[List[str]] = None, |
||||
publication_types: Optional[ |
||||
List[ |
||||
Literal[ |
||||
"Review", |
||||
"JournalArticle", |
||||
"CaseReport", |
||||
"ClinicalTrial", |
||||
"Conference", |
||||
"Dataset", |
||||
"Editorial", |
||||
"LettersAndComments", |
||||
"MetaAnalysis", |
||||
"News", |
||||
"Study", |
||||
"Book", |
||||
"BookSection", |
||||
] |
||||
] |
||||
] = None, |
||||
open_access: bool = False, |
||||
min_citation_count: Optional[int] = None, |
||||
date_range: Optional[str] = None, |
||||
year_range: Optional[str] = None, |
||||
fields_of_study: Optional[ |
||||
List[ |
||||
Literal[ |
||||
"Computer Science", |
||||
"Medicine", |
||||
"Chemistry", |
||||
"Biology", |
||||
"Materials Science", |
||||
"Physics", |
||||
"Geology", |
||||
"Psychology", |
||||
"Art", |
||||
"History", |
||||
"Geography", |
||||
"Sociology", |
||||
"Business", |
||||
"Political Science", |
||||
"Economics", |
||||
"Philosophy", |
||||
"Mathematics", |
||||
"Engineering", |
||||
"Environmental Science", |
||||
"Agricultural and Food Sciences", |
||||
"Education", |
||||
"Law", |
||||
"Linguistics", |
||||
] |
||||
] |
||||
] = None, |
||||
): |
||||
|
||||
# Search for papers |
||||
papers = search_semantic_scholar( |
||||
query=query, |
||||
limit=limit, |
||||
fields=fields, |
||||
publication_types=publication_types, |
||||
open_access=open_access, |
||||
min_citation_count=min_citation_count, |
||||
date_range=date_range, |
||||
year_range=year_range, |
||||
fields_of_study=fields_of_study, |
||||
) |
||||
|
||||
if not papers: |
||||
print("No results found or an error occurred.") |
||||
return |
||||
|
||||
# Print results |
||||
print_green(f"\nFound {len(papers)} papers matching your query: '{query}'") |
||||
|
||||
for paper in papers: |
||||
print(paper) |
||||
exit() |
||||
|
||||
|
||||
def search_paper_by_title( |
||||
title: str, |
||||
fields: Optional[List[str]] = None |
||||
): |
||||
""" |
||||
Search for a single paper that best matches the given title. |
||||
|
||||
Parameters: |
||||
----------- |
||||
title : str |
||||
The title to search for |
||||
fields : List[str], optional |
||||
List of fields to include in the response |
||||
|
||||
Returns: |
||||
-------- |
||||
dict or None |
||||
JSON data for the best matching paper, or None if no match or error |
||||
""" |
||||
# Define the API endpoint URL |
||||
url = "https://api.semanticscholar.org/graph/v1/paper/search/match" |
||||
|
||||
# Set up default fields if not provided |
||||
if fields is None: |
||||
fields = [ |
||||
"title", |
||||
"abstract", |
||||
"year", |
||||
"authors.name", |
||||
"externalIds", |
||||
"url", |
||||
"publicationDate", |
||||
"journal", |
||||
"citationCount", |
||||
"openAccessPdf" |
||||
] |
||||
|
||||
# Build query parameters |
||||
params = {"query": title, "fields": ",".join(fields)} |
||||
|
||||
# Send the API request |
||||
try: |
||||
response = requests.get(url, params=params) |
||||
response.raise_for_status() # Raise an exception for HTTP errors |
||||
return response.json() |
||||
except requests.exceptions.HTTPError as e: |
||||
if e.response.status_code == 404: |
||||
print(f"No paper found matching title: {title}") |
||||
return None |
||||
else: |
||||
print(f"HTTP Error: {e}") |
||||
print(f"Response text: {e.response.text}") |
||||
return None |
||||
except requests.exceptions.RequestException as e: |
||||
print(f"Error: {e}") |
||||
return None |
||||
|
||||
def get_paper_details( |
||||
paper_id: str, |
||||
fields: Optional[List[str]] = None |
||||
): |
||||
""" |
||||
Get detailed information about a paper by its identifier. |
||||
|
||||
Parameters: |
||||
----------- |
||||
paper_id : str |
||||
The paper identifier. Can be: |
||||
- Semantic Scholar ID (e.g., 649def34f8be52c8b66281af98ae884c09aef38b) |
||||
- DOI (e.g., DOI:10.18653/v1/N18-3011) |
||||
- arXiv ID (e.g., ARXIV:2106.15928) |
||||
- etc. |
||||
fields : List[str], optional |
||||
List of fields to include in the response |
||||
|
||||
Returns: |
||||
-------- |
||||
dict or None |
||||
JSON data for the paper, or None if not found or error |
||||
""" |
||||
# Define the API endpoint URL |
||||
url = f"https://api.semanticscholar.org/graph/v1/paper/{paper_id}" |
||||
|
||||
# Set up default fields if not provided |
||||
if fields is None: |
||||
fields = [ |
||||
"title", |
||||
"abstract", |
||||
"year", |
||||
"authors.name", |
||||
"externalIds", |
||||
"url", |
||||
"publicationDate", |
||||
"journal", |
||||
"citationCount", |
||||
"openAccessPdf" |
||||
] |
||||
|
||||
# Add DOI: prefix if it's a DOI without the prefix |
||||
if paper_id.startswith("10.") and "DOI:" not in paper_id: |
||||
paper_id = f"DOI:{paper_id}" |
||||
|
||||
# Build query parameters |
||||
params = {"fields": ",".join(fields)} |
||||
|
||||
# Send the API request |
||||
try: |
||||
response = requests.get(url, params=params) |
||||
response.raise_for_status() # Raise an exception for HTTP errors |
||||
return response.json() |
||||
except requests.exceptions.HTTPError as e: |
||||
if e.response.status_code == 404: |
||||
print(f"No paper found with ID: {paper_id}") |
||||
return None |
||||
else: |
||||
print(f"HTTP Error: {e}") |
||||
print(f"Response text: {e.response.text}") |
||||
return None |
||||
except requests.exceptions.RequestException as e: |
||||
print(f"Error: {e}") |
||||
return None |
||||
|
||||
if __name__ == "__main__": |
||||
main( |
||||
query="machine learning", |
||||
limit=1, |
||||
fields=[ |
||||
"title", |
||||
"url", |
||||
"abstract", |
||||
"tldr", |
||||
"externalIds", |
||||
"year", |
||||
"influentialCitationCount", |
||||
"fieldsOfStudy", |
||||
"publicationDate", |
||||
"journal", |
||||
], |
||||
open_access=True, |
||||
) |
||||
@ -1,345 +0,0 @@ |
||||
import os |
||||
import urllib |
||||
import streamlit as st |
||||
from _base_class import StreamlitBaseClass |
||||
import feedparser |
||||
import requests |
||||
from bs4 import BeautifulSoup |
||||
from urllib.parse import urljoin |
||||
from utils import fix_key |
||||
from colorprinter.print_color import * |
||||
from datetime import datetime, timedelta |
||||
|
||||
|
||||
class RSSFeedsPage(StreamlitBaseClass): |
||||
def __init__(self, username: str): |
||||
super().__init__(username=username) |
||||
self.page_name = "RSS Feeds" |
||||
|
||||
# Initialize attributes from session state if available |
||||
for k, v in st.session_state.get(self.page_name, {}).items(): |
||||
setattr(self, k, v) |
||||
|
||||
def run(self): |
||||
if "selected_feed" not in st.session_state: |
||||
st.session_state["selected_feed"] = None |
||||
self.update_current_page(self.page_name) |
||||
self.display_feed() |
||||
|
||||
self.sidebar_actions() |
||||
|
||||
# Persist state to session_state |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
def select_rss_feeds(self): |
||||
# Fetch RSS feeds from the user's ArangoDB collection |
||||
rss_feeds = self.get_rss_feeds() |
||||
if rss_feeds: |
||||
feed_options = [feed["title"] for feed in rss_feeds] |
||||
with st.sidebar: |
||||
st.subheader("Show your feeds") |
||||
selected_feed_title = st.selectbox( |
||||
"Select a feed", options=feed_options, index=None |
||||
) |
||||
if selected_feed_title: |
||||
st.session_state["selected_feed"] = [ |
||||
feed["_key"] |
||||
for feed in rss_feeds |
||||
if feed["title"] == selected_feed_title |
||||
][0] |
||||
st.rerun() |
||||
|
||||
else: |
||||
st.write("You have no RSS feeds added.") |
||||
|
||||
def get_rss_feeds(self): |
||||
return list(self.user_arango.db.collection("rss_feeds").all()) |
||||
|
||||
def sidebar_actions(self): |
||||
with st.sidebar: |
||||
# Select a feed to show |
||||
self.select_rss_feeds() |
||||
st.subheader("Add a New RSS Feed") |
||||
rss_url = st.text_input("Website URL or RSS Feed URL") |
||||
if st.button("Discover Feeds"): |
||||
if rss_url: |
||||
with st.spinner("Discovering feeds..."): |
||||
feeds = self.discover_feeds(rss_url) |
||||
if feeds: |
||||
st.session_state["discovered_feeds"] = feeds |
||||
st.rerun() |
||||
else: |
||||
st.error("No RSS feeds found at the provided URL.") |
||||
if "discovered_feeds" in st.session_state: |
||||
st.subheader("Select a Feed to Add") |
||||
feeds = st.session_state["discovered_feeds"] |
||||
feed_options = [f"{feed['title']} ({feed['href']})" for feed in feeds] |
||||
selected_feed = st.selectbox("Available Feeds", options=feed_options) |
||||
selected_feed_url = feeds[feed_options.index(selected_feed)]["href"] |
||||
if st.button("Preview Feed"): |
||||
feed_data = feedparser.parse(selected_feed_url) |
||||
st.write(f"{feed_data.feed.get('title', 'No title')}") |
||||
description = html_to_markdown( |
||||
feed_data.feed.get("description", "No description") |
||||
) |
||||
st.write(f"_{description}_") |
||||
for entry in feed_data.entries[:5]: |
||||
print("ENTRY:") |
||||
with st.expander(entry.title): |
||||
summary = ( |
||||
entry.summary |
||||
if "summary" in entry |
||||
else "No summary available" |
||||
) |
||||
markdown_summary = html_to_markdown(summary) |
||||
st.markdown(markdown_summary) |
||||
if st.button( |
||||
"Add RSS Feed", |
||||
on_click=self.add_rss_feed, |
||||
args=(selected_feed_url, feed_data, description), |
||||
): |
||||
|
||||
del st.session_state["discovered_feeds"] |
||||
st.success("RSS Feed added.") |
||||
st.rerun() |
||||
|
||||
def discover_feeds(self, url): |
||||
try: |
||||
if not url.startswith("http"): |
||||
url = "https://" + url |
||||
|
||||
# Check if the input URL is already an RSS feed |
||||
f = feedparser.parse(url) |
||||
if len(f.entries) > 0: |
||||
return [ |
||||
{ |
||||
"href": url, |
||||
"title": f.feed.get("title", "No title"), |
||||
"icon": self.get_site_icon(url), |
||||
} |
||||
] |
||||
|
||||
# If not, proceed to discover feeds from the webpage |
||||
raw = requests.get(url).text |
||||
result = [] |
||||
possible_feeds = [] |
||||
html = BeautifulSoup(raw, "html.parser") |
||||
|
||||
# Find the site icon |
||||
icon_url = self.get_site_icon(url, html) |
||||
|
||||
# Find all <link> tags with rel="alternate" and type containing "rss" or "xml" |
||||
feed_urls = html.findAll("link", rel="alternate") |
||||
for f in feed_urls: |
||||
t = f.get("type", None) |
||||
if t and ("rss" in t or "xml" in t): |
||||
href = f.get("href", None) |
||||
if href: |
||||
possible_feeds.append(urljoin(url, href)) |
||||
|
||||
# Find all <a> tags with href containing "rss", "xml", or "feed" |
||||
parsed_url = urllib.parse.urlparse(url) |
||||
base = parsed_url.scheme + "://" + parsed_url.hostname |
||||
atags = html.findAll("a") |
||||
for a in atags: |
||||
href = a.get("href", None) |
||||
if href and ("rss" in href or "xml" in href or "feed" in href): |
||||
possible_feeds.append(urljoin(base, href)) |
||||
|
||||
# Validate the possible feeds using feedparser |
||||
for feed_url in list(set(possible_feeds)): |
||||
f = feedparser.parse(feed_url) |
||||
if len(f.entries) > 0: |
||||
result.append( |
||||
{ |
||||
"href": feed_url, |
||||
"title": f.feed.get("title", "No title"), |
||||
"icon": icon_url, |
||||
} |
||||
) |
||||
|
||||
return result |
||||
except Exception as e: |
||||
print(f"Error discovering feeds: {e}") |
||||
return [] |
||||
|
||||
|
||||
def add_rss_feed(self, url, feed_data, description): |
||||
try: |
||||
icon_url = feed_data["feed"]["image"]["href"] |
||||
except Exception as e: |
||||
icon_url = self.get_site_icon(url) |
||||
|
||||
title = feed_data["feed"].get("title", "No title") |
||||
print_blue(title) |
||||
icon_path = download_icon(icon_url) if icon_url else None |
||||
_key = fix_key(url) |
||||
now_timestamp = datetime.now().isoformat() # Convert datetime to ISO format string |
||||
|
||||
self.user_arango.db.collection("rss_feeds").insert( |
||||
{ |
||||
"_key": _key, |
||||
"url": url, |
||||
"title": title, |
||||
"icon_path": icon_path, |
||||
"description": description, |
||||
'fetched_timestamp': now_timestamp, # Add the timestamp field |
||||
'feed_data': feed_data, |
||||
}, |
||||
overwrite=True, |
||||
) |
||||
|
||||
feed = self.get_feed_from_arango(_key) |
||||
now_timestamp = datetime.now().isoformat() # Convert datetime to ISO format string |
||||
if feed: |
||||
self.update_feed(_key, feed) |
||||
else: |
||||
self.base_arango.db.collection("rss_feeds").insert( |
||||
{ |
||||
"_key": _key, |
||||
"url": url, |
||||
"title": title, |
||||
"icon_path": icon_path, |
||||
"description": description, |
||||
'fetched_timestamp': now_timestamp, # Add the timestamp field |
||||
"feed_data": feed_data, |
||||
}, |
||||
overwrite=True, |
||||
overwrite_mode="update", |
||||
) |
||||
def update_feed(self, feed_key, feed=None): |
||||
""" |
||||
Updates RSS feed that already exists in the ArangoDB base database. |
||||
|
||||
Args: |
||||
feed_key (str): The key identifying the feed in the database. |
||||
|
||||
Returns: |
||||
dict: The parsed feed data. |
||||
|
||||
Raises: |
||||
Exception: If there is an error updating the feed in the database. |
||||
""" |
||||
if not feed: |
||||
feed = self.get_feed_from_arango(feed_key) |
||||
|
||||
feed_data = feedparser.parse(feed["url"]) |
||||
print_rainbow(feed_data['feed']) |
||||
feed["feed_data"] = feed_data |
||||
if self.username not in feed.get("users", []): |
||||
feed["users"] = feed.get("users", []) + [self.username] |
||||
fetched_timestamp = datetime.now().isoformat() # Convert datetime to ISO format string |
||||
|
||||
# Update the fetched_timestamp in the database |
||||
self.base_arango.db.collection("rss_feeds").update( |
||||
{ |
||||
"_key": feed["_key"], |
||||
"fetched_timestamp": fetched_timestamp, |
||||
"feed_data": feed_data, |
||||
} |
||||
) |
||||
return feed_data |
||||
|
||||
|
||||
def update_session_state(self, page_name=None): |
||||
# Update session state |
||||
if page_name: |
||||
st.session_state[page_name] = self.__dict__ |
||||
|
||||
def get_site_icon(self, url, html=None): |
||||
try: |
||||
if not html: |
||||
raw = requests.get(url).text |
||||
html = BeautifulSoup(raw, "html.parser") |
||||
|
||||
icon_link = html.find("link", rel="icon") |
||||
if icon_link: |
||||
icon_url = icon_link.get("href", None) |
||||
if icon_url: |
||||
return urljoin(url, icon_url) |
||||
|
||||
# Fallback to finding other common icon links |
||||
icon_link = html.find("link", rel="shortcut icon") |
||||
if icon_link: |
||||
icon_url = icon_link.get("href", None) |
||||
if icon_url: |
||||
return urljoin(url, icon_url) |
||||
|
||||
return None |
||||
except Exception as e: |
||||
print(f"Error getting site icon: {e}") |
||||
return None |
||||
|
||||
def get_feed_from_arango(self, feed_key): |
||||
""" |
||||
Retrieve an RSS feed from the ArangoDB base databse. |
||||
|
||||
Args: |
||||
feed_key (str): The key of the RSS feed to retrieve from the ArangoDB base database. |
||||
|
||||
Returns: |
||||
dict: The RSS feed document retrieved from the ArangoDB base database. |
||||
""" |
||||
return self.base_arango.db.collection("rss_feeds").get(feed_key) |
||||
|
||||
|
||||
def get_feed(self, feed_key): |
||||
feed = self.get_feed_from_arango(feed_key) |
||||
feed_data = feed["feed_data"] |
||||
fetched_time = datetime.fromisoformat(feed['fetched_timestamp']) # Parse the timestamp string |
||||
|
||||
if datetime.now() - fetched_time < timedelta(hours=1): |
||||
return feed_data |
||||
else: |
||||
return self.update_feed(feed_key) |
||||
|
||||
|
||||
def display_feed(self): |
||||
if st.session_state["selected_feed"]: |
||||
feed_data = self.get_feed(st.session_state["selected_feed"]) |
||||
|
||||
st.title(feed_data['feed'].get("title", "No title")) |
||||
st.write(feed_data['feed'].get("description", "No description")) |
||||
st.write("**Recent Entries:**") |
||||
for entry in feed_data['entries'][:5]: |
||||
with st.expander(entry['title']): |
||||
summary = ( |
||||
entry['summary'] if "summary" in entry else "No summary available" |
||||
) |
||||
markdown_summary = html_to_markdown(summary) |
||||
st.markdown(markdown_summary) |
||||
st.markdown(f"[Read more]({entry['link']})") |
||||
|
||||
|
||||
def html_to_markdown(html): |
||||
soup = BeautifulSoup(html, "html.parser") |
||||
for br in soup.find_all("br"): |
||||
br.replace_with("\n") |
||||
for strong in soup.find_all("strong"): |
||||
strong.replace_with(f"**{strong.text}**") |
||||
for em in soup.find_all("em"): |
||||
em.replace_with(f"*{em.text}*") |
||||
for p in soup.find_all("p"): |
||||
p.replace_with(f"{p.text}\n\n") |
||||
return soup.get_text() |
||||
|
||||
|
||||
def download_icon(icon_url, save_folder="external_icons"): |
||||
try: |
||||
if not os.path.exists(save_folder): |
||||
os.makedirs(save_folder) |
||||
|
||||
response = requests.get(icon_url, stream=True) |
||||
if response.status_code == 200: |
||||
icon_name = os.path.basename(icon_url) |
||||
icon_path = os.path.join(save_folder, icon_name) |
||||
with open(icon_path, "wb") as f: |
||||
for chunk in response.iter_content(1024): |
||||
f.write(chunk) |
||||
return icon_path |
||||
else: |
||||
print(f"Failed to download icon: {response.status_code}") |
||||
return None |
||||
except Exception as e: |
||||
print(f"Error downloading icon: {e}") |
||||
return None |
||||
Loading…
Reference in new issue