emb|从视频到音频:使用VIT进行音频分类

emb|从视频到音频:使用VIT进行音频分类

文章图片

emb|从视频到音频:使用VIT进行音频分类

文章图片

emb|从视频到音频:使用VIT进行音频分类

就机器学习而言 , 音频本身是一个有广泛应用的完整的领域 , 包括语音识别、音乐分类和声音事件检测等等 。 传统上音频分类一直使用谱图分析和隐马尔可夫模型等方法 , 这些方法已被证明是有效的 , 但也有其局限性 。 近期VIT已经成为音频任务的一个有前途的替代品 , OpenAI的Whisper就是一个很好的例子 。

在本文中 , 我们将利用ViT - Vision Transformer的是一个Pytorch实现在音频分类数据集GTZAN数据集-音乐类型分类上训练它 。
数据集介绍GTZAN 数据集是在音乐流派识别 (MGR) 研究中最常用的公共数据集 。这些文件是在 2000-2001 年从各种来源收集的 , 包括个人 CD、收音机、麦克风录音 , 代表各种录音条件下的声音 。

这个数据集由子文件夹组成 , 每个子文件夹是一种类型 。

加载数据集我们将加载每个.wav文件 , 并通过librosa库生成相应的Mel谱图 。
mel谱图是声音信号的频谱内容的一种可视化表示 , 它的垂直轴表示mel尺度上的频率 , 水平轴表示时间 。 它是音频信号处理中常用的一种表示形式 , 特别是在音乐信息检索领域 。
梅尔音阶(Mel scale , 英语:mel scale)是一个考虑到人类音高感知的音阶 。 因为人类不会感知线性范围的频率 , 也就是说我们在检测低频差异方面要胜于高频 。例如 , 我们可以轻松分辨出500 Hz和1000 Hz之间的差异 , 但是即使之间的距离相同 , 我们也很难分辨出10000 Hz和10500 Hz之间的差异 。 所以梅尔音阶解决了这个问题 , 如果梅尔音阶的差异相同 , 则意指人类感觉到的音高差异将相同 。
def wav2melspec(fp):
   y sr = librosa.load(fp)
   S = librosa.feature.melspectrogram(y=y sr=sr n_mels=128)
   log_S = librosa.amplitude_to_db(S ref=np.max)
   img = librosa.display.specshow(log_S sr=sr x_axis='time' y_axis='mel')
   # get current figure without white border
   img = plt.gcf()
   img.gca().xaxis.set_major_locator(plt.NullLocator())
   img.gca().yaxis.set_major_locator(plt.NullLocator())
   img.subplots_adjust(top = 1 bottom = 0 right = 1 left = 0
           hspace = 0 wspace = 0)
   img.gca().xaxis.set_major_locator(plt.NullLocator())
   img.gca().yaxis.set_major_locator(plt.NullLocator())
   # to pil image
   img.canvas.draw()
   img = Image.frombytes('RGB' img.canvas.get_width_height() img.canvas.tostring_rgb())
   return img
上述函数将产生一个简单的mel谱图:

现在我们从文件夹中加载数据集 , 并对图像应用转换 。
class AudioDataset(Dataset):
   def __init__(self root transform=None):
       self.root = root
       self.transform = transform
       self.classes = sorted(os.listdir(root))
       self.class_to_idx = {c: i for i c in enumerate(self.classes)
       self.samples = [

       for c in self.classes:
           for fp in os.listdir(os.path.join(root c)):
               self.samples.append((os.path.join(root c fp) self.class_to_idx[c
))
   
   def __len__(self):