Refactor file paths and load prompts from prompts.yaml

This commit is contained in:
lasseedfast 2024-10-08 09:45:46 +02:00
parent cc259b08aa
commit b22088cd3d

View File

@ -24,15 +24,15 @@ except LookupError:
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the absolute path to the prompts.yaml file # Construct the absolute path to the prompts.yaml file
prompts_path = os.path.join(script_dir, 'prompts.yaml') prompts_path = os.path.join(script_dir, "prompts.yaml")
# Load prompts from configuration file # Load prompts from configuration file
with open(prompts_path, 'r') as file: with open(prompts_path, "r") as file:
prompts = yaml.safe_load(file) prompts = yaml.safe_load(file)
CUSTOM_SYSTEM_PROMPT = prompts['CUSTOM_SYSTEM_PROMPT'] CUSTOM_SYSTEM_PROMPT = prompts["CUSTOM_SYSTEM_PROMPT"]
GET_SENTENCES_PROMPT = prompts['GET_SENTENCES_PROMPT'] GET_SENTENCES_PROMPT = prompts["GET_SENTENCES_PROMPT"]
EXPLANATION_PROMPT = prompts['EXPLANATION_PROMPT'] EXPLANATION_PROMPT = prompts["EXPLANATION_PROMPT"]
class LLM: class LLM:
@ -85,7 +85,7 @@ class LLM:
memory (bool, optional): Whether to use memory. Defaults to True. memory (bool, optional): Whether to use memory. Defaults to True.
keep_alive (int, optional): Keep-alive duration in seconds. Defaults to 3600. keep_alive (int, optional): Keep-alive duration in seconds. Defaults to 3600.
""" """
dotenv.load_dotenv()
if model: if model:
self.model = model self.model = model
else: else:
@ -99,9 +99,12 @@ class LLM:
else: else:
self.messages = [{"role": "system", "content": CUSTOM_SYSTEM_PROMPT}] self.messages = [{"role": "system", "content": CUSTOM_SYSTEM_PROMPT}]
if openai_key: # For use with OpenAI # Check if OpenAI key is provided
if openai_key: # Use OpenAI
self.use_openai(openai_key, model) self.use_openai(openai_key, model)
else: # For use with Ollama elif os.getenv("OPENAI_API_KEY") != '': # Use OpenAI
self.use_openai(os.getenv("OPENAI_API_KEY"), model)
else: # Use Ollama
self.use_ollama(model) self.use_ollama(model)
def use_openai(self, key, model): def use_openai(self, key, model):
@ -128,7 +131,7 @@ class LLM:
if model: if model:
self.model = model self.model = model
else: else:
self.model = os.getenv("OPENAI_MODEL") self.model = os.getenv("LLM_MODEL")
def use_ollama(self, model): def use_ollama(self, model):
""" """
@ -233,6 +236,11 @@ class Highlighter:
llm_memory (bool): Flag to enable or disable memory for the language model. llm_memory (bool): Flag to enable or disable memory for the language model.
llm_keep_alive (int): The keep-alive duration for the language model in seconds. llm_keep_alive (int): The keep-alive duration for the language model in seconds.
""" """
dotenv.load_dotenv()
# Ensure both model are provided or set in the environment
assert llm_model or os.getenv("LLM_MODEL"), "LLM_MODEL must be provided as argument or set in the environment."
self.silent = silent self.silent = silent
self.comment = comment self.comment = comment
self.llm_params = { self.llm_params = {
@ -245,6 +253,7 @@ class Highlighter:
"keep_alive": llm_keep_alive, "keep_alive": llm_keep_alive,
} }
async def highlight( async def highlight(
self, self,
user_input, user_input,
@ -270,12 +279,15 @@ class Highlighter:
), "You need to provide either a PDF filename, a list of filenames or data in JSON format." ), "You need to provide either a PDF filename, a list of filenames or data in JSON format."
if data: if data:
docs = [item['pdf_filename'] for item in data] docs = [item["pdf_filename"] for item in data]
if not docs: if not docs:
docs = [pdf_filename] docs = [pdf_filename]
tasks = [self.annotate_pdf(user_input, doc, pages=item.get('pages')) for doc, item in zip(docs, data or [{}]*len(docs))] tasks = [
self.annotate_pdf(user_input, doc, pages=item.get("pages"))
for doc, item in zip(docs, data or [{}] * len(docs))
]
pdf_buffers = await asyncio.gather(*tasks) pdf_buffers = await asyncio.gather(*tasks)
combined_pdf = pymupdf.open() combined_pdf = pymupdf.open()
@ -404,17 +416,33 @@ if __name__ == "__main__":
import json import json
# Set up argument parser for command-line interface # Set up argument parser for command-line interface
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(
parser.add_argument("--user_input", type=str, help="The user input") description=(
parser.add_argument("--pdf_filename", type=str, help="The PDF filename") "Highlight sentences in PDF documents using an LLM.\n\n"
parser.add_argument("--silent", action="store_true", help="No user warnings") "For more information, visit: https://github.com/lasseedfast/pdf-highlighter/blob/main/README.md"
parser.add_argument("--openai_key", type=str, help="OpenAI API key") )
parser.add_argument("--comment", action="store_true", help="Include comments") )
parser.add_argument(
"--user_input",
type=str,
required=True,
help="The text input from the user to highlight in the PDFs.",
)
parser.add_argument("--pdf_filename", type=str, help="The PDF filename to process.")
parser.add_argument("--silent", action="store_true", help="Suppress warnings.")
parser.add_argument("--openai_key", type=str, help="API key for OpenAI.")
parser.add_argument("--llm_model", type=str, help="The model name for the language model.")
parser.add_argument(
"--comment",
action="store_true",
help="Include comments in the highlighted PDF.",
)
parser.add_argument( parser.add_argument(
"--data", "--data",
type=json.loads, type=json.loads,
help="The data in JSON format (fields: user_input, pdf_filename, list_of_pages)", help="Data in JSON format (fields: user_input, pdf_filename, list_of_pages).",
) )
args = parser.parse_args() args = parser.parse_args()
# Initialize the Highlighter class with the provided arguments # Initialize the Highlighter class with the provided arguments
@ -422,6 +450,7 @@ if __name__ == "__main__":
silent=args.silent, silent=args.silent,
openai_key=args.openai_key, openai_key=args.openai_key,
comment=args.comment, comment=args.comment,
llm_model=args.llm_model,
) )
# Define the main asynchronous function to highlight the PDF # Define the main asynchronous function to highlight the PDF
@ -432,9 +461,12 @@ if __name__ == "__main__":
data=args.data, data=args.data,
) )
# Save the highlighted PDF to a new file # Save the highlighted PDF to a new file
filename = args.pdf_filename.replace(".pdf", "_highlighted.pdf")
await save_pdf_to_file( await save_pdf_to_file(
highlighted_pdf, args.pdf_filename.replace(".pdf", "_highlighted.pdf") highlighted_pdf, filename
) )
# Print the clickable file path
print(f'''Highlighted PDF saved to "file://{filename.replace(' ', '%20')}"''')
# Run the main function using asyncio # Run the main function using asyncio
asyncio.run(main()) asyncio.run(main())