You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

80 lines
3.4 KiB

import whisperx
class Transcription:
def __init__(self):
self.transcriptions = []
def add_transcription(self, json_data):
for line in json_data['lines']:
transcription_entry = {
'speaker': line['speakerDesignation'],
'start_time': self.convert_time_format(line['startTime']),
'end_time': self.convert_time_format(line['endTime']),
'text': line['text']
}
self.transcriptions.append(transcription_entry)
def add_segments(self, segments):
# segments: list of dicts, each with 'words' (list of dicts with 'word', 'start', 'end', 'speaker')
for segment in segments:
for word_info in segment.get('words', []):
# Convert numpy float64 to Python float if needed
start = float(word_info['start'])
end = float(word_info['end'])
transcription_entry = {
'speaker': word_info.get('speaker', 'UU'),
'start_time': f"{start:.3f}",
'end_time': f"{end:.3f}",
'text': word_info['word']
}
self.transcriptions.append(transcription_entry)
def convert_time_format(self, time_str):
hours, minutes, seconds = time_str.split(':')
seconds, milliseconds = seconds.split(',')
return f"{int(hours) * 3600 + int(minutes) * 60 + int(seconds)}.{milliseconds}"
def to_xml(self, audio_file_id=None):
xml_transcription = "<Transcription Revision=\"1\">\n<p>\n"
for entry in self.transcriptions:
start = float(entry['start_time'])
end = float(entry['end_time'])
length = end - start
xml_transcription += f"<w sp=\"{entry['speaker']}\" s=\"{start:.3f}\" l=\"{length:.3f}\">{entry['text']}</w>\n"
xml_transcription += "</p>\n</Transcription>"
return xml_transcription
def clear_transcriptions(self):
self.transcriptions = []
def transcribe(audio_file, min_speakers=2, max_speakers=4):
HF_TOKEN = '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'
device = "cpu"
batch_size = 16 # reduce if low on GPU mem
compute_type = "int8" # change to "int8" if low on GPU mem (may reduce accuracy)
# Transcribe with original whisper (batched)
model = whisperx.load_model("turbo", device, compute_type=compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
# Align whisper output
if result["language"] == "sv":
model_a, metadata = whisperx.load_align_model(language_code="sv", device=device, model_name="viktor-enzell/wav2vec2-large-voxrex-swedish-4gram")
else:
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
# Assign speaker labels
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
# diarize_model.model.embedding_batch_size = 4
# diarize_model.model.segmentation_batch_size = 4
diarize_segments = diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
return result["segments"] # Ensure the function returns the transcription segments