Bhaveen's picture
Added readme
5c9f529 verified
---
language: en
license: mit
tags:
- audio
- audio-classification
- musical-instruments
- wav2vec2
- transformers
- pytorch
datasets:
- custom
metrics:
- accuracy
- roc_auc
model-index:
- name: epoch_musical_instruments_identification_2
results:
- task:
type: audio-classification
name: Musical Instrument Classification
metrics:
- type: accuracy
value: 0.9333
name: Accuracy
- type: roc_auc
value: 0.9859
name: ROC AUC (Macro)
- type: loss
value: 1.0639
name: Validation Loss
base_model:
- facebook/wav2vec2-base-960h
---
# Musical Instrument Classification Model
This model is a fine-tuned version of [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) for musical instrument classification. It can identify 9 different musical instruments from audio recordings with high accuracy.
## Model Description
- **Model type:** Audio Classification
- **Base model:** facebook/wav2vec2-base-960h
- **Language:** Audio (no specific language)
- **License:** MIT
- **Fine-tuned on:** Custom musical instrument dataset (200 samples for each class)
## Performance
The model achieves excellent performance on the evaluation set after 5 epochs of training:
- **Final Accuracy:** 93.33%
- **Final ROC AUC (Macro):** 98.59%
- **Final Validation Loss:** 1.064
- **Evaluation Runtime:** 14.18 seconds
- **Evaluation Speed:** 25.39 samples/second
### Training Progress
| Epoch | Training Loss | Validation Loss | ROC AUC | Accuracy |
|-------|---------------|-----------------|---------|----------|
| 1 | 1.9872 | 1.8875 | 0.9248 | 0.6639 |
| 2 | 1.8652 | 1.4793 | 0.9799 | 0.8000 |
| 3 | 1.3868 | 1.2311 | 0.9861 | 0.8194 |
| 4 | 1.3242 | 1.1121 | 0.9827 | 0.9250 |
| 5 | 1.1869 | 1.0639 | 0.9859 | 0.9333 |
## Supported Instruments
The model can classify the following 9 musical instruments:
1. **Acoustic Guitar**
2. **Bass Guitar**
3. **Drum Set**
4. **Electric Guitar**
5. **Flute**
6. **Hi-Hats**
7. **Keyboard**
8. **Trumpet**
9. **Violin**
## Usage
### Quick Start with Pipeline
```python
from transformers import pipeline
import torchaudio
# Load the classification pipeline
classifier = pipeline("audio-classification", model="Bhaveen/epoch_musical_instruments_identification_2")
# Load and preprocess audio
audio, rate = torchaudio.load("your_audio_file.wav")
transform = torchaudio.transforms.Resample(rate, 16000)
audio = transform(audio).numpy().reshape(-1)[:48000]
# Classify the audio
result = classifier(audio)
print(result)
```
### Using Transformers Directly
```python
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import torchaudio
import torch
# Load model and feature extractor
model_name = "Bhaveen/epoch_musical_instruments_identification_2"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(model_name)
# Load and preprocess audio
audio, rate = torchaudio.load("your_audio_file.wav")
transform = torchaudio.transforms.Resample(rate, 16000)
audio = transform(audio).numpy().reshape(-1)[:48000]
# Extract features and make prediction
inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1)
print(f"Predicted instrument: {model.config.id2label[predicted_class.item()]}")
```
## Training Details
### Dataset and Preprocessing
- **Custom dataset** with audio recordings of 9 musical instruments
- **Train/Test Split:** 80/20 using file numbering (files < 160 for training)
- **Data Balancing:** Random oversampling applied to minority classes
- **Audio Preprocessing:**
- Resampling to 16,000 Hz
- Fixed length of 48,000 samples (3 seconds)
- Truncation of longer audio files
### Training Configuration
```python
# Training hyperparameters
batch_size = 1
gradient_accumulation_steps = 4
learning_rate = 5e-6
num_train_epochs = 5
warmup_steps = 50
weight_decay = 0.02
```
### Model Architecture
- **Base Model:** facebook/wav2vec2-base-960h
- **Classification Head:** Added for 9-class classification
- **Parameters:** ~95M trainable parameters
- **Features:** Wav2Vec2 audio representations with fine-tuned classification layer
## Technical Specifications
- **Audio Format:** WAV files
- **Sample Rate:** 16,000 Hz
- **Input Length:** 3 seconds (48,000 samples)
- **Model Framework:** PyTorch + Transformers
- **Inference Device:** GPU recommended (CUDA)
## Evaluation Metrics
The model uses the following evaluation metrics:
- **Accuracy:** Standard classification accuracy
- **ROC AUC:** Macro-averaged ROC AUC with one-vs-rest approach
- **Multi-class Classification:** Softmax probabilities for all 9 instrument classes
## Limitations and Considerations
1. **Audio Duration:** Model expects exactly 3-second audio clips (truncates longer, may not work well with shorter)
2. **Single Instrument Focus:** Optimized for single instrument classification, mixed instruments may produce uncertain results
3. **Audio Quality:** Performance depends on audio quality and recording conditions
4. **Sample Rate:** Input must be resampled to 16kHz for optimal performance
5. **Domain Specificity:** Trained on specific instrument recordings, may not generalize to all variants or playing styles
## Training Environment
- **Platform:** Google Colab
- **GPU:** CUDA-enabled device
- **Libraries:**
- transformers==4.28.1
- torchaudio==0.12
- datasets
- evaluate
- imblearn
## Model Files
The repository contains:
- Model weights and configuration
- Feature extractor configuration
- Training logs and metrics
- Label mappings (id2label, label2id)
---
*Model trained as part of a hackathon project*