Refactor file paths and load prompts from prompts.yaml
This commit is contained in:
parent
cc259b08aa
commit
b22088cd3d
@ -24,15 +24,15 @@ except LookupError:
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Construct the absolute path to the prompts.yaml file
|
||||
prompts_path = os.path.join(script_dir, 'prompts.yaml')
|
||||
prompts_path = os.path.join(script_dir, "prompts.yaml")
|
||||
|
||||
# Load prompts from configuration file
|
||||
with open(prompts_path, 'r') as file:
|
||||
with open(prompts_path, "r") as file:
|
||||
prompts = yaml.safe_load(file)
|
||||
|
||||
CUSTOM_SYSTEM_PROMPT = prompts['CUSTOM_SYSTEM_PROMPT']
|
||||
GET_SENTENCES_PROMPT = prompts['GET_SENTENCES_PROMPT']
|
||||
EXPLANATION_PROMPT = prompts['EXPLANATION_PROMPT']
|
||||
CUSTOM_SYSTEM_PROMPT = prompts["CUSTOM_SYSTEM_PROMPT"]
|
||||
GET_SENTENCES_PROMPT = prompts["GET_SENTENCES_PROMPT"]
|
||||
EXPLANATION_PROMPT = prompts["EXPLANATION_PROMPT"]
|
||||
|
||||
|
||||
class LLM:
|
||||
@ -85,7 +85,7 @@ class LLM:
|
||||
memory (bool, optional): Whether to use memory. Defaults to True.
|
||||
keep_alive (int, optional): Keep-alive duration in seconds. Defaults to 3600.
|
||||
"""
|
||||
dotenv.load_dotenv()
|
||||
|
||||
if model:
|
||||
self.model = model
|
||||
else:
|
||||
@ -99,9 +99,12 @@ class LLM:
|
||||
else:
|
||||
self.messages = [{"role": "system", "content": CUSTOM_SYSTEM_PROMPT}]
|
||||
|
||||
if openai_key: # For use with OpenAI
|
||||
# Check if OpenAI key is provided
|
||||
if openai_key: # Use OpenAI
|
||||
self.use_openai(openai_key, model)
|
||||
else: # For use with Ollama
|
||||
elif os.getenv("OPENAI_API_KEY") != '': # Use OpenAI
|
||||
self.use_openai(os.getenv("OPENAI_API_KEY"), model)
|
||||
else: # Use Ollama
|
||||
self.use_ollama(model)
|
||||
|
||||
def use_openai(self, key, model):
|
||||
@ -128,7 +131,7 @@ class LLM:
|
||||
if model:
|
||||
self.model = model
|
||||
else:
|
||||
self.model = os.getenv("OPENAI_MODEL")
|
||||
self.model = os.getenv("LLM_MODEL")
|
||||
|
||||
def use_ollama(self, model):
|
||||
"""
|
||||
@ -233,6 +236,11 @@ class Highlighter:
|
||||
llm_memory (bool): Flag to enable or disable memory for the language model.
|
||||
llm_keep_alive (int): The keep-alive duration for the language model in seconds.
|
||||
"""
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Ensure both model are provided or set in the environment
|
||||
assert llm_model or os.getenv("LLM_MODEL"), "LLM_MODEL must be provided as argument or set in the environment."
|
||||
|
||||
self.silent = silent
|
||||
self.comment = comment
|
||||
self.llm_params = {
|
||||
@ -245,6 +253,7 @@ class Highlighter:
|
||||
"keep_alive": llm_keep_alive,
|
||||
}
|
||||
|
||||
|
||||
async def highlight(
|
||||
self,
|
||||
user_input,
|
||||
@ -270,12 +279,15 @@ class Highlighter:
|
||||
), "You need to provide either a PDF filename, a list of filenames or data in JSON format."
|
||||
|
||||
if data:
|
||||
docs = [item['pdf_filename'] for item in data]
|
||||
docs = [item["pdf_filename"] for item in data]
|
||||
|
||||
if not docs:
|
||||
docs = [pdf_filename]
|
||||
|
||||
tasks = [self.annotate_pdf(user_input, doc, pages=item.get('pages')) for doc, item in zip(docs, data or [{}]*len(docs))]
|
||||
tasks = [
|
||||
self.annotate_pdf(user_input, doc, pages=item.get("pages"))
|
||||
for doc, item in zip(docs, data or [{}] * len(docs))
|
||||
]
|
||||
pdf_buffers = await asyncio.gather(*tasks)
|
||||
|
||||
combined_pdf = pymupdf.open()
|
||||
@ -404,17 +416,33 @@ if __name__ == "__main__":
|
||||
import json
|
||||
|
||||
# Set up argument parser for command-line interface
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--user_input", type=str, help="The user input")
|
||||
parser.add_argument("--pdf_filename", type=str, help="The PDF filename")
|
||||
parser.add_argument("--silent", action="store_true", help="No user warnings")
|
||||
parser.add_argument("--openai_key", type=str, help="OpenAI API key")
|
||||
parser.add_argument("--comment", action="store_true", help="Include comments")
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Highlight sentences in PDF documents using an LLM.\n\n"
|
||||
"For more information, visit: https://github.com/lasseedfast/pdf-highlighter/blob/main/README.md"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user_input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The text input from the user to highlight in the PDFs.",
|
||||
)
|
||||
parser.add_argument("--pdf_filename", type=str, help="The PDF filename to process.")
|
||||
parser.add_argument("--silent", action="store_true", help="Suppress warnings.")
|
||||
parser.add_argument("--openai_key", type=str, help="API key for OpenAI.")
|
||||
parser.add_argument("--llm_model", type=str, help="The model name for the language model.")
|
||||
parser.add_argument(
|
||||
"--comment",
|
||||
action="store_true",
|
||||
help="Include comments in the highlighted PDF.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=json.loads,
|
||||
help="The data in JSON format (fields: user_input, pdf_filename, list_of_pages)",
|
||||
help="Data in JSON format (fields: user_input, pdf_filename, list_of_pages).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize the Highlighter class with the provided arguments
|
||||
@ -422,6 +450,7 @@ if __name__ == "__main__":
|
||||
silent=args.silent,
|
||||
openai_key=args.openai_key,
|
||||
comment=args.comment,
|
||||
llm_model=args.llm_model,
|
||||
)
|
||||
|
||||
# Define the main asynchronous function to highlight the PDF
|
||||
@ -432,9 +461,12 @@ if __name__ == "__main__":
|
||||
data=args.data,
|
||||
)
|
||||
# Save the highlighted PDF to a new file
|
||||
filename = args.pdf_filename.replace(".pdf", "_highlighted.pdf")
|
||||
await save_pdf_to_file(
|
||||
highlighted_pdf, args.pdf_filename.replace(".pdf", "_highlighted.pdf")
|
||||
highlighted_pdf, filename
|
||||
)
|
||||
# Print the clickable file path
|
||||
print(f'''Highlighted PDF saved to "file://{filename.replace(' ', '%20')}"''')
|
||||
|
||||
# Run the main function using asyncio
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
Loading…
x
Reference in New Issue
Block a user