This repository has been archived by the owner on Apr 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3bca395
commit 97ef41a
Showing
17 changed files
with
397 additions
and
341 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from transformers import WhisperProcessor, WhisperForConditionalGeneration | ||
from gtts import gTTS | ||
import torchaudio | ||
import logging | ||
import os | ||
import requests | ||
from pydub import AudioSegment | ||
import subprocess | ||
import time | ||
import librosa | ||
import torch | ||
import customtkinter | ||
import httpx | ||
from CTkMenuBar import * # Addon Downloaded from https://github.com/Akascape/CTkMenuBar | ||
import re | ||
from tkinter import StringVar | ||
import pygame | ||
from .sentence_translator import SentenceTranslator | ||
|
||
class CustomTranslator: | ||
def __init__(self): | ||
self.processor = None | ||
self.model = None | ||
self.target_language = StringVar() | ||
self.target_language.set("en") # Default target language | ||
|
||
def load_model(self): | ||
# Load the model if it hasn't been loaded | ||
if self.processor is None: | ||
self.processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2") | ||
|
||
if self.model is None: | ||
self.model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2") | ||
|
||
# Move the model to GPU if available | ||
if torch.cuda.is_available(): | ||
self.model.to('cuda') | ||
|
||
def unload_model(self): | ||
# Unload the model if it has been loaded | ||
if self.processor is not None: | ||
del self.processor | ||
self.processor = None | ||
|
||
if self.model is not None: | ||
# Move the model back to CPU before deleting | ||
if torch.cuda.is_available(): | ||
self.model.to('cpu') | ||
del self.model | ||
self.model = None | ||
|
||
def process_audio_chunk(self, input_path, target_language, chunk_idx, output_path): | ||
try: | ||
self.load_model() | ||
|
||
# Load input audio file using librosa | ||
input_waveform, input_sampling_rate = librosa.load(input_path, sr=None, mono=True) | ||
|
||
# Convert NumPy array to PyTorch tensor if needed | ||
if not isinstance(input_waveform, torch.Tensor): | ||
input_waveform = torch.tensor(input_waveform) | ||
|
||
forced_decoder_ids = self.processor.get_decoder_prompt_ids(language=target_language, task="translate") | ||
|
||
# Ensure the input audio has a proper frame rate | ||
if input_sampling_rate != 16000: | ||
resampler = torchaudio.transforms.Resample(orig_freq=input_sampling_rate, new_freq=16000) | ||
input_waveform = resampler(input_waveform) | ||
|
||
# Process the input audio with the processor | ||
input_features = self.processor(input_waveform.numpy(), sampling_rate=16000, return_tensors="pt") | ||
|
||
# Generate token ids | ||
predicted_ids = self.model.generate(input_features["input_features"], forced_decoder_ids=forced_decoder_ids) | ||
|
||
# Decode token ids to text | ||
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | ||
|
||
#fix a bug: Text Validation check if we have duplicate successive words | ||
words = transcription.split() | ||
cleaned_words = [words[0]] | ||
|
||
for word in words[1:]: | ||
if word != cleaned_words[-1]: | ||
cleaned_words.append(word) | ||
|
||
cleaned_str = ' '.join(cleaned_words) | ||
transcription = cleaned_str | ||
|
||
Translation_chunk_output_path = f"{output_path}_Translation_chunk{chunk_idx + 1}.mp3" | ||
|
||
# Use SpeechRecognizer for translation (modify as needed) | ||
if target_language != "en": | ||
translator = SentenceTranslator(src="en", dst=target_language) | ||
translated_text = translator(transcription) | ||
|
||
# Generate final audio output from translated text | ||
self.generate_audio(translated_text, Translation_chunk_output_path, target_language) | ||
logging.info(f"Processing successful. Translated text: {translated_text}") | ||
return translated_text | ||
else: | ||
self.generate_audio(transcription, Translation_chunk_output_path, target_language) | ||
logging.info(f"Processing successful. Translated text: {transcription}") | ||
return transcription | ||
|
||
# Log success | ||
logging.info(f"Translation successful for {input_path}. Translated text: {transcription}") | ||
|
||
except Exception as e: | ||
# Log errors | ||
logging.error(f"Error processing audio: {e}") | ||
raise # Re-raise the exception | ||
|
||
finally: | ||
# Ensure model is unloaded and memory is cleared even if an exception occurs | ||
self.unload_model() | ||
|
||
def generate_audio(self, text, output_path, target_language): | ||
tts = gTTS(text, lang=target_language, slow=False) | ||
tts.save(output_path) | ||
|
||
def play_audio(self, audio_path): # disabled for now | ||
pygame.mixer.init() | ||
pygame.mixer.music.load(audio_path) | ||
pygame.mixer.music.play() | ||
|
||
def stop_audio(self): | ||
if not pygame.mixer.get_init(): | ||
pygame.mixer.init() | ||
try: | ||
pygame.mixer.music.stop() | ||
except: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# AudioFileTranslator_S2ST/main.py | ||
|
||
from tkinter import Tk | ||
from AudioFileTranslator_S2ST.translator_gui import TranslatorGUI | ||
import customtkinter | ||
|
||
def run_gui(): | ||
root = customtkinter.CTk() | ||
app = TranslatorGUI(root) | ||
root.mainloop() | ||
|
||
if __name__ == "__main__": | ||
run_gui() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import requests | ||
import httpx | ||
|
||
class SentenceTranslator: | ||
def __init__(self, src, dst, patience=-1, timeout=30, error_messages_callback=None): | ||
self.src = src | ||
self.dst = dst | ||
self.patience = patience | ||
self.timeout = timeout | ||
self.error_messages_callback = error_messages_callback | ||
|
||
def __call__(self, sentence): | ||
try: | ||
translated_sentence = [] | ||
# handle the special case: empty string. | ||
if not sentence: | ||
return None | ||
translated_sentence = self.GoogleTranslate(sentence, src=self.src, dst=self.dst, timeout=self.timeout) | ||
fail_to_translate = translated_sentence[-1] == '\n' | ||
while fail_to_translate and self.patience: | ||
translated_sentence = self.GoogleTranslate(translated_sentence, src=self.src, dst=self.dst, | ||
timeout=self.timeout).text | ||
if translated_sentence[-1] == '\n': | ||
if self.patience == -1: | ||
continue | ||
self.patience -= 1 | ||
else: | ||
fail_to_translate = False | ||
|
||
return translated_sentence | ||
|
||
except KeyboardInterrupt: | ||
if self.error_messages_callback: | ||
self.error_messages_callback("Cancelling all tasks") | ||
else: | ||
print("Cancelling all tasks") | ||
return | ||
|
||
except Exception as e: | ||
if self.error_messages_callback: | ||
self.error_messages_callback(e) | ||
else: | ||
print(e) | ||
return | ||
|
||
def GoogleTranslate(self, text, src, dst, timeout=30): | ||
url = 'https://translate.googleapis.com/translate_a/' | ||
params = 'single?client=gtx&sl=' + src + '&tl=' + dst + '&dt=t&q=' + text | ||
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64)', 'Referer': 'https://translate.google.com', } | ||
|
||
try: | ||
response = requests.get(url + params, headers=headers, timeout=timeout) | ||
if response.status_code == 200: | ||
response_json = response.json()[0] | ||
length = len(response_json) | ||
translation = "" | ||
for i in range(length): | ||
translation = translation + response_json[i][0] | ||
return translation | ||
return | ||
|
||
except requests.exceptions.ConnectionError: | ||
with httpx.Client() as client: | ||
response = client.get(url + params, headers=headers, timeout=timeout) | ||
if response.status_code == 200: | ||
response_json = response.json()[0] | ||
length = len(response_json) | ||
translation = "" | ||
for i in range(length): | ||
translation = translation + response_json[i][0] | ||
return translation | ||
return | ||
|
||
except KeyboardInterrupt: | ||
if self.error_messages_callback: | ||
self.error_messages_callback("Cancelling all tasks") | ||
else: | ||
print("Cancelling all tasks") | ||
return | ||
|
||
except Exception as e: | ||
if self.error_messages_callback: | ||
self.error_messages_callback(e) | ||
else: | ||
print(e) | ||
return |
Oops, something went wrong.