{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "uq9k8YYUKjnp" }, "outputs": [], "source": [ "import os\n", "import urllib.request\n", "import zipfile\n", "import json\n", "import pandas as pd\n", "import time\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from sklearn.model_selection import train_test_split\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "L5h3Tsa0LIoo" }, "outputs": [], "source": [ "def unzip_archive(filepath, dir_path):\n", " with zipfile.ZipFile(f\"{filepath}\", 'r') as zip_ref:\n", " zip_ref.extractall(dir_path)\n", "\n", "unzip_archive(os.getcwd() + '/data/raw/spotify_million_playlist_dataset.zip', os.getcwd() + '/data/raw/playlists')\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import shutil\n", "\n", "def make_dir(directory):\n", " if os.path.exists(directory):\n", " shutil.rmtree(directory)\n", " os.makedirs(directory)\n", " else:\n", " os.makedirs(directory)\n", " \n", "directory = os.getcwd() + '/data/raw/data'\n", "make_dir(directory)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "cols = [\n", " 'name',\n", " 'pid',\n", " 'num_followers',\n", " 'pos',\n", " 'artist_name',\n", " 'track_name',\n", " 'album_name'\n", "]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qyCujIu8cDGg", "outputId": "0964ace3-2916-49e3-eebf-2e08e61d95d9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mpd.slice.188000-188999.json\t100/1000\t10.0%" ] } ], "source": [ "\n", "directory = os.getcwd() + '/data/raw/playlists/data'\n", "df = pd.DataFrame()\n", "index = 0\n", "# Loop through all files in the directory\n", "for filename in os.listdir(directory):\n", " # Check if the item is a file (not a subdirectory)\n", " if os.path.isfile(os.path.join(directory, filename)):\n", " if filename.find('.json') != -1 :\n", " index += 1\n", "\n", " # Print the filename or perform operations on the file\n", " print(f'\\r{filename}\\t{index}/1000\\t{((index/1000)*100):.1f}%', end='')\n", "\n", " # If you need the full file path, you can use:\n", " full_path = os.path.join(directory, filename)\n", "\n", " with open(full_path, 'r') as file:\n", " json_data = json.load(file)\n", "\n", " temp = pd.DataFrame(json_data['playlists'])\n", " expanded_df = temp.explode('tracks').reset_index(drop=True)\n", "\n", " # Normalize the JSON data\n", " json_normalized = pd.json_normalize(expanded_df['tracks'])\n", "\n", " # Concatenate the original DataFrame with the normalized JSON data\n", " result = pd.concat([expanded_df.drop(columns=['tracks']), json_normalized], axis=1)\n", " \n", " result = result[cols]\n", "\n", " df = pd.concat([df, result], axis=0, ignore_index=True)\n", " \n", " if index % 50 == 0:\n", " df.to_parquet(f'{os.getcwd()}/data/raw/data/playlists_{index % 1000}.parquet')\n", " del df\n", " df = pd.DataFrame()\n", " if index % 100 == 0:\n", " break" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import pyarrow.parquet as pq\n", "\n", "def read_parquet_folder(folder_path):\n", " dataframes = []\n", " for file in os.listdir(folder_path):\n", " if file.endswith('.parquet'):\n", " file_path = os.path.join(folder_path, file)\n", " df = pd.read_parquet(file_path)\n", " dataframes.append(df)\n", " \n", " return pd.concat(dataframes, ignore_index=True)\n", "\n", "folder_path = os.getcwd() + '/data/raw/data'\n", "df = read_parquet_folder(folder_path)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "directory = os.getcwd() + '/data/raw/mappings'\n", "make_dir(directory)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def create_ids(df, col, name):\n", " # Create a dictionary mapping unique values to IDs\n", " value_to_id = {val: i for i, val in enumerate(df[col].unique())}\n", "\n", " # Create a new column with the IDs\n", " df[f'{name}_id'] = df[col].map(value_to_id)\n", " df[[f'{name}_id', col]].drop_duplicates().to_csv(os.getcwd() + f'/data/raw/mappings/{name}.csv')\n", " # df = df.drop(col, axis=1)\n", " return df" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "df = create_ids(df, 'artist_name', 'artist')\n", "df = create_ids(df, 'pid', 'playlist')\n", "df = create_ids(df, 'track_name', 'song')\n", "df = create_ids(df, 'album_name', 'album')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "df['artist_count'] = df.groupby(['playlist_id','artist_id'])['song_id'].transform('nunique')\n", "df['album_count'] = df.groupby(['playlist_id','artist_id'])['album_id'].transform('nunique')\n", "df['song_count'] = df.groupby(['playlist_id','artist_id'])['song_id'].transform('count')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "df['playlist_songs'] = df.groupby(['playlist_id'])['pos'].transform('max')\n", "df['playlist_songs'] += 1" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "df['artist_percent'] = df['artist_count'] / df['playlist_songs']\n", "df['song_percent'] = df['song_count'] / df['playlist_songs']\n", "df['album_percent'] = df['album_count'] / df['playlist_songs']" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namepidnum_followersposartist_nametrack_namealbum_nameartist_idplaylist_idsong_idalbum_idartist_countalbum_countsong_countplaylist_songsartist_percentsong_percentalbum_percent
212throwbacks14300520R. KellyIgnition - RemixChocolate Factory10852031521111930.0051810.0051810.005181
213throwbacks14300521Backstreet BoysI Want It That WayOriginal Album Classics10952041531111930.0051810.0051810.005181
214throwbacks14300522*NSYNCBye Bye ByeNo Strings Attached11052051541111930.0051810.0051810.005181
215throwbacks14300523Fountains Of WayneStacy's MomWelcome Interstate Managers11152061551111930.0051810.0051810.005181
216throwbacks14300524Bowling For Soup1985A Hangover You Don't Deserve11252071561111930.0051810.0051810.005181
.........................................................
400throwbacks1430052188JoJoToo Little, Too Late - Radio VersionToo Little, Too Late19953902931111930.0051810.0051810.005181
401throwbacks1430052189Spice GirlsWannabe - Radio EditSpice20053912941111930.0051810.0051810.005181
402throwbacks1430052190MiMSThis Is Why I'm HotMusic Is My Savior20153922951111930.0051810.0051810.005181
403throwbacks1430052191RihannaDisturbiaGood Girl Gone Bad11553932963331930.0155440.0155440.015544
404throwbacks1430052192DEVBass Down LowThe Night The Sun Came Up17953942642121930.0103630.0103630.005181
\n", "

193 rows × 18 columns

\n", "
" ], "text/plain": [ " name pid num_followers pos artist_name \\\n", "212 throwbacks 143005 2 0 R. Kelly \n", "213 throwbacks 143005 2 1 Backstreet Boys \n", "214 throwbacks 143005 2 2 *NSYNC \n", "215 throwbacks 143005 2 3 Fountains Of Wayne \n", "216 throwbacks 143005 2 4 Bowling For Soup \n", ".. ... ... ... ... ... \n", "400 throwbacks 143005 2 188 JoJo \n", "401 throwbacks 143005 2 189 Spice Girls \n", "402 throwbacks 143005 2 190 MiMS \n", "403 throwbacks 143005 2 191 Rihanna \n", "404 throwbacks 143005 2 192 DEV \n", "\n", " track_name album_name \\\n", "212 Ignition - Remix Chocolate Factory \n", "213 I Want It That Way Original Album Classics \n", "214 Bye Bye Bye No Strings Attached \n", "215 Stacy's Mom Welcome Interstate Managers \n", "216 1985 A Hangover You Don't Deserve \n", ".. ... ... \n", "400 Too Little, Too Late - Radio Version Too Little, Too Late \n", "401 Wannabe - Radio Edit Spice \n", "402 This Is Why I'm Hot Music Is My Savior \n", "403 Disturbia Good Girl Gone Bad \n", "404 Bass Down Low The Night The Sun Came Up \n", "\n", " artist_id playlist_id song_id album_id artist_count album_count \\\n", "212 108 5 203 152 1 1 \n", "213 109 5 204 153 1 1 \n", "214 110 5 205 154 1 1 \n", "215 111 5 206 155 1 1 \n", "216 112 5 207 156 1 1 \n", ".. ... ... ... ... ... ... \n", "400 199 5 390 293 1 1 \n", "401 200 5 391 294 1 1 \n", "402 201 5 392 295 1 1 \n", "403 115 5 393 296 3 3 \n", "404 179 5 394 264 2 1 \n", "\n", " song_count playlist_songs artist_percent song_percent album_percent \n", "212 1 193 0.005181 0.005181 0.005181 \n", "213 1 193 0.005181 0.005181 0.005181 \n", "214 1 193 0.005181 0.005181 0.005181 \n", "215 1 193 0.005181 0.005181 0.005181 \n", "216 1 193 0.005181 0.005181 0.005181 \n", ".. ... ... ... ... ... \n", "400 1 193 0.005181 0.005181 0.005181 \n", "401 1 193 0.005181 0.005181 0.005181 \n", "402 1 193 0.005181 0.005181 0.005181 \n", "403 3 193 0.015544 0.015544 0.015544 \n", "404 2 193 0.010363 0.010363 0.005181 \n", "\n", "[193 rows x 18 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[df['playlist_id'] == 5]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
playlist_idartist_idartist_percent
0000.571429
1000.571429
2000.571429
3000.571429
4000.571429
\n", "
" ], "text/plain": [ " playlist_id artist_id artist_percent\n", "0 0 0 0.571429\n", "1 0 0 0.571429\n", "2 0 0 0.571429\n", "3 0 0 0.571429\n", "4 0 0 0.571429" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "artists = df.loc[:,['playlist_id','artist_id','album_id','album_percent']]\n", "artists.head()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "X = artists.loc[:,['playlist_id','artist_id','album_id']]\n", "y = artists.loc[:,'album_percent']\n", "\n", "# Split our data into training and test sets\n", "X_train, X_val, y_train, y_val = train_test_split(X,y,random_state=0, test_size=0.2)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def prep_dataloaders(X_train,y_train,X_val,y_val,batch_size):\n", " # Convert training and test data to TensorDatasets\n", " trainset = TensorDataset(torch.from_numpy(np.array(X_train)).long(), \n", " torch.from_numpy(np.array(y_train)).float())\n", " valset = TensorDataset(torch.from_numpy(np.array(X_val)).long(), \n", " torch.from_numpy(np.array(y_val)).float())\n", "\n", " # Create Dataloaders for our training and test data to allow us to iterate over minibatches \n", " trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", " valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False)\n", "\n", " return trainloader, valloader\n", "\n", "batchsize = 64\n", "trainloader,valloader = prep_dataloaders(X_train,y_train,X_val,y_val,batchsize)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "class NNColabFiltering(nn.Module):\n", " \n", " def __init__(self, n_playlists, n_artists, embedding_dim_users, embedding_dim_items, n_activations, rating_range):\n", " super().__init__()\n", " self.user_embeddings = nn.Embedding(num_embeddings=n_playlists,embedding_dim=embedding_dim_users)\n", " self.item_embeddings = nn.Embedding(num_embeddings=n_artists,embedding_dim=embedding_dim_items)\n", " self.fc1 = nn.Linear(embedding_dim_users+embedding_dim_items,n_activations)\n", " self.fc2 = nn.Linear(n_activations,1)\n", " self.rating_range = rating_range\n", "\n", " def forward(self, X):\n", " # Get embeddings for minibatch\n", " embedded_users = self.user_embeddings(X[:,0])\n", " embedded_items = self.item_embeddings(X[:,1])\n", " # Concatenate user and item embeddings\n", " embeddings = torch.cat([embedded_users,embedded_items],dim=1)\n", " # Pass embeddings through network\n", " preds = self.fc1(embeddings)\n", " preds = F.relu(preds)\n", " preds = self.fc2(preds)\n", " # Scale predicted ratings to target-range [low,high]\n", " preds = torch.sigmoid(preds) * (self.rating_range[1]-self.rating_range[0]) + self.rating_range[0]\n", " return preds" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def train_model(model, criterion, optimizer, dataloaders, device, num_epochs=5, scheduler=None):\n", " model = model.to(device) # Send model to GPU if available\n", " since = time.time()\n", "\n", " costpaths = {'train':[],'val':[]}\n", "\n", " for epoch in range(num_epochs):\n", " print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n", " print('-' * 10)\n", "\n", " # Each epoch has a training and validation phase\n", " for phase in ['train', 'val']:\n", " if phase == 'train':\n", " model.train() # Set model to training mode\n", " else:\n", " model.eval() # Set model to evaluate mode\n", "\n", " running_loss = 0.0\n", "\n", " # Get the inputs and labels, and send to GPU if available\n", " index = 0\n", " for (inputs,labels) in dataloaders[phase]:\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", "\n", " # Zero the weight gradients\n", " optimizer.zero_grad()\n", "\n", " # Forward pass to get outputs and calculate loss\n", " # Track gradient only for training data\n", " with torch.set_grad_enabled(phase == 'train'):\n", " outputs = model.forward(inputs).view(-1)\n", " loss = criterion(outputs, labels)\n", "\n", " # Backpropagation to get the gradients with respect to each weight\n", " # Only if in train\n", " if phase == 'train':\n", " loss.backward()\n", " # Update the weights\n", " optimizer.step()\n", "\n", " # Convert loss into a scalar and add it to running_loss\n", " running_loss += np.sqrt(loss.item()) * labels.size(0)\n", " print(f'\\r{running_loss} {index} {index / len(dataloaders[phase])}', end='')\n", " index +=1\n", "\n", " # Step along learning rate scheduler when in train\n", " if (phase == 'train') and (scheduler is not None):\n", " scheduler.step()\n", "\n", " # Calculate and display average loss and accuracy for the epoch\n", " epoch_loss = running_loss / len(dataloaders[phase].dataset)\n", " costpaths[phase].append(epoch_loss)\n", " print('{} loss: {:.4f}'.format(phase, epoch_loss))\n", "\n", " time_elapsed = time.time() - since\n", " print('Training complete in {:.0f}m {:.0f}s'.format(\n", " time_elapsed // 60, time_elapsed % 60))\n", "\n", " return costpaths" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0/2\n", "----------\n", "910724978601.7391 123493 100.00%\n", "train loss: 115229.4395\n", "227700857865.127 30873 100.00%\n", "val loss: 115239.3512\n", "Epoch 1/2\n", "----------\n", "910727409277.4519 123493 100.00%\n", "train loss: 115229.7471\n", "227700857865.127 30873 100.00%\n", "val loss: 115239.3512\n", "Epoch 2/2\n", "----------\n", "910734475316.9005 123493 100.00%\n", "train loss: 115230.6411\n", "227700857865.127 30873 100.00%\n", "val loss: 115239.3512\n", "Training complete in 71m 54s\n" ] } ], "source": [ "dataloaders = {'train':trainloader, 'val':valloader}\n", "n_playlists = X.loc[:,'playlist_id'].max()+1\n", "n_artists = X.loc[:,'artist_id'].max()+1\n", "n_albums = X.loc[:,'album_id'].max()+1\n", "model = NNColabFiltering(\n", " n_playlists,\n", " n_artists,\n", " embedding_dim_users=50,\n", " embedding_dim_items=50,\n", " n_activations = 100,\n", " rating_range=[0.,n_albums]\n", ")\n", "criterion = nn.MSELoss()\n", "lr=0.001\n", "n_epochs=10\n", "wd=1e-3\n", "optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "costpaths = train_model(model,criterion,optimizer,dataloaders, device, n_epochs, scheduler=None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the cost over training and validation sets\n", "fig,ax = plt.subplots(1,2,figsize=(15,5))\n", "for i,key in enumerate(costpaths.keys()):\n", " ax_sub=ax[i%3]\n", " ax_sub.plot(costpaths[key])\n", " ax_sub.set_title(key)\n", " ax_sub.set_xlabel('Epoch')\n", " ax_sub.set_ylabel('Loss')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# Save the entire model\n", "torch.save(model, os.getcwd() + '/models/recommender.pt')" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def generate_recommendations(artist_album, playlists, model, playlist_id, device, top_n=10, batch_size=1024):\n", " model.eval()\n", "\n", "\n", " all_movie_ids = torch.tensor(artist_album['artist_album_id'].values, dtype=torch.long, device=device)\n", " user_ids = torch.full((len(all_movie_ids),), playlist_id, dtype=torch.long, device=device)\n", "\n", " # Initialize tensor to store all predictions\n", " all_predictions = torch.zeros(len(all_movie_ids), device=device)\n", "\n", " # Generate predictions in batches\n", " with torch.no_grad():\n", " for i in range(0, len(all_movie_ids), batch_size):\n", " batch_user_ids = user_ids[i:i+batch_size]\n", " batch_movie_ids = all_movie_ids[i:i+batch_size]\n", "\n", " input_tensor = torch.stack([batch_user_ids, batch_movie_ids], dim=1)\n", " batch_predictions = model(input_tensor).squeeze()\n", " all_predictions[i:i+batch_size] = batch_predictions\n", "\n", " # Convert to numpy for easier handling\n", " predictions = all_predictions.cpu().numpy()\n", "\n", " albums_listened = set(playlists.loc[playlists['playlist_id'] == playlist_id, 'artist_album_id'].tolist())\n", "\n", " unlistened_mask = np.isin(artist_album['artist_album_id'].values, list(albums_listened), invert=True)\n", "\n", " # Get top N recommendations\n", " top_indices = np.argsort(predictions[unlistened_mask])[-top_n:][::-1]\n", " recs = artist_album['artist_album_id'].values[unlistened_mask][top_indices]\n", "\n", " recs_names = artist_album.loc[artist_album['artist_album_id'].isin(recs)]\n", " album, artist = recs_names['album_name'].values, recs_names['artist_name'].values\n", "\n", " return album.tolist(), artist.tolist() " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Precision: 5.0609978643478826e-06\n", "Recall: 5.0609978643478826e-06\n" ] } ], "source": [ "from torchmetrics import Precision, Recall\n", "\n", "precision = Precision(task=\"multiclass\", num_classes=num_classes).to(device) \n", "recall = Recall(task=\"multiclass\", num_classes=num_classes).to(device) \n", "\n", "\n", "model.eval()\n", "with torch.no_grad():\n", " for batch in dataloaders['val']:\n", " inputs, targets = batch\n", " inputs = inputs.to(device)\n", " targets = targets.to(device)\n", "\n", " outputs = model(inputs)\n", "\n", " # For binary classification\n", " preds = torch.argmax(outputs, dim=1)\n", "\n", " # Update metrics\n", " precision(preds, targets)\n", " recall(preds, targets)\n", "\n", "# Compute final metrics\n", "final_precision = precision.compute()\n", "final_recall = recall.compute()\n", "\n", "print(f\"Precision: {final_precision}\")\n", "print(f\"Recall: {final_recall}\")" ] } ], "metadata": { "colab": { "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }