You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
48 lines
1.8 KiB
48 lines
1.8 KiB
import requests |
|
|
|
class LLM(): |
|
def __init__(self, system_prompt=None, temperature=0.8, max_new_tokens=1000): |
|
""" |
|
Initializes the LLM class with the given parameters. |
|
|
|
Args: |
|
system_prompt (str, optional): The system prompt to use. Defaults to "Be precise and keep to the given information.". |
|
temperature (float, optional): The temperature to use for generating new tokens. Defaults to 0.8. |
|
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 1000. |
|
""" |
|
self.temperature=temperature |
|
self.max_new_tokens=max_new_tokens |
|
if system_prompt is None: |
|
self.system_prompt="Be precise and keep to the given information." |
|
else: |
|
self.system_prompt=system_prompt |
|
|
|
def generate(self, prompt, repeat_penalty=1.2): |
|
""" |
|
Generates new tokens based on the given prompt. |
|
|
|
Args: |
|
prompt (str): The prompt to use for generating new tokens. |
|
|
|
Returns: |
|
str: The generated tokens. |
|
""" |
|
# Make a POST request to the API endpoint |
|
headers = {"Content-Type": "application/json"} |
|
url = "http://localhost:8080/completion" |
|
json={ |
|
"prompt": prompt, |
|
#"system_prompt": self.system_prompt, #TODO https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#change-system-prompt-on-runtime |
|
"temperature": self.temperature, |
|
"n_predict": self.max_new_tokens, |
|
"top_k": 30, |
|
"repeat_penalty": repeat_penalty, |
|
} |
|
|
|
|
|
response = requests.post(url, headers=headers, json=json) |
|
if not response.ok: |
|
print(response.content) |
|
else: |
|
return response.json()['content'] |
|
|
|
|