parent
0ea28f1652
commit
cd53a49831
2 changed files with 50 additions and 1 deletions
@ -0,0 +1,48 @@ |
||||
import requests |
||||
|
||||
class LLM(): |
||||
def __init__(self, system_prompt=None, temperature=0.8, max_new_tokens=1000): |
||||
""" |
||||
Initializes the LLM class with the given parameters. |
||||
|
||||
Args: |
||||
system_prompt (str, optional): The system prompt to use. Defaults to "Be precise and keep to the given information.". |
||||
temperature (float, optional): The temperature to use for generating new tokens. Defaults to 0.8. |
||||
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 1000. |
||||
""" |
||||
self.temperature=temperature |
||||
self.max_new_tokens=max_new_tokens |
||||
if system_prompt is None: |
||||
self.system_prompt="Be precise and keep to the given information." |
||||
else: |
||||
self.system_prompt=system_prompt |
||||
|
||||
def generate(self, prompt, repeat_penalty=1.2): |
||||
""" |
||||
Generates new tokens based on the given prompt. |
||||
|
||||
Args: |
||||
prompt (str): The prompt to use for generating new tokens. |
||||
|
||||
Returns: |
||||
str: The generated tokens. |
||||
""" |
||||
# Make a POST request to the API endpoint |
||||
headers = {"Content-Type": "application/json"} |
||||
url = "http://localhost:8080/completion" |
||||
json={ |
||||
"prompt": prompt, |
||||
#"system_prompt": self.system_prompt, #TODO https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#change-system-prompt-on-runtime |
||||
"temperature": self.temperature, |
||||
"n_predict": self.max_new_tokens, |
||||
"top_k": 30, |
||||
"repeat_penalty": repeat_penalty, |
||||
} |
||||
|
||||
|
||||
response = requests.post(url, headers=headers, json=json) |
||||
if not response.ok: |
||||
print(response.content) |
||||
else: |
||||
return response.json()['content'] |
||||
|
||||
Loading…
Reference in new issue