Refactor file paths and load prompts from prompts.yaml
This commit is contained in:
parent
cc259b08aa
commit
b22088cd3d
@ -24,15 +24,15 @@ except LookupError:
|
|||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# Construct the absolute path to the prompts.yaml file
|
# Construct the absolute path to the prompts.yaml file
|
||||||
prompts_path = os.path.join(script_dir, 'prompts.yaml')
|
prompts_path = os.path.join(script_dir, "prompts.yaml")
|
||||||
|
|
||||||
# Load prompts from configuration file
|
# Load prompts from configuration file
|
||||||
with open(prompts_path, 'r') as file:
|
with open(prompts_path, "r") as file:
|
||||||
prompts = yaml.safe_load(file)
|
prompts = yaml.safe_load(file)
|
||||||
|
|
||||||
CUSTOM_SYSTEM_PROMPT = prompts['CUSTOM_SYSTEM_PROMPT']
|
CUSTOM_SYSTEM_PROMPT = prompts["CUSTOM_SYSTEM_PROMPT"]
|
||||||
GET_SENTENCES_PROMPT = prompts['GET_SENTENCES_PROMPT']
|
GET_SENTENCES_PROMPT = prompts["GET_SENTENCES_PROMPT"]
|
||||||
EXPLANATION_PROMPT = prompts['EXPLANATION_PROMPT']
|
EXPLANATION_PROMPT = prompts["EXPLANATION_PROMPT"]
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
@ -85,7 +85,7 @@ class LLM:
|
|||||||
memory (bool, optional): Whether to use memory. Defaults to True.
|
memory (bool, optional): Whether to use memory. Defaults to True.
|
||||||
keep_alive (int, optional): Keep-alive duration in seconds. Defaults to 3600.
|
keep_alive (int, optional): Keep-alive duration in seconds. Defaults to 3600.
|
||||||
"""
|
"""
|
||||||
dotenv.load_dotenv()
|
|
||||||
if model:
|
if model:
|
||||||
self.model = model
|
self.model = model
|
||||||
else:
|
else:
|
||||||
@ -99,9 +99,12 @@ class LLM:
|
|||||||
else:
|
else:
|
||||||
self.messages = [{"role": "system", "content": CUSTOM_SYSTEM_PROMPT}]
|
self.messages = [{"role": "system", "content": CUSTOM_SYSTEM_PROMPT}]
|
||||||
|
|
||||||
if openai_key: # For use with OpenAI
|
# Check if OpenAI key is provided
|
||||||
|
if openai_key: # Use OpenAI
|
||||||
self.use_openai(openai_key, model)
|
self.use_openai(openai_key, model)
|
||||||
else: # For use with Ollama
|
elif os.getenv("OPENAI_API_KEY") != '': # Use OpenAI
|
||||||
|
self.use_openai(os.getenv("OPENAI_API_KEY"), model)
|
||||||
|
else: # Use Ollama
|
||||||
self.use_ollama(model)
|
self.use_ollama(model)
|
||||||
|
|
||||||
def use_openai(self, key, model):
|
def use_openai(self, key, model):
|
||||||
@ -128,7 +131,7 @@ class LLM:
|
|||||||
if model:
|
if model:
|
||||||
self.model = model
|
self.model = model
|
||||||
else:
|
else:
|
||||||
self.model = os.getenv("OPENAI_MODEL")
|
self.model = os.getenv("LLM_MODEL")
|
||||||
|
|
||||||
def use_ollama(self, model):
|
def use_ollama(self, model):
|
||||||
"""
|
"""
|
||||||
@ -233,6 +236,11 @@ class Highlighter:
|
|||||||
llm_memory (bool): Flag to enable or disable memory for the language model.
|
llm_memory (bool): Flag to enable or disable memory for the language model.
|
||||||
llm_keep_alive (int): The keep-alive duration for the language model in seconds.
|
llm_keep_alive (int): The keep-alive duration for the language model in seconds.
|
||||||
"""
|
"""
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Ensure both model are provided or set in the environment
|
||||||
|
assert llm_model or os.getenv("LLM_MODEL"), "LLM_MODEL must be provided as argument or set in the environment."
|
||||||
|
|
||||||
self.silent = silent
|
self.silent = silent
|
||||||
self.comment = comment
|
self.comment = comment
|
||||||
self.llm_params = {
|
self.llm_params = {
|
||||||
@ -245,6 +253,7 @@ class Highlighter:
|
|||||||
"keep_alive": llm_keep_alive,
|
"keep_alive": llm_keep_alive,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def highlight(
|
async def highlight(
|
||||||
self,
|
self,
|
||||||
user_input,
|
user_input,
|
||||||
@ -270,12 +279,15 @@ class Highlighter:
|
|||||||
), "You need to provide either a PDF filename, a list of filenames or data in JSON format."
|
), "You need to provide either a PDF filename, a list of filenames or data in JSON format."
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
docs = [item['pdf_filename'] for item in data]
|
docs = [item["pdf_filename"] for item in data]
|
||||||
|
|
||||||
if not docs:
|
if not docs:
|
||||||
docs = [pdf_filename]
|
docs = [pdf_filename]
|
||||||
|
|
||||||
tasks = [self.annotate_pdf(user_input, doc, pages=item.get('pages')) for doc, item in zip(docs, data or [{}]*len(docs))]
|
tasks = [
|
||||||
|
self.annotate_pdf(user_input, doc, pages=item.get("pages"))
|
||||||
|
for doc, item in zip(docs, data or [{}] * len(docs))
|
||||||
|
]
|
||||||
pdf_buffers = await asyncio.gather(*tasks)
|
pdf_buffers = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
combined_pdf = pymupdf.open()
|
combined_pdf = pymupdf.open()
|
||||||
@ -404,17 +416,33 @@ if __name__ == "__main__":
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
# Set up argument parser for command-line interface
|
# Set up argument parser for command-line interface
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("--user_input", type=str, help="The user input")
|
description=(
|
||||||
parser.add_argument("--pdf_filename", type=str, help="The PDF filename")
|
"Highlight sentences in PDF documents using an LLM.\n\n"
|
||||||
parser.add_argument("--silent", action="store_true", help="No user warnings")
|
"For more information, visit: https://github.com/lasseedfast/pdf-highlighter/blob/main/README.md"
|
||||||
parser.add_argument("--openai_key", type=str, help="OpenAI API key")
|
)
|
||||||
parser.add_argument("--comment", action="store_true", help="Include comments")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--user_input",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The text input from the user to highlight in the PDFs.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--pdf_filename", type=str, help="The PDF filename to process.")
|
||||||
|
parser.add_argument("--silent", action="store_true", help="Suppress warnings.")
|
||||||
|
parser.add_argument("--openai_key", type=str, help="API key for OpenAI.")
|
||||||
|
parser.add_argument("--llm_model", type=str, help="The model name for the language model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--comment",
|
||||||
|
action="store_true",
|
||||||
|
help="Include comments in the highlighted PDF.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
type=json.loads,
|
type=json.loads,
|
||||||
help="The data in JSON format (fields: user_input, pdf_filename, list_of_pages)",
|
help="Data in JSON format (fields: user_input, pdf_filename, list_of_pages).",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Initialize the Highlighter class with the provided arguments
|
# Initialize the Highlighter class with the provided arguments
|
||||||
@ -422,6 +450,7 @@ if __name__ == "__main__":
|
|||||||
silent=args.silent,
|
silent=args.silent,
|
||||||
openai_key=args.openai_key,
|
openai_key=args.openai_key,
|
||||||
comment=args.comment,
|
comment=args.comment,
|
||||||
|
llm_model=args.llm_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define the main asynchronous function to highlight the PDF
|
# Define the main asynchronous function to highlight the PDF
|
||||||
@ -432,9 +461,12 @@ if __name__ == "__main__":
|
|||||||
data=args.data,
|
data=args.data,
|
||||||
)
|
)
|
||||||
# Save the highlighted PDF to a new file
|
# Save the highlighted PDF to a new file
|
||||||
|
filename = args.pdf_filename.replace(".pdf", "_highlighted.pdf")
|
||||||
await save_pdf_to_file(
|
await save_pdf_to_file(
|
||||||
highlighted_pdf, args.pdf_filename.replace(".pdf", "_highlighted.pdf")
|
highlighted_pdf, filename
|
||||||
)
|
)
|
||||||
|
# Print the clickable file path
|
||||||
|
print(f'''Highlighted PDF saved to "file://{filename.replace(' ', '%20')}"''')
|
||||||
|
|
||||||
# Run the main function using asyncio
|
# Run the main function using asyncio
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
Loading…
x
Reference in New Issue
Block a user