|
|
--- |
|
|
language: en |
|
|
license: mit |
|
|
tags: |
|
|
- audio |
|
|
- audio-classification |
|
|
- musical-instruments |
|
|
- wav2vec2 |
|
|
- transformers |
|
|
- pytorch |
|
|
datasets: |
|
|
- custom |
|
|
metrics: |
|
|
- accuracy |
|
|
- roc_auc |
|
|
model-index: |
|
|
- name: epoch_musical_instruments_identification_2 |
|
|
results: |
|
|
- task: |
|
|
type: audio-classification |
|
|
name: Musical Instrument Classification |
|
|
metrics: |
|
|
- type: accuracy |
|
|
value: 0.9333 |
|
|
name: Accuracy |
|
|
- type: roc_auc |
|
|
value: 0.9859 |
|
|
name: ROC AUC (Macro) |
|
|
- type: loss |
|
|
value: 1.0639 |
|
|
name: Validation Loss |
|
|
base_model: |
|
|
- facebook/wav2vec2-base-960h |
|
|
--- |
|
|
|
|
|
# Musical Instrument Classification Model |
|
|
|
|
|
This model is a fine-tuned version of [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) for musical instrument classification. It can identify 9 different musical instruments from audio recordings with high accuracy. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
- **Model type:** Audio Classification |
|
|
- **Base model:** facebook/wav2vec2-base-960h |
|
|
- **Language:** Audio (no specific language) |
|
|
- **License:** MIT |
|
|
- **Fine-tuned on:** Custom musical instrument dataset (200 samples for each class) |
|
|
|
|
|
## Performance |
|
|
|
|
|
The model achieves excellent performance on the evaluation set after 5 epochs of training: |
|
|
|
|
|
- **Final Accuracy:** 93.33% |
|
|
- **Final ROC AUC (Macro):** 98.59% |
|
|
- **Final Validation Loss:** 1.064 |
|
|
- **Evaluation Runtime:** 14.18 seconds |
|
|
- **Evaluation Speed:** 25.39 samples/second |
|
|
|
|
|
### Training Progress |
|
|
|
|
|
| Epoch | Training Loss | Validation Loss | ROC AUC | Accuracy | |
|
|
|-------|---------------|-----------------|---------|----------| |
|
|
| 1 | 1.9872 | 1.8875 | 0.9248 | 0.6639 | |
|
|
| 2 | 1.8652 | 1.4793 | 0.9799 | 0.8000 | |
|
|
| 3 | 1.3868 | 1.2311 | 0.9861 | 0.8194 | |
|
|
| 4 | 1.3242 | 1.1121 | 0.9827 | 0.9250 | |
|
|
| 5 | 1.1869 | 1.0639 | 0.9859 | 0.9333 | |
|
|
|
|
|
## Supported Instruments |
|
|
|
|
|
The model can classify the following 9 musical instruments: |
|
|
|
|
|
1. **Acoustic Guitar** |
|
|
2. **Bass Guitar** |
|
|
3. **Drum Set** |
|
|
4. **Electric Guitar** |
|
|
5. **Flute** |
|
|
6. **Hi-Hats** |
|
|
7. **Keyboard** |
|
|
8. **Trumpet** |
|
|
9. **Violin** |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Quick Start with Pipeline |
|
|
|
|
|
```python |
|
|
from transformers import pipeline |
|
|
import torchaudio |
|
|
|
|
|
# Load the classification pipeline |
|
|
classifier = pipeline("audio-classification", model="Bhaveen/epoch_musical_instruments_identification_2") |
|
|
|
|
|
# Load and preprocess audio |
|
|
audio, rate = torchaudio.load("your_audio_file.wav") |
|
|
transform = torchaudio.transforms.Resample(rate, 16000) |
|
|
audio = transform(audio).numpy().reshape(-1)[:48000] |
|
|
|
|
|
# Classify the audio |
|
|
result = classifier(audio) |
|
|
print(result) |
|
|
``` |
|
|
|
|
|
### Using Transformers Directly |
|
|
|
|
|
```python |
|
|
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
|
|
import torchaudio |
|
|
import torch |
|
|
|
|
|
# Load model and feature extractor |
|
|
model_name = "Bhaveen/epoch_musical_instruments_identification_2" |
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
|
|
model = AutoModelForAudioClassification.from_pretrained(model_name) |
|
|
|
|
|
# Load and preprocess audio |
|
|
audio, rate = torchaudio.load("your_audio_file.wav") |
|
|
transform = torchaudio.transforms.Resample(rate, 16000) |
|
|
audio = transform(audio).numpy().reshape(-1)[:48000] |
|
|
|
|
|
# Extract features and make prediction |
|
|
inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
predicted_class = torch.argmax(predictions, dim=-1) |
|
|
|
|
|
print(f"Predicted instrument: {model.config.id2label[predicted_class.item()]}") |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Dataset and Preprocessing |
|
|
|
|
|
- **Custom dataset** with audio recordings of 9 musical instruments |
|
|
- **Train/Test Split:** 80/20 using file numbering (files < 160 for training) |
|
|
- **Data Balancing:** Random oversampling applied to minority classes |
|
|
- **Audio Preprocessing:** |
|
|
- Resampling to 16,000 Hz |
|
|
- Fixed length of 48,000 samples (3 seconds) |
|
|
- Truncation of longer audio files |
|
|
|
|
|
### Training Configuration |
|
|
|
|
|
```python |
|
|
# Training hyperparameters |
|
|
batch_size = 1 |
|
|
gradient_accumulation_steps = 4 |
|
|
learning_rate = 5e-6 |
|
|
num_train_epochs = 5 |
|
|
warmup_steps = 50 |
|
|
weight_decay = 0.02 |
|
|
``` |
|
|
|
|
|
### Model Architecture |
|
|
|
|
|
- **Base Model:** facebook/wav2vec2-base-960h |
|
|
- **Classification Head:** Added for 9-class classification |
|
|
- **Parameters:** ~95M trainable parameters |
|
|
- **Features:** Wav2Vec2 audio representations with fine-tuned classification layer |
|
|
|
|
|
## Technical Specifications |
|
|
|
|
|
- **Audio Format:** WAV files |
|
|
- **Sample Rate:** 16,000 Hz |
|
|
- **Input Length:** 3 seconds (48,000 samples) |
|
|
- **Model Framework:** PyTorch + Transformers |
|
|
- **Inference Device:** GPU recommended (CUDA) |
|
|
|
|
|
## Evaluation Metrics |
|
|
|
|
|
The model uses the following evaluation metrics: |
|
|
|
|
|
- **Accuracy:** Standard classification accuracy |
|
|
- **ROC AUC:** Macro-averaged ROC AUC with one-vs-rest approach |
|
|
- **Multi-class Classification:** Softmax probabilities for all 9 instrument classes |
|
|
|
|
|
|
|
|
|
|
|
## Limitations and Considerations |
|
|
|
|
|
1. **Audio Duration:** Model expects exactly 3-second audio clips (truncates longer, may not work well with shorter) |
|
|
2. **Single Instrument Focus:** Optimized for single instrument classification, mixed instruments may produce uncertain results |
|
|
3. **Audio Quality:** Performance depends on audio quality and recording conditions |
|
|
4. **Sample Rate:** Input must be resampled to 16kHz for optimal performance |
|
|
5. **Domain Specificity:** Trained on specific instrument recordings, may not generalize to all variants or playing styles |
|
|
|
|
|
## Training Environment |
|
|
|
|
|
- **Platform:** Google Colab |
|
|
- **GPU:** CUDA-enabled device |
|
|
- **Libraries:** |
|
|
- transformers==4.28.1 |
|
|
- torchaudio==0.12 |
|
|
- datasets |
|
|
- evaluate |
|
|
- imblearn |
|
|
|
|
|
## Model Files |
|
|
|
|
|
The repository contains: |
|
|
- Model weights and configuration |
|
|
- Feature extractor configuration |
|
|
- Training logs and metrics |
|
|
- Label mappings (id2label, label2id) |
|
|
|
|
|
--- |
|
|
|
|
|
*Model trained as part of a hackathon project* |