You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
80 lines
3.4 KiB
80 lines
3.4 KiB
import whisperx |
|
|
|
class Transcription: |
|
def __init__(self): |
|
self.transcriptions = [] |
|
|
|
def add_transcription(self, json_data): |
|
for line in json_data['lines']: |
|
transcription_entry = { |
|
'speaker': line['speakerDesignation'], |
|
'start_time': self.convert_time_format(line['startTime']), |
|
'end_time': self.convert_time_format(line['endTime']), |
|
'text': line['text'] |
|
} |
|
self.transcriptions.append(transcription_entry) |
|
|
|
def add_segments(self, segments): |
|
# segments: list of dicts, each with 'words' (list of dicts with 'word', 'start', 'end', 'speaker') |
|
for segment in segments: |
|
for word_info in segment.get('words', []): |
|
# Convert numpy float64 to Python float if needed |
|
start = float(word_info['start']) |
|
end = float(word_info['end']) |
|
transcription_entry = { |
|
'speaker': word_info.get('speaker', 'UU'), |
|
'start_time': f"{start:.3f}", |
|
'end_time': f"{end:.3f}", |
|
'text': word_info['word'] |
|
} |
|
self.transcriptions.append(transcription_entry) |
|
|
|
def convert_time_format(self, time_str): |
|
hours, minutes, seconds = time_str.split(':') |
|
seconds, milliseconds = seconds.split(',') |
|
return f"{int(hours) * 3600 + int(minutes) * 60 + int(seconds)}.{milliseconds}" |
|
|
|
def to_xml(self, audio_file_id=None): |
|
xml_transcription = "<Transcription Revision=\"1\">\n<p>\n" |
|
for entry in self.transcriptions: |
|
start = float(entry['start_time']) |
|
end = float(entry['end_time']) |
|
length = end - start |
|
xml_transcription += f"<w sp=\"{entry['speaker']}\" s=\"{start:.3f}\" l=\"{length:.3f}\">{entry['text']}</w>\n" |
|
xml_transcription += "</p>\n</Transcription>" |
|
return xml_transcription |
|
|
|
def clear_transcriptions(self): |
|
self.transcriptions = [] |
|
|
|
|
|
|
|
def transcribe(audio_file, min_speakers=2, max_speakers=4): |
|
|
|
|
|
HF_TOKEN = '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' |
|
device = "cpu" |
|
batch_size = 16 # reduce if low on GPU mem |
|
compute_type = "int8" # change to "int8" if low on GPU mem (may reduce accuracy) |
|
|
|
# Transcribe with original whisper (batched) |
|
model = whisperx.load_model("turbo", device, compute_type=compute_type) |
|
|
|
audio = whisperx.load_audio(audio_file) |
|
result = model.transcribe(audio, batch_size=batch_size) |
|
|
|
# Align whisper output |
|
if result["language"] == "sv": |
|
model_a, metadata = whisperx.load_align_model(language_code="sv", device=device, model_name="viktor-enzell/wav2vec2-large-voxrex-swedish-4gram") |
|
else: |
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) |
|
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) |
|
|
|
# Assign speaker labels |
|
diarize_model = whisperx.diarize.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device) |
|
# diarize_model.model.embedding_batch_size = 4 |
|
# diarize_model.model.segmentation_batch_size = 4 |
|
diarize_segments = diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) |
|
|
|
result = whisperx.assign_word_speakers(diarize_segments, result) |
|
return result["segments"] # Ensure the function returns the transcription segments
|
|
|