Refactor file paths and load prompts from prompts.yaml

This commit is contained in:
lasseedfast 2024-10-08 09:45:46 +02:00
parent cc259b08aa
commit b22088cd3d

View File

@ -24,15 +24,15 @@ except LookupError:
script_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the absolute path to the prompts.yaml file
prompts_path = os.path.join(script_dir, 'prompts.yaml')
prompts_path = os.path.join(script_dir, "prompts.yaml")
# Load prompts from configuration file
with open(prompts_path, 'r') as file:
with open(prompts_path, "r") as file:
prompts = yaml.safe_load(file)
CUSTOM_SYSTEM_PROMPT = prompts['CUSTOM_SYSTEM_PROMPT']
GET_SENTENCES_PROMPT = prompts['GET_SENTENCES_PROMPT']
EXPLANATION_PROMPT = prompts['EXPLANATION_PROMPT']
CUSTOM_SYSTEM_PROMPT = prompts["CUSTOM_SYSTEM_PROMPT"]
GET_SENTENCES_PROMPT = prompts["GET_SENTENCES_PROMPT"]
EXPLANATION_PROMPT = prompts["EXPLANATION_PROMPT"]
class LLM:
@ -85,7 +85,7 @@ class LLM:
memory (bool, optional): Whether to use memory. Defaults to True.
keep_alive (int, optional): Keep-alive duration in seconds. Defaults to 3600.
"""
dotenv.load_dotenv()
if model:
self.model = model
else:
@ -99,9 +99,12 @@ class LLM:
else:
self.messages = [{"role": "system", "content": CUSTOM_SYSTEM_PROMPT}]
if openai_key: # For use with OpenAI
# Check if OpenAI key is provided
if openai_key: # Use OpenAI
self.use_openai(openai_key, model)
else: # For use with Ollama
elif os.getenv("OPENAI_API_KEY") != '': # Use OpenAI
self.use_openai(os.getenv("OPENAI_API_KEY"), model)
else: # Use Ollama
self.use_ollama(model)
def use_openai(self, key, model):
@ -128,7 +131,7 @@ class LLM:
if model:
self.model = model
else:
self.model = os.getenv("OPENAI_MODEL")
self.model = os.getenv("LLM_MODEL")
def use_ollama(self, model):
"""
@ -233,6 +236,11 @@ class Highlighter:
llm_memory (bool): Flag to enable or disable memory for the language model.
llm_keep_alive (int): The keep-alive duration for the language model in seconds.
"""
dotenv.load_dotenv()
# Ensure both model are provided or set in the environment
assert llm_model or os.getenv("LLM_MODEL"), "LLM_MODEL must be provided as argument or set in the environment."
self.silent = silent
self.comment = comment
self.llm_params = {
@ -245,6 +253,7 @@ class Highlighter:
"keep_alive": llm_keep_alive,
}
async def highlight(
self,
user_input,
@ -270,12 +279,15 @@ class Highlighter:
), "You need to provide either a PDF filename, a list of filenames or data in JSON format."
if data:
docs = [item['pdf_filename'] for item in data]
docs = [item["pdf_filename"] for item in data]
if not docs:
docs = [pdf_filename]
tasks = [self.annotate_pdf(user_input, doc, pages=item.get('pages')) for doc, item in zip(docs, data or [{}]*len(docs))]
tasks = [
self.annotate_pdf(user_input, doc, pages=item.get("pages"))
for doc, item in zip(docs, data or [{}] * len(docs))
]
pdf_buffers = await asyncio.gather(*tasks)
combined_pdf = pymupdf.open()
@ -404,17 +416,33 @@ if __name__ == "__main__":
import json
# Set up argument parser for command-line interface
parser = argparse.ArgumentParser()
parser.add_argument("--user_input", type=str, help="The user input")
parser.add_argument("--pdf_filename", type=str, help="The PDF filename")
parser.add_argument("--silent", action="store_true", help="No user warnings")
parser.add_argument("--openai_key", type=str, help="OpenAI API key")
parser.add_argument("--comment", action="store_true", help="Include comments")
parser = argparse.ArgumentParser(
description=(
"Highlight sentences in PDF documents using an LLM.\n\n"
"For more information, visit: https://github.com/lasseedfast/pdf-highlighter/blob/main/README.md"
)
)
parser.add_argument(
"--user_input",
type=str,
required=True,
help="The text input from the user to highlight in the PDFs.",
)
parser.add_argument("--pdf_filename", type=str, help="The PDF filename to process.")
parser.add_argument("--silent", action="store_true", help="Suppress warnings.")
parser.add_argument("--openai_key", type=str, help="API key for OpenAI.")
parser.add_argument("--llm_model", type=str, help="The model name for the language model.")
parser.add_argument(
"--comment",
action="store_true",
help="Include comments in the highlighted PDF.",
)
parser.add_argument(
"--data",
type=json.loads,
help="The data in JSON format (fields: user_input, pdf_filename, list_of_pages)",
help="Data in JSON format (fields: user_input, pdf_filename, list_of_pages).",
)
args = parser.parse_args()
# Initialize the Highlighter class with the provided arguments
@ -422,6 +450,7 @@ if __name__ == "__main__":
silent=args.silent,
openai_key=args.openai_key,
comment=args.comment,
llm_model=args.llm_model,
)
# Define the main asynchronous function to highlight the PDF
@ -432,9 +461,12 @@ if __name__ == "__main__":
data=args.data,
)
# Save the highlighted PDF to a new file
filename = args.pdf_filename.replace(".pdf", "_highlighted.pdf")
await save_pdf_to_file(
highlighted_pdf, args.pdf_filename.replace(".pdf", "_highlighted.pdf")
highlighted_pdf, filename
)
# Print the clickable file path
print(f'''Highlighted PDF saved to "file://{filename.replace(' ', '%20')}"''')
# Run the main function using asyncio
asyncio.run(main())
asyncio.run(main())