代码实践#

欢迎来到 “Connecting Music Audio and Natural Language” 教程的在线补充材料

简介#

你是否想象过一个能够理解一段音乐并找到相关音乐的 AI?今天,我们就来构建这样一个系统!使用 PyTorch 和 Hugging Face,我们将创建一个音乐检索模型(query-by-description,基于描述的查询),让用户通过自然语言查询来搜索和探索音乐。

Music-Retrieval Model

我们将构建什么#

今天,我们来构建第二个模型!

在本教程结束时,你将拥有:

  • 一个使用预训练音乐模型(MusicFM)和预训练语言模型(Roberta)的可运行音乐检索模型

  • 我们使用基于 contrastive 的训练方法

前置条件#

  • 基本的 Python 知识

  • 熟悉深度学习概念

  • Google Colab 账户(免费!)

让我们开始吧! 🚀#

第一步:搭建我们的环境#

首先,让我们搭建 Google Colab 环境。创建一个新的 notebook,并确保启用了 GPU 运行时(相信我,你会需要它的!)。

import torch
print("GPU Available:", torch.cuda.is_available())
print("GPU Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import torch
      2 print("GPU Available:", torch.cuda.is_available())
      3 print("GPU Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

ModuleNotFoundError: No module named 'torch'

第二步:理解数据 📊#

我们将使用 LP-MusicCaps-MTT 数据集的子集。为什么选择这个数据集?因为它非常适合学习:

  • 规模适中(3k 训练集,300 测试集)

  • 10 秒的 CC 音频文件

💡 小贴士:如需原始内容,你可以从 LP-MusicCaps-MTT 下载

import torchaudio
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from IPython.display import Audio

# Load Magnatagatune Dataset: 1-min
dataset = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k", split="train")
print("Original Magnatagatune Tags: ", dataset[10]['tags'])
print("-"*10)
print("LP-MusicCaps Captions: ")
print("\n".join(dataset[10]['texts']))
Audio(dataset[10]['audio']["array"], rate=22050)
Original Magnatagatune Tags:  ['vocals', 'female', 'guitar', 'girl', 'pop', 'female vocal', 'rock', 'female vocals', 'female singer']
----------
LP-MusicCaps Captions: 
Get ready to be blown away by the powerful and energetic female vocals accompanied by a catchy guitar riff in this upbeat pop-rock anthem, performed by an incredibly talented girl singer with impressive female vocal range.
A catchy pop-rock song featuring strong female vocals and a prominent guitar riff.
Get ready to experience the dynamic and captivating sound of a female singer with powerful vocals, accompanied by the electric strumming of a guitar - this pop/rock tune will have you hooked on the mesmerizing female vocals of this talented girl.
This song is a powerful combination of female vocals, guitar, and rock influences, with a pop beat that keeps the tempo up. The female singer's voice is full of emotion, creating a sense of vulnerability and rawness. The acoustic sound is perfect for a girl's night out, with the melancholic folk vibe that captures the heart of a female vocalist who tells a story through her music.

第三步:创建我们的数据集类 🎨#

这里开始变得有趣了!我们需要创建一个自定义数据集类,它将:

  • 加载音乐数据(x)和描述文本(y)

import torch
import random
from torch.utils.data import Dataset

class MusicTextDataset(Dataset):
    def __init__(self, split="train"):
        self.data = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k", split=split)
        musicfm_embeds = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k-musicfm-embedding", split=split)
        roberta_embeds = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k-roberta-embedding", split=split)
        self.track2embs = {i["track_id"]:i["embedding"] for i in musicfm_embeds}
        self.caption2embs = {i["track_id"]:i["embedding"] for i in roberta_embeds}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int):
        item = self.data[index]
        text = random.choice(item['texts'])
        h_audio = torch.tensor(self.track2embs[item['track_id']])
        h_text = torch.tensor(self.caption2embs[item['track_id']])
        return {
            "track_id": item["track_id"],
            "text": text,
            "h_audio": h_audio,
            "h_text": h_text
        }
train_dataset = MusicTextDataset(split="train")
test_dataset = MusicTextDataset(split="test")
tr_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, num_workers=0,shuffle=True, drop_last=True)
te_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=128, num_workers=0, shuffle=False, drop_last=True)
for item in test_dataset:
    print(item["track_id"])
    print(item["text"])
    print(item["h_audio"].shape)
    print(item["h_text"].shape)
    break
18754
A powerful female Indian vocalist captivates listeners with her mesmerizing rock singing, infusing her foreign roots into an electrifying blend of contemporary sounds, delivering a captivating performance that evades the realm of opera.
torch.Size([1024])
torch.Size([768])

第四步:构建和训练我们的模型架构 🏗️#

现在到了激动人心的部分——构建我们的模型!我们将使用一个结合了以下组件的现代架构:

  • MusicFM 用于音频理解

  • Roberta 用于文本理解

  • 使用投影(projection)和 contrastive 连接来关联音乐和语言的潜在空间

class JointEmbeddingModel(torch.nn.Module):
    def __init__(self, joint_dim=128, temperature=0.07):
        super().__init__()
        self.joint_dim = joint_dim
        self.temperature = temperature
        # Add projection part
        self.init_temperature = torch.tensor([np.log(1/temperature)])
        self.logit_scale = nn.Parameter(self.init_temperature, requires_grad=True)
        self.text_embedding_dim = 768 # roberta dim
        self.audio_embedding_dim = 1024 # music Fm dim
        self.audio_projection = nn.Sequential(
            nn.Linear(self.audio_embedding_dim, self.joint_dim, bias=False),
            nn.ReLU(),
            nn.Linear(self.joint_dim, self.joint_dim, bias=False)
        )
        self.text_projection = nn.Sequential(
            nn.Linear(self.text_embedding_dim, self.joint_dim, bias=False),
            nn.ReLU(),
            nn.Linear(self.joint_dim, self.joint_dim, bias=False)
        )

    @property
    def device(self):
        return list(self.parameters())[0].device

    @property
    def dtype(self):
        return list(self.parameters())[0].dtype

    def audio_forward(self, h_audio):
        z_audio = self.audio_projection(h_audio)
        return z_audio

    def text_forward(self, h_text):
        z_text = self.text_projection(h_text)
        return z_text

    def simple_contrastive_loss(self, z1, z2):
        z1 = nn.functional.normalize(z1, dim=1)
        z2 = nn.functional.normalize(z2, dim=1)
        temperature = torch.clamp(self.logit_scale.exp(), max=100)
        logits = torch.einsum('nc,mc->nm', [z1, z2]) * temperature.to(self.device)
        N = logits.shape[0]  # batch size per GPU
        labels = torch.arange(N, dtype=torch.long, device=self.device)
        return torch.nn.functional.cross_entropy(logits, labels)

    def forward(self, batch):
        z_audio = self.audio_forward(batch['h_audio'].to(self.device))
        z_text = self.text_forward(batch['h_text'].to(self.device))
        loss_a2t = self.simple_contrastive_loss(z_audio, z_text)
        loss_t2a = self.simple_contrastive_loss(z_text, z_audio)
        loss = (loss_a2t + loss_t2a) / 2
        return loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = JointEmbeddingModel()
model.to(device)
train_parmas = sum(p.numel() for p in model.parameters() if p.requires_grad)
freeze_parmas = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"training model with: train_parmas {train_parmas} params, and {freeze_parmas} freeze parmas")
training model with: train_parmas 262145 params, and 0 freeze parmas
def train(model, dataloader, optimizer, epoch):
    model.train()
    total_loss = 0
    pbar = tqdm(dataloader, desc=f'TRAIN Epoch {epoch:02}')
    for batch in pbar:
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    epoch_loss = total_loss / len(dataloader)
    return epoch_loss

def test(model, dataloader):
    model.eval()
    total_loss = 0
    for batch in dataloader:
        with torch.no_grad():
            loss = model(batch)
        total_loss += loss.item()
    epoch_loss = total_loss / len(dataloader)
    return epoch_loss
NUM_EPOCHS = 10
lr = 1e-2
# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
for epoch in range(NUM_EPOCHS):
    train_loss = train(model, tr_dataloader, optimizer, epoch)
    valid_loss = test(model, te_dataloader)
    print("[Epoch %d/%d] [Train Loss: %.4f] [Valid Loss: %.4f]" % (epoch + 1, NUM_EPOCHS, train_loss, valid_loss))
[Epoch 1/10] [Train Loss: 4.7925] [Valid Loss: 4.4868]
[Epoch 2/10] [Train Loss: 4.3727] [Valid Loss: 4.2788]
[Epoch 3/10] [Train Loss: 4.1146] [Valid Loss: 3.9048]
[Epoch 4/10] [Train Loss: 3.8462] [Valid Loss: 3.7907]
[Epoch 5/10] [Train Loss: 3.6809] [Valid Loss: 3.7047]
[Epoch 6/10] [Train Loss: 3.5244] [Valid Loss: 3.5302]
[Epoch 7/10] [Train Loss: 3.2971] [Valid Loss: 3.5063]
[Epoch 8/10] [Train Loss: 3.1679] [Valid Loss: 3.3523]
[Epoch 9/10] [Train Loss: 3.0083] [Valid Loss: 3.3572]
[Epoch 10/10] [Train Loss: 2.9240] [Valid Loss: 3.2919]

推理与构建检索引擎#

  1. 加载模型和 Embedding

  2. 提取项目嵌入数据库(即向量数据库,Vector Database)

  3. 提取查询 Embedding

  4. 度量距离(相似度)

# load model
model.eval()
print("let's start inference!")
let's start inference!
# bulid metadata db
dataset = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k", split="test")
meta_db = {i["track_id"]:i for i in tqdm(dataset)}
def get_item_vector_db(model, dataloader):
    track_ids, audios, item_joint_embedding = [], [], []
    for item in tqdm(dataloader):
        h_audio = item['h_audio']
        with torch.no_grad():
            z_audio = model.audio_forward(h_audio.to(model.device))
        item_joint_embedding.append(z_audio.detach().cpu())
        track_ids.extend(item['track_id'])
    item_vector_db = torch.cat(item_joint_embedding, dim=0)
    return item_vector_db, track_ids

item_vector_db, track_ids = get_item_vector_db(model, te_dataloader)
text_encoder = AutoModel.from_pretrained("roberta-base").to(device)
text_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
def get_query_embedding(query, model, text_encoder, text_tokenizer):
    encode = text_tokenizer([query],
                          padding='longest',
                          truncation=True,
                          max_length=128,
                          return_tensors="pt")
    input_ids = encode["input_ids"].to(device)
    attention_mask = encode["attention_mask"].to(device)
    with torch.no_grad():
        text_output = text_encoder(input_ids=input_ids , attention_mask=attention_mask)
        h_text = text_output["last_hidden_state"].mean(dim=1)
        z_text = model.text_forward(h_text)
    query_vector = z_text.detach().cpu()
    return query_vector
def retrieval_fn(query, model, item_vector_db, topk=3):
    query_vector = get_query_embedding(query, model, text_encoder, text_tokenizer)
    query_vector = nn.functional.normalize(query_vector, dim=1)
    item_vector_db = nn.functional.normalize(item_vector_db, dim=1)
    similarity_metrics = query_vector @ item_vector_db.T
    _, indices = torch.topk(similarity_metrics, k=topk)
    for i in indices.flatten():
        item = meta_db[track_ids[i]]
        print("track_id: ", item['track_id'])
        print("ground truth tags: ", item["tags"])
        display(Audio(item["audio"]["array"], rate=22050))
query = "country guitar with no vocal"
indices = retrieval_fn(query, model, item_vector_db)
track_id:  46649
ground truth tags:  ['guitar', 'banjo', 'folk', 'strings', 'country', 'no vocals']
track_id:  48072
ground truth tags:  ['no voice', 'guitar', 'strings', 'country', 'violin']
track_id:  33437
ground truth tags:  ['duet', 'classical', 'guitar', 'acoustic', 'classical guitar', 'no vocals', 'spanish', 'slow']

结语 🎉#

恭喜!你已经构建了一个完整的基于描述查询(query-by-description)的系统。但这仅仅是开始——还有很多方法可以改进和扩展这个模型:

  • 尝试不同的架构

  • 使用更大的数据集进行实验

  • 实现更好的评估指标

进一步学习资源 📚#

  1. PyTorch Documentation

现在,去创造令人惊叹的东西吧! 🌟