You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

90 lines
2.6 KiB

from arango_things import arango_db, get_documents
from sys import argv
from datetime import datetime
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.text_splitter import RecursiveCharacterTextSplitter
def translate(text, llm):
template = """
You are a professional translator. Only translate, nothing else, and never add anything of your own.
Translate this text into English.
Text: {text}
Translation:
"""
prompt = PromptTemplate(template=template, input_variables=["text"], )
llm_chain = LLMChain(prompt=prompt, llm=llm)
return llm_chain.run(text)
# Callbacks support token-wise streaming
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
n_gpu_layers = 80 # Change this value based on your model and your GPU VRAM pool.
n_batch = 4096 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
if len (argv) > 1:
model = argv[1]
model_folder = model[:model.rfind('/')]
model_filename = model[model.rfind('/')+1:]
model_folder = 'model_files'
model_filename = 'mistral-7b-openorca.Q4_K_M.gguf' #'mistral-7b-openorca.Q5_K_S.gguf' #'mistral-7b-openorca.Q4_K_M.gguf'
llm = LlamaCpp(
model_path=f'{model_folder}/{model_filename}',
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
n_ctx = 4096,
temperature=0,
max_tokens = 2500,
callback_manager=callback_manager,
verbose=True, # Verbose is required to pass to the callback manager
)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=0)
# Get records without translation.
query = '''
FOR doc IN speeches
FILTER doc.language != 'EN'
FILTER CHAR_LENGTH(doc.translation) < 10
SORT RAND()
LIMIT 1
RETURN doc
'''
while True:
cursor = arango_db.aql.execute(query=query, count=True)
if cursor.count() == 1:
record = cursor.next()
else:
print('Done!')
break
# Translate using Ollama.
try:
print(f'\n\n{record["_key"]}\n')
translation = []
splitted_text = text_splitter.split_text(record['text'])
for text in splitted_text:
translation.append(translate(text, llm))
record['translation'] = ' '.join(translation)
record['translation_metas'] = {'with': 'LlamaCpp', 'model': model_filename, 'date': datetime.today().strftime('%Y-%m-%d')}
arango_db.collection("speeches").update(record)
except:
pass