You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.5 KiB
51 lines
1.5 KiB
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub |
|
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface |
|
from fairseq import utils |
|
import nltk |
|
import torch |
|
|
|
# Download the required NLTK resource |
|
nltk.download('averaged_perceptron_tagger') |
|
|
|
# Model loading |
|
models, cfg, task = load_model_ensemble_and_task_from_hf_hub( |
|
"facebook/fastspeech2-en-ljspeech", |
|
arg_overrides={"vocoder": "hifigan", "fp16": False} |
|
) |
|
|
|
# Set device |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# Move all models to the correct device |
|
for model in models: |
|
model.to(device) |
|
|
|
# Update configuration and build generator after moving models |
|
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg) |
|
generator = task.build_generator(models, cfg) |
|
|
|
# Ensure the vocoder is on the correct device |
|
generator.vocoder.model.to(device) |
|
|
|
# Define your text |
|
text = """Hi there, thanks for having me! My interest in electric cars really started back when I was a teenager...""" |
|
|
|
# Convert text to model input |
|
sample = TTSHubInterface.get_model_input(task, text) |
|
|
|
# Recursively move all tensors in sample to the correct device |
|
sample = utils.move_to_cuda(sample) if torch.cuda.is_available() else sample |
|
|
|
|
|
|
|
# Generate speech |
|
wav, rate = TTSHubInterface.get_prediction(task, models[0], generator, sample) |
|
|
|
from scipy.io.wavfile import write |
|
|
|
# If wav is a tensor, convert it to a NumPy array |
|
if isinstance(wav, torch.Tensor): |
|
wav = wav.cpu().numpy() |
|
|
|
# Save the audio to a WAV file |
|
write('output_fair.wav', rate, wav) |