Tensorflow 音乐预测
在本文中,我展示了如何使用张量流来预测音乐风格。
在我的示例中,我比较了电子音乐和古典音乐。
你可以在我的github上找到代码:
https://github.com/victordalet/sound_to_partition
i - 数据集
第一步,您需要创建一个数据集文件夹,并在里面添加一个音乐风格文件夹,例如我添加一个 techno 文件夹和 classic 文件夹,其中放置我的 wav 歌曲。
ii - 火车
我创建一个训练文件,参数 max_epochs 需要完成。
修改构造函数中与数据集文件夹中您的目录对应的类。
在加载和处理方法中,我从不同的目录检索wav文件并获取频谱图。
出于训练目的,我使用 keras 卷积和模型。
import osimport sysfrom typing import listimport librosaimport numpy as npfrom tensorflow.keras.layers import input, conv2d, maxpooling2d, flatten, densefrom tensorflow.keras.models import modelfrom tensorflow.keras.optimizers import adamfrom sklearn.model_selection import train_test_splitfrom tensorflow.keras.utils import to_categoricalfrom tensorflow.image import resizeclass train: def __init__(self): self.x_train = none self.x_test = none self.y_train = none self.y_test = none self.data_dir: str = 'dataset' self.classes: list[str] = ['techno','classic'] self.max_epochs: int = int(sys.argv[1]) @staticmethod def load_and_preprocess_data(data_dir, classes, target_shape=(128, 128)): data = [] labels = [] for i, class_name in enumerate(classes): class_dir = os.path.join(data_dir, class_name) for filename in os.listdir(class_dir): if filename.endswith('.wav'): file_path = os.path.join(class_dir, filename) audio_data, sample_rate = librosa.load(file_path, sr=none) mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate) mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), target_shape) data.append(mel_spectrogram) labels.append(i) return np.array(data), np.array(labels) def create_model(self): data, labels = self.load_and_preprocess_data(self.data_dir, self.classes) labels = to_categorical(labels, num_classes=len(self.classes)) # convert labels to one-hot encoding self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(data, labels, test_size=0.2, random_state=42) input_shape = self.x_train[0].shape input_layer = input(shape=input_shape) x = conv2d(32, (3, 3), activation='relu')(input_layer) x = maxpooling2d((2, 2))(x) x = conv2d(64, (3, 3), activation='relu')(x) x = maxpooling2d((2, 2))(x) x = flatten()(x) x = dense(64, activation='relu')(x) output_layer = dense(len(self.classes), activation='softmax')(x) self.model = model(input_layer, output_layer) self.model.compile(optimizer=adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy']) def train_model(self): self.model.fit(self.x_train, self.y_train, epochs=self.max_epochs, batch_size=32, validation_data=(self.x_test, self.y_test)) test_accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0) print(test_accuracy[1]) def save_model(self): self.model.save('weight.h5')if __name__ == '__main__': train = train() train.create_model() train.train_model() train.save_model()
iii-测试
为了测试和使用模型,我创建了这个类来检索权重并预测音乐的风格。
不要忘记将正确的类添加到构造函数中。
from typing import Listimport librosaimport numpy as npfrom tensorflow.keras.models import load_modelfrom tensorflow.image import resizeimport tensorflow as tfclass Test: def __init__(self, audio_file_path: str): self.model = load_model('weight.h5') self.target_shape = (128, 128) self.classes: List[str] = ['techno','classic'] self.audio_file_path: str = audio_file_path def test_audio(self, file_path, model): audio_data, sample_rate = librosa.load(file_path, sr=None) mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate) mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), self.target_shape) mel_spectrogram = tf.reshape(mel_spectrogram, (1,) + self.target_shape + (1,)) predictions = model.predict(mel_spectrogram) class_probabilities = predictions[0] predicted_class_index = np.argmax(class_probabilities) return class_probabilities, predicted_class_index def test(self): class_probabilities, predicted_class_index = self.test_audio(self.audio_file_path, self.model) for i, class_label in enumerate(self.classes): probability = class_probabilities[i] print(f'Class: {class_label}, Probability: {probability:.4f}') predicted_class = self.classes[predicted_class_index] accuracy = class_probabilities[predicted_class_index] print(f'The audio is classified as: {predicted_class}') print(f'Accuracy: {accuracy:.4f}')