You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
90 lines
2.6 KiB
90 lines
2.6 KiB
from arango_things import arango_db, get_documents |
|
from sys import argv |
|
|
|
from datetime import datetime |
|
|
|
from langchain.llms import LlamaCpp |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
from langchain.callbacks.manager import CallbackManager |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
def translate(text, llm): |
|
template = """ |
|
You are a professional translator. Only translate, nothing else, and never add anything of your own. |
|
Translate this text into English. |
|
|
|
Text: {text} |
|
|
|
Translation: |
|
""" |
|
|
|
prompt = PromptTemplate(template=template, input_variables=["text"], ) |
|
llm_chain = LLMChain(prompt=prompt, llm=llm) |
|
|
|
return llm_chain.run(text) |
|
|
|
|
|
|
|
# Callbacks support token-wise streaming |
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
|
|
|
|
|
n_gpu_layers = 80 # Change this value based on your model and your GPU VRAM pool. |
|
n_batch = 4096 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU. |
|
|
|
|
|
if len (argv) > 1: |
|
model = argv[1] |
|
model_folder = model[:model.rfind('/')] |
|
model_filename = model[model.rfind('/')+1:] |
|
model_folder = 'model_files' |
|
model_filename = 'mistral-7b-openorca.Q4_K_M.gguf' #'mistral-7b-openorca.Q5_K_S.gguf' #'mistral-7b-openorca.Q4_K_M.gguf' |
|
|
|
llm = LlamaCpp( |
|
model_path=f'{model_folder}/{model_filename}', |
|
n_gpu_layers=n_gpu_layers, |
|
n_batch=n_batch, |
|
n_ctx = 4096, |
|
temperature=0, |
|
max_tokens = 2500, |
|
callback_manager=callback_manager, |
|
verbose=True, # Verbose is required to pass to the callback manager |
|
) |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=0) |
|
|
|
# Get records without translation. |
|
query = ''' |
|
FOR doc IN speeches |
|
FILTER doc.language != 'EN' |
|
FILTER CHAR_LENGTH(doc.translation) < 10 |
|
SORT RAND() |
|
LIMIT 1 |
|
RETURN doc |
|
''' |
|
|
|
while True: |
|
cursor = arango_db.aql.execute(query=query, count=True) |
|
|
|
if cursor.count() == 1: |
|
record = cursor.next() |
|
else: |
|
print('Done!') |
|
break |
|
|
|
# Translate using Ollama. |
|
try: |
|
print(f'\n\n{record["_key"]}\n') |
|
|
|
translation = [] |
|
splitted_text = text_splitter.split_text(record['text']) |
|
for text in splitted_text: |
|
translation.append(translate(text, llm)) |
|
record['translation'] = ' '.join(translation) |
|
record['translation_metas'] = {'with': 'LlamaCpp', 'model': model_filename, 'date': datetime.today().strftime('%Y-%m-%d')} |
|
arango_db.collection("speeches").update(record) |
|
except: |
|
pass |