Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import PIL | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from torch import optim | |
| from torch.utils.data import DataLoader | |
| import torch.nn.functional as F | |
| from torch.distributions import Categorical | |
| import torchvision | |
| import torchvision.datasets as datasets | |
| import torchvision.transforms as transforms | |
| from transformers import AutoTokenizer | |
| device = torch.device(0 if torch.cuda.is_available() else 'cpu') | |
| def extract_patches(image_tensor, patch_size=patch_size): | |
| # Get the dimensions of the image tensor | |
| bs, c, h, w = image_tensor.size() | |
| # Define the Unfold layer with appropriate parameters | |
| unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size) | |
| # Apply Unfold to the image tensor | |
| unfolded = unfold(image_tensor) | |
| # Reshape the unfolded tensor to match the desired output shape | |
| # Output shape: BSxLxH, where L is the number of patches in each dimension | |
| unfolded = unfolded.transpose(1, 2).reshape(bs, -1, c * patch_size * patch_size) | |
| return unfolded | |
| # sinusoidal positional embeds | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = x[:, None] * emb[None, :] | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| # Define a module for attention blocks | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, hidden_size=hidden_size, num_heads=num_heads, masking=True): | |
| super(AttentionBlock, self).__init__() | |
| self.masking = masking | |
| # Multi-head attention mechanism | |
| self.multihead_attn = nn.MultiheadAttention(hidden_size, | |
| num_heads=num_heads, | |
| batch_first=True, | |
| dropout=0.0) | |
| def forward(self, x_in, kv_in, key_mask=None): | |
| # Apply causal masking if enabled | |
| if self.masking: | |
| bs, l, h = x_in.shape | |
| mask = torch.triu(torch.ones(l, l, device=x_in.device), 1).bool() | |
| else: | |
| mask = None | |
| # Perform multi-head attention operation | |
| return self.multihead_attn(x_in, kv_in, kv_in, attn_mask=mask, | |
| key_padding_mask=key_mask)[0] | |
| # Define a module for a transformer block with self-attention | |
| # and optional causal masking | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, hidden_size=hidden_size, num_heads=num_heads, decoder=False, masking=True): | |
| super(TransformerBlock, self).__init__() | |
| self.decoder = decoder | |
| # Layer normalization for the input | |
| self.norm1 = nn.LayerNorm(hidden_size) | |
| # Self-attention mechanism | |
| self.attn1 = AttentionBlock(hidden_size=hidden_size, num_heads=num_heads, | |
| masking=masking) | |
| # Layer normalization for the output of the first attention layer | |
| if self.decoder: | |
| self.norm2 = nn.LayerNorm(hidden_size) | |
| # Self-attention mechanism for the decoder with no masking | |
| self.attn2 = AttentionBlock(hidden_size=hidden_size, | |
| num_heads=num_heads, masking=False) | |
| # Layer normalization for the output before the MLP | |
| self.norm_mlp = nn.LayerNorm(hidden_size) | |
| # Multi-layer perceptron (MLP) | |
| self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size * 4), | |
| nn.ELU(), | |
| nn.Linear(hidden_size * 4, hidden_size)) | |
| def forward(self, x, input_key_mask=None, cross_key_mask=None, kv_cross=None): | |
| # Perform self-attention operation | |
| x = self.attn1(x, x, key_mask=input_key_mask) + x | |
| x = self.norm1(x) | |
| # If decoder, perform additional cross-attention layer | |
| if self.decoder: | |
| x = self.attn2(x, kv_cross, key_mask=cross_key_mask) + x | |
| x = self.norm2(x) | |
| # Apply MLP and layer normalization | |
| x = self.mlp(x) + x | |
| return self.norm_mlp(x) | |
| # Define a decoder module for the Transformer architecture | |
| class Decoder(nn.Module): | |
| def __init__(self, num_emb, hidden_size=hidden_size, num_layers=num_layers, num_heads=num_heads): | |
| super(Decoder, self).__init__() | |
| # Create an embedding layer for tokens | |
| self.embedding = nn.Embedding(num_emb, hidden_size) | |
| # Initialize the embedding weights | |
| self.embedding.weight.data = 0.001 * self.embedding.weight.data | |
| # Initialize sinusoidal positional embeddings | |
| self.pos_emb = SinusoidalPosEmb(hidden_size) | |
| # Create multiple transformer blocks as layers | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock(hidden_size, num_heads, | |
| decoder=True) for _ in range(num_layers) | |
| ]) | |
| # Define a linear layer for output prediction | |
| self.fc_out = nn.Linear(hidden_size, num_emb) | |
| def forward(self, input_seq, encoder_output, input_padding_mask=None, | |
| encoder_padding_mask=None): | |
| # Embed the input sequence | |
| input_embs = self.embedding(input_seq) | |
| bs, l, h = input_embs.shape | |
| # Add positional embeddings to the input embeddings | |
| seq_indx = torch.arange(l, device=input_seq.device) | |
| pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h) | |
| embs = input_embs + pos_emb | |
| # Pass the embeddings through each transformer block | |
| for block in self.blocks: | |
| embs = block(embs, | |
| input_key_mask=input_padding_mask, | |
| cross_key_mask=encoder_padding_mask, | |
| kv_cross=encoder_output) | |
| return self.fc_out(embs) | |
| # Define an Vision Encoder module for the Transformer architecture | |
| class VisionEncoder(nn.Module): | |
| def __init__(self, image_size, channels_in, patch_size=patch_size, hidden_size=hidden_size, | |
| num_layers=3, num_heads=num_heads): | |
| super(VisionEncoder, self).__init__() | |
| self.patch_size = patch_size | |
| self.fc_in = nn.Linear(channels_in * patch_size * patch_size, hidden_size) | |
| seq_length = (image_size // patch_size) ** 2 | |
| self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, | |
| hidden_size).normal_(std=0.02)) | |
| # Create multiple transformer blocks as layers | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock(hidden_size, num_heads, | |
| decoder=False, masking=False) for _ in range(num_layers) | |
| ]) | |
| def forward(self, image): | |
| bs = image.shape[0] | |
| patch_seq = extract_patches(image, patch_size=self.patch_size) | |
| patch_emb = self.fc_in(patch_seq) | |
| # Add a unique embedding to each token embedding | |
| embs = patch_emb + self.pos_embedding | |
| # Pass the embeddings through each transformer block | |
| for block in self.blocks: | |
| embs = block(embs) | |
| return embs | |
| # Define an Vision Encoder-Decoder module for the Transformer architecture | |
| class VisionEncoderDecoder(nn.Module): | |
| def __init__(self, image_size, channels_in, num_emb, patch_size=patch_size, | |
| hidden_size=hidden_size, num_layers=num_layers, num_heads=num_heads): | |
| super(VisionEncoderDecoder, self).__init__() | |
| # Create an encoder and decoder with specified parameters | |
| self.encoder = VisionEncoder(image_size=image_size, channels_in=channels_in, | |
| patch_size=patch_size, hidden_size=hidden_size, | |
| num_layers=num_layers[0], num_heads=num_heads) | |
| self.decoder = Decoder(num_emb=num_emb, hidden_size=hidden_size, | |
| num_layers=num_layers[1], num_heads=num_heads) | |
| def forward(self, input_image, target_seq, padding_mask): | |
| # Generate padding masks for the target sequence | |
| bool_padding_mask = padding_mask == 0 | |
| # Encode the input sequence | |
| encoded_seq = self.encoder(image=input_image) | |
| # Decode the target sequence using the encoded sequence | |
| decoded_seq = self.decoder(input_seq=target_seq, | |
| encoder_output=encoded_seq, | |
| input_padding_mask=bool_padding_mask) | |
| return decoded_seq | |
| model = torch.load("caption_model.pth", weights_only=False) | |
| model.eval() | |
| tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| def pred_transformer_caption(test_img): | |
| # Add the Start-Of-Sentence token to the prompt to signal the network to start generating the caption | |
| sos_token = 101 * torch.ones(1, 1).long() | |
| # Set the temperature for sampling during generation | |
| temp = 0.5 | |
| log_tokens = [sos_token] | |
| model.eval() | |
| with torch.no_grad(): | |
| # Encode the input image | |
| with torch.cuda.amp.autocast(): | |
| # Forward pass | |
| image_embedding = model.encoder(test_img.to(device)) | |
| # Generate the answer tokens | |
| for i in range(50): | |
| input_tokens = torch.cat(log_tokens, 1) | |
| # Decode the input tokens into the next predicted tokens | |
| data_pred = model.decoder(input_tokens.to(device), image_embedding) | |
| # Sample from the distribution of predicted probabilities | |
| dist = Categorical(logits=data_pred[:, -1] / temp) | |
| next_tokens = dist.sample().reshape(1, 1) | |
| # Append the next predicted token to the sequence | |
| log_tokens.append(next_tokens.cpu()) | |
| # Break the loop if the End-Of-Caption token is predicted | |
| if next_tokens.item() == 102: | |
| break | |
| # Convert the list of token indices to a tensor | |
| pred_text = torch.cat(log_tokens, 1) | |
| # Convert the token indices to their corresponding strings using the vocabulary | |
| pred_text_strings = tokenizer.decode(pred_text[0], skip_special_tokens=True) | |
| # Join the token strings to form the predicted text | |
| pred_text = "".join(pred_text_strings) | |
| # Print the predicted text | |
| return (pred_text) | |
| ##Dashboard | |
| st.title("Caption_APP") | |
| test_img=st.file_uploader(label="upload the funny pic :) :", type=["png","jpg","jpeg"]) | |
| caption="" | |
| if test_img: | |
| test_img=PIL.Image.open(test_img) | |
| test_img=test_img.resize((128,128)) | |
| test_img=((test_img-np.amin(test_img))/(np.amax(test_img)-np.amin(test_img))) | |
| test_img=np.array(test_img) | |
| test_img=test_img.reshape((1,)+test_img.shape) | |
| test_img=test_img.astype("float32") | |
| copy=test_img | |
| test_img=torch.from_numpy(test_img).to(device).unsqueeze(0) | |
| caption=(str)(pred_transformer_caption(test_img)) | |
| st.image(image=np.squeeze(copy),caption=caption) | |
| #st.write(caption) |