代码实践#
欢迎来到 “Connecting Music Audio and Natural Language” 教程的在线补充材料
简介#
你是否想象过一个能够理解一段音乐并找到相关音乐的 AI?今天,我们就来构建这样一个系统!使用 PyTorch 和 Hugging Face,我们将创建一个音乐检索模型(query-by-description,基于描述的查询),让用户通过自然语言查询来搜索和探索音乐。

我们将构建什么#
今天,我们来构建第二个模型!
Audio Encoder: MusicFM, Minz et al. A Foundation Model for Music Informatics
Text Decoder: Roberta-based, Yinhan Liu et al. RoBERTa: A Robustly Optimized BERT Pretraining Approach
在本教程结束时,你将拥有:
一个使用预训练音乐模型(MusicFM)和预训练语言模型(Roberta)的可运行音乐检索模型
我们使用基于 contrastive 的训练方法
前置条件#
基本的 Python 知识
熟悉深度学习概念
Google Colab 账户(免费!)
让我们开始吧! 🚀#
第一步:搭建我们的环境#
首先,让我们搭建 Google Colab 环境。创建一个新的 notebook,并确保启用了 GPU 运行时(相信我,你会需要它的!)。
import torch
print("GPU Available:", torch.cuda.is_available())
print("GPU Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[1], line 1
----> 1 import torch
2 print("GPU Available:", torch.cuda.is_available())
3 print("GPU Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
ModuleNotFoundError: No module named 'torch'
第二步:理解数据 📊#
我们将使用 LP-MusicCaps-MTT 数据集的子集。为什么选择这个数据集?因为它非常适合学习:
规模适中(3k 训练集,300 测试集)
10 秒的 CC 音频文件
💡 小贴士:如需原始内容,你可以从 LP-MusicCaps-MTT 下载
import torchaudio
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from IPython.display import Audio
# Load Magnatagatune Dataset: 1-min
dataset = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k", split="train")
print("Original Magnatagatune Tags: ", dataset[10]['tags'])
print("-"*10)
print("LP-MusicCaps Captions: ")
print("\n".join(dataset[10]['texts']))
Audio(dataset[10]['audio']["array"], rate=22050)
Original Magnatagatune Tags: ['vocals', 'female', 'guitar', 'girl', 'pop', 'female vocal', 'rock', 'female vocals', 'female singer']
----------
LP-MusicCaps Captions:
Get ready to be blown away by the powerful and energetic female vocals accompanied by a catchy guitar riff in this upbeat pop-rock anthem, performed by an incredibly talented girl singer with impressive female vocal range.
A catchy pop-rock song featuring strong female vocals and a prominent guitar riff.
Get ready to experience the dynamic and captivating sound of a female singer with powerful vocals, accompanied by the electric strumming of a guitar - this pop/rock tune will have you hooked on the mesmerizing female vocals of this talented girl.
This song is a powerful combination of female vocals, guitar, and rock influences, with a pop beat that keeps the tempo up. The female singer's voice is full of emotion, creating a sense of vulnerability and rawness. The acoustic sound is perfect for a girl's night out, with the melancholic folk vibe that captures the heart of a female vocalist who tells a story through her music.
第三步:创建我们的数据集类 🎨#
这里开始变得有趣了!我们需要创建一个自定义数据集类,它将:
加载音乐数据(x)和描述文本(y)
import torch
import random
from torch.utils.data import Dataset
class MusicTextDataset(Dataset):
def __init__(self, split="train"):
self.data = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k", split=split)
musicfm_embeds = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k-musicfm-embedding", split=split)
roberta_embeds = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k-roberta-embedding", split=split)
self.track2embs = {i["track_id"]:i["embedding"] for i in musicfm_embeds}
self.caption2embs = {i["track_id"]:i["embedding"] for i in roberta_embeds}
def __len__(self):
return len(self.data)
def __getitem__(self, index: int):
item = self.data[index]
text = random.choice(item['texts'])
h_audio = torch.tensor(self.track2embs[item['track_id']])
h_text = torch.tensor(self.caption2embs[item['track_id']])
return {
"track_id": item["track_id"],
"text": text,
"h_audio": h_audio,
"h_text": h_text
}
train_dataset = MusicTextDataset(split="train")
test_dataset = MusicTextDataset(split="test")
tr_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, num_workers=0,shuffle=True, drop_last=True)
te_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=128, num_workers=0, shuffle=False, drop_last=True)
for item in test_dataset:
print(item["track_id"])
print(item["text"])
print(item["h_audio"].shape)
print(item["h_text"].shape)
break
18754
A powerful female Indian vocalist captivates listeners with her mesmerizing rock singing, infusing her foreign roots into an electrifying blend of contemporary sounds, delivering a captivating performance that evades the realm of opera.
torch.Size([1024])
torch.Size([768])
第四步:构建和训练我们的模型架构 🏗️#
现在到了激动人心的部分——构建我们的模型!我们将使用一个结合了以下组件的现代架构:
MusicFM 用于音频理解
Roberta 用于文本理解
使用投影(projection)和 contrastive 连接来关联音乐和语言的潜在空间
class JointEmbeddingModel(torch.nn.Module):
def __init__(self, joint_dim=128, temperature=0.07):
super().__init__()
self.joint_dim = joint_dim
self.temperature = temperature
# Add projection part
self.init_temperature = torch.tensor([np.log(1/temperature)])
self.logit_scale = nn.Parameter(self.init_temperature, requires_grad=True)
self.text_embedding_dim = 768 # roberta dim
self.audio_embedding_dim = 1024 # music Fm dim
self.audio_projection = nn.Sequential(
nn.Linear(self.audio_embedding_dim, self.joint_dim, bias=False),
nn.ReLU(),
nn.Linear(self.joint_dim, self.joint_dim, bias=False)
)
self.text_projection = nn.Sequential(
nn.Linear(self.text_embedding_dim, self.joint_dim, bias=False),
nn.ReLU(),
nn.Linear(self.joint_dim, self.joint_dim, bias=False)
)
@property
def device(self):
return list(self.parameters())[0].device
@property
def dtype(self):
return list(self.parameters())[0].dtype
def audio_forward(self, h_audio):
z_audio = self.audio_projection(h_audio)
return z_audio
def text_forward(self, h_text):
z_text = self.text_projection(h_text)
return z_text
def simple_contrastive_loss(self, z1, z2):
z1 = nn.functional.normalize(z1, dim=1)
z2 = nn.functional.normalize(z2, dim=1)
temperature = torch.clamp(self.logit_scale.exp(), max=100)
logits = torch.einsum('nc,mc->nm', [z1, z2]) * temperature.to(self.device)
N = logits.shape[0] # batch size per GPU
labels = torch.arange(N, dtype=torch.long, device=self.device)
return torch.nn.functional.cross_entropy(logits, labels)
def forward(self, batch):
z_audio = self.audio_forward(batch['h_audio'].to(self.device))
z_text = self.text_forward(batch['h_text'].to(self.device))
loss_a2t = self.simple_contrastive_loss(z_audio, z_text)
loss_t2a = self.simple_contrastive_loss(z_text, z_audio)
loss = (loss_a2t + loss_t2a) / 2
return loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = JointEmbeddingModel()
model.to(device)
train_parmas = sum(p.numel() for p in model.parameters() if p.requires_grad)
freeze_parmas = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"training model with: train_parmas {train_parmas} params, and {freeze_parmas} freeze parmas")
training model with: train_parmas 262145 params, and 0 freeze parmas
def train(model, dataloader, optimizer, epoch):
model.train()
total_loss = 0
pbar = tqdm(dataloader, desc=f'TRAIN Epoch {epoch:02}')
for batch in pbar:
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()
total_loss += loss.item()
epoch_loss = total_loss / len(dataloader)
return epoch_loss
def test(model, dataloader):
model.eval()
total_loss = 0
for batch in dataloader:
with torch.no_grad():
loss = model(batch)
total_loss += loss.item()
epoch_loss = total_loss / len(dataloader)
return epoch_loss
NUM_EPOCHS = 10
lr = 1e-2
# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
for epoch in range(NUM_EPOCHS):
train_loss = train(model, tr_dataloader, optimizer, epoch)
valid_loss = test(model, te_dataloader)
print("[Epoch %d/%d] [Train Loss: %.4f] [Valid Loss: %.4f]" % (epoch + 1, NUM_EPOCHS, train_loss, valid_loss))
[Epoch 1/10] [Train Loss: 4.7925] [Valid Loss: 4.4868]
[Epoch 2/10] [Train Loss: 4.3727] [Valid Loss: 4.2788]
[Epoch 3/10] [Train Loss: 4.1146] [Valid Loss: 3.9048]
[Epoch 4/10] [Train Loss: 3.8462] [Valid Loss: 3.7907]
[Epoch 5/10] [Train Loss: 3.6809] [Valid Loss: 3.7047]
[Epoch 6/10] [Train Loss: 3.5244] [Valid Loss: 3.5302]
[Epoch 7/10] [Train Loss: 3.2971] [Valid Loss: 3.5063]
[Epoch 8/10] [Train Loss: 3.1679] [Valid Loss: 3.3523]
[Epoch 9/10] [Train Loss: 3.0083] [Valid Loss: 3.3572]
[Epoch 10/10] [Train Loss: 2.9240] [Valid Loss: 3.2919]
推理与构建检索引擎#
加载模型和 Embedding
提取项目嵌入数据库(即向量数据库,Vector Database)
提取查询 Embedding
度量距离(相似度)
# load model
model.eval()
print("let's start inference!")
let's start inference!
# bulid metadata db
dataset = load_dataset("mulab-mir/lp-music-caps-magnatagatune-3k", split="test")
meta_db = {i["track_id"]:i for i in tqdm(dataset)}
def get_item_vector_db(model, dataloader):
track_ids, audios, item_joint_embedding = [], [], []
for item in tqdm(dataloader):
h_audio = item['h_audio']
with torch.no_grad():
z_audio = model.audio_forward(h_audio.to(model.device))
item_joint_embedding.append(z_audio.detach().cpu())
track_ids.extend(item['track_id'])
item_vector_db = torch.cat(item_joint_embedding, dim=0)
return item_vector_db, track_ids
item_vector_db, track_ids = get_item_vector_db(model, te_dataloader)
text_encoder = AutoModel.from_pretrained("roberta-base").to(device)
text_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
def get_query_embedding(query, model, text_encoder, text_tokenizer):
encode = text_tokenizer([query],
padding='longest',
truncation=True,
max_length=128,
return_tensors="pt")
input_ids = encode["input_ids"].to(device)
attention_mask = encode["attention_mask"].to(device)
with torch.no_grad():
text_output = text_encoder(input_ids=input_ids , attention_mask=attention_mask)
h_text = text_output["last_hidden_state"].mean(dim=1)
z_text = model.text_forward(h_text)
query_vector = z_text.detach().cpu()
return query_vector
def retrieval_fn(query, model, item_vector_db, topk=3):
query_vector = get_query_embedding(query, model, text_encoder, text_tokenizer)
query_vector = nn.functional.normalize(query_vector, dim=1)
item_vector_db = nn.functional.normalize(item_vector_db, dim=1)
similarity_metrics = query_vector @ item_vector_db.T
_, indices = torch.topk(similarity_metrics, k=topk)
for i in indices.flatten():
item = meta_db[track_ids[i]]
print("track_id: ", item['track_id'])
print("ground truth tags: ", item["tags"])
display(Audio(item["audio"]["array"], rate=22050))
query = "country guitar with no vocal"
indices = retrieval_fn(query, model, item_vector_db)
track_id: 46649
ground truth tags: ['guitar', 'banjo', 'folk', 'strings', 'country', 'no vocals']
track_id: 48072
ground truth tags: ['no voice', 'guitar', 'strings', 'country', 'violin']
track_id: 33437
ground truth tags: ['duet', 'classical', 'guitar', 'acoustic', 'classical guitar', 'no vocals', 'spanish', 'slow']
结语 🎉#
恭喜!你已经构建了一个完整的基于描述查询(query-by-description)的系统。但这仅仅是开始——还有很多方法可以改进和扩展这个模型:
尝试不同的架构
使用更大的数据集进行实验
实现更好的评估指标
进一步学习资源 📚#
现在,去创造令人惊叹的东西吧! 🌟