import ChatTTS
import torch
import numpy as np
import torchaudio
import re
import time



# 设置环境变量以避免内存碎片化
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# 使用 CPU 进行计算
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

chat = ChatTTS.Chat()
chat.load(compile=False)  # Set to True for better performance
# 记录开始时间
start_time = time.time()
print(f"Start Time: {start_time}")
text = """
好的,让我们用小朋友听得懂的话来聊聊太阳吧!

### 太阳是什么?

太阳呀,就像一个大大的、热热的气球,挂在天空中。它是我们太阳系里的老大,所有的行星,包括我们住的地球,都绕着它转圈圈。太阳真的好大好大,如果把太阳比作一个大西瓜,那地球就像一粒小芝麻。

### 太阳是怎么发光发热的?

太阳里面有很多小小的氢气球,它们在太阳的肚子里挤来挤去,最后抱在一起变成了氦气球。这个过程就像小朋友们玩碰碰车,撞在一起就会产生能量,太阳就是通过这种方式发光发热的。

### 太阳的能量有什么用?

太阳发出的能量,就像一个超级大的灯泡,给我们的地球带来了光明和温暖。植物宝宝用这个能量做游戏,变成了我们吃的食物。我们也可以用这个能量来发电,让家里的灯亮起来,电视放起来。



    
"""

# 使用固定音色
guding_spk = torch.load("speaker/seed_1397.pt", map_location=device, weights_only=True)

# 相关参数,见文档说明
params_infer_code = ChatTTS.Chat.InferCodeParams(
    spk_emb=guding_spk,  # add sampled speaker
    temperature=.3,     # using custom temperature
    top_P=0.7,          # top P decode
    top_K=20,           # top K decode,
)

# 根据效果自行选择
# use oral_(0-9), laugh_(0-2), break_(0-7)
# to generate special token in text to synthesize.
params_refine_text = ChatTTS.Chat.RefineTextParams(
    prompt='[oral_0][laugh_0][break_4]',
)


def number_to_chinese(num):
    if not isinstance(num, int) or num < 0 or num >= 10 ** 9:
        raise ValueError("Number must be an integer between 0 and 999,999,999 inclusive.")

    units = ["", "十", "百", "千", "万", "十", "百", "千", "亿"]
    digits = ["零", "一", "二", "三", "四", "五", "六", "七", "八", "九"]

    def convert_chunk(chunk):
        result = ""
        zero_flag = False
        for i, digit in enumerate(reversed(str(chunk))):
            d = int(digit)
            if d == 0:
                if not zero_flag:
                    result = digits[d] + result
                    zero_flag = True
            else:
                result = digits[d] + units[i] + result
                zero_flag = False
        return result

    if num == 0:
        return digits[0]

    result = ""
    if num >= 100000000:
        result += convert_chunk(num // 100000000) + units[8]
        num %= 100000000
        if num == 0:
            return result

    if num >= 10000:
        result += convert_chunk(num // 10000) + units[4]
        num %= 10000
        if num == 0:
            return result

    result += convert_chunk(num)

    # Remove trailing zeros
    result = result.rstrip(digits[0])

    # Replace multiple zeros with a single zero
    result = re.sub(r'零+', digits[0], result)

    return result
def percentage_to_chinese(percent):
    if not isinstance(percent, int) or percent < 0 or percent > 100:
        raise ValueError("Percentage must be an integer between 0 and 100 inclusive.")
    return f"百分之{number_to_chinese(percent)}"


def replace_numbers_with_chinese(text):
    def replace(match):
        num = int(match.group())
        return number_to_chinese(num)

    # Use regex to find all numbers in the text
    return re.sub(r'\d+', replace, text)

def replace_percentages_with_chinese(text):
    def replace(match):
        num = int(match.group(1))
        return percentage_to_chinese(num)

    # Use regex to find all percentages in the text
    return re.sub(r'(\d+)%', replace, text)

def replace_special_characters(text):
    """
    将文本中的特殊字符替换为对应的汉字。
    目前支持将 '+' 替换为 '加'。
    """
    replacements = {
        '+': '加',
        '×': '乘',
        '÷': '除',
        # '.': '点'
    }
    for char, replacement in replacements.items():
        text = text.replace(char, replacement)
    return text




def remove_bracketed_content(text):
    """
    移除文本中所有以 [ 开头和 ] 结尾的内容,包括字母、数字和中文字符。
    """
    return re.sub(r'\[.*?\]', '', text)

def replace_punctuation_with_uv_break(text):
    """
    将文本中的标点符号替换为 [uv_break]。
    """
    # 定义常见的标点符号
    punctuation_pattern = r'[!@#¥();;·:“”【】《》、…:—!?]'
    return re.sub(punctuation_pattern, ' ', text)

text = remove_bracketed_content(text)


def text_to_sentences(text, max_length=18):
    # 去除多余的空行和缩进
    text = re.sub(r'\s+', ' ', text).strip()

    # 使用正则表达式根据中英文标点符号分割句子
    punctuation_pattern = re.compile(r'[,。!?.,!?]')
    raw_sentences = punctuation_pattern.split(text)
    # print(f'首次分割后的句子数量:{len(raw_sentences)},其中为空句子数量:{raw_sentences.count("")},{raw_sentences}')

    sentences = []
    current_sentence = ""

    for sentence in raw_sentences:
        sentence = sentence.strip()
        if not sentence:
            continue

        # 如果当前句子加上新句子超过 max_length,则先保存当前句子
        if len(current_sentence) + len(sentence) + 1 > max_length:
            if current_sentence:
                sentences.append(current_sentence)
                current_sentence = ""

        # 如果当前句子为空,则直接赋值
        if not current_sentence:
            current_sentence = sentence
        else:
            current_sentence += ", " + sentence

    # 添加最后一个句子
    if current_sentence:
        sentences.append(current_sentence)

    return sentences




# 将文本分割成列表
#去除最后的[]标注内容
text = remove_bracketed_content(text)
# 替换文本中的百分比
text = replace_percentages_with_chinese(text)
# print(f"百分数:{text}")
# 替换文本中的数字
text = replace_numbers_with_chinese(text)
# print(f"数:{text}")
text=replace_special_characters(text)
# print(f"特殊字符:{text}")
texts = text_to_sentences(text)
print(f"列表:{texts}")

wavs = []

subtitle_lines = []

current_time = 0.0  # 当前时间戳,单位为秒

for index, text in enumerate(texts):
    text=replace_punctuation_with_uv_break(text)
    wav = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
    # 确保音频数据是一维数组
    if wav.ndim > 1:
        wav = wav.squeeze()
    wavs.append(wav)

    # 计算音频持续时间
    sample_rate = 24000  # 假设采样率为24000 Hz
    duration = len(wav) / sample_rate

    # 生成 SRT 格式的字幕行
    start_time = current_time
    end_time = current_time + duration

    start_time_str = time.strftime('%H:%M:%S,', time.gmtime(start_time)) + f"{int((start_time % 1) * 1000):03d}"
    end_time_str = time.strftime('%H:%M:%S,', time.gmtime(end_time)) + f"{int((end_time % 1) * 1000):03d}"

    subtitle_line = f"{index + 1}\n{start_time_str} --> {end_time_str}\n{text}\n"
    subtitle_lines.append(subtitle_line)

    # 更新当前时间戳
    current_time = end_time

def merge_wavs(wavs, output_file, sample_rate=24000):
    # 将所有wav数据连接成一个numpy数组
    merged_wav = np.concatenate(wavs, axis=0)

    # 转换为torch张量并添加通道维度
    merged_wav_tensor = torch.from_numpy(merged_wav).unsqueeze(0)

    # 保存合并后的音频文件
    torchaudio.save(output_file, merged_wav_tensor, sample_rate)

try:
    merge_wavs(wavs, "wwww.wav", 24000)
except Exception as e:
    print(f"Error merging wavs: {e}")
finally:
    # 释放未使用的缓存
    torch.cuda.empty_cache()


# 写入字幕文件
with open("wwww.srt", "w", encoding="utf-8") as subtitle_file:
    subtitle_file.writelines(subtitle_lines)

# 记录结束时间
end_time = time.time()
print(f"End Time: {end_time}")  # 添加调试信息

# 计算并打印耗时
elapsed_time = end_time - start_time
print(f"Total time taken: {elapsed_time:.2f} seconds")

 

 

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐