| import argparse |
| import logging |
| import os |
| import wandb |
| import gradio as gr |
|
|
| import zipfile |
| import pickle |
| from pathlib import Path |
| from typing import List, Any, Dict |
| from PIL import Image |
| from pathlib import Path |
|
|
| from transformers import AutoTokenizer |
| from sentence_transformers import SentenceTransformer, util |
| from multilingual_clip import pt_multilingual_clip |
| import torch |
|
|
| from pathlib import Path |
| from typing import Callable, Dict, List, Tuple |
| from PIL.Image import Image |
|
|
| print(__file__) |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "" |
|
|
| logging.basicConfig(level=logging.INFO) |
| DEFAULT_APPLICATION_NAME = "FashGen" |
|
|
| APP_DIR = Path(__file__).resolve().parent |
| README = APP_DIR / "README.md" |
|
|
| DEFAULT_PORT = 11700 |
|
|
| EMBEDDINGS_DIR = "artifacts/img-embeddings" |
| EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "embeddings.pkl") |
| RAW_PHOTOS_DIR = "artifacts/raw-photos" |
|
|
| |
| wandb.login(key="4b5a23a662b20fdd61f2aeb5032cf56fdce278a4") |
| api = wandb.Api() |
| artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1") |
| artifact_embeddings.download(EMBEDDINGS_DIR) |
| artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1") |
| artifact_raw_photos.download("artifacts") |
|
|
| with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref: |
| zip_ref.extractall(RAW_PHOTOS_DIR) |
|
|
|
|
| class TextEncoder: |
| """Encodes the given text""" |
|
|
| def __init__(self, model_path="M-CLIP/XLM-Roberta-Large-Vit-B-32"): |
| self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path) |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
| @torch.no_grad() |
| def encode(self, query: str) -> torch.Tensor: |
| """Predict/infer text embedding for a given query.""" |
| query_emb = self.model.forward([query], self.tokenizer) |
| return query_emb |
|
|
|
|
| class ImageEnoder: |
| """Encodes the given image""" |
|
|
| def __init__(self, model_path="clip-ViT-B-32"): |
| self.model = SentenceTransformer(model_path) |
|
|
| @torch.no_grad() |
| def encode(self, image: Image) -> torch.Tensor: |
| """Predict/infer text embedding for a given query.""" |
| image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False) |
| return image_emb |
|
|
|
|
| class Retriever: |
| """Retrieves relevant images for a given text embedding.""" |
|
|
| def __init__(self, image_embeddings_path=None): |
| self.text_encoder = TextEncoder() |
| self.image_encoder = ImageEnoder() |
|
|
| with open(image_embeddings_path, "rb") as file: |
| self.image_names, self.image_embeddings = pickle.load(file) |
| self.image_names = [ |
| img_name.replace("fashion-aggregator/fashion_aggregator/data/photos/", "") |
| for img_name in self.image_names |
| ] |
| print("Images:", len(self.image_names)) |
|
|
| @torch.no_grad() |
| def predict(self, text_query: str, k: int = 10) -> List[Any]: |
| """Return top-k relevant items for a given embedding""" |
| query_emb = self.text_encoder.encode(text_query) |
| relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0] |
| return relevant_images |
|
|
| @torch.no_grad() |
| def search_images(self, text_query: str, k: int = 6) -> Dict[str, List[Any]]: |
| """Return top-k relevant images for a given embedding""" |
| images = self.predict(text_query, k) |
| paths_and_scores = {"path": [], "score": []} |
| for img in images: |
| paths_and_scores["path"].append(os.path.join(RAW_PHOTOS_DIR, self.image_names[img["corpus_id"]])) |
| paths_and_scores["score"].append(img["score"]) |
| return paths_and_scores |
|
|
|
|
| def main(args): |
| predictor = PredictorBackend(url=args.model_url) |
| frontend = make_frontend(predictor.run, flagging=args.flagging, gantry=args.gantry, app_name=args.application) |
| frontend.launch( |
| |
| |
| |
| ) |
|
|
|
|
| def make_frontend( |
| fn: Callable[[Image], str], flagging: bool = False, gantry: bool = False, app_name: str = "fashion-aggregator" |
| ): |
| """Creates a gradio.Interface frontend for text to image search function.""" |
|
|
| allow_flagging = "never" |
|
|
| |
| frontend = gr.Interface( |
| fn=fn, |
| outputs=gr.Gallery(label="Relevant Items"), |
| |
| inputs=gr.components.Textbox(label="Item Description"), |
| title="FashGen", |
| description=__doc__, |
| cache_examples=False, |
| allow_flagging=allow_flagging, |
| flagging_options=["incorrect", "offensive", "other"], |
| ) |
| return frontend |
|
|
|
|
| class PredictorBackend: |
| """Interface to a backend that serves predictions. |
| To communicate with a backend accessible via a URL, provide the url kwarg. |
| Otherwise, runs a predictor locally. |
| """ |
|
|
| def __init__(self, url=None): |
| if url is not None: |
| self.url = url |
| self._predict = self._predict_from_endpoint |
| else: |
| model = Retriever(image_embeddings_path=EMBEDDINGS_FILE) |
| self._predict = model.predict |
| self._search_images = model.search_images |
|
|
| def run(self, text: str): |
| pred, metrics = self._predict_with_metrics(text) |
| self._log_inference(pred, metrics) |
| return pred |
|
|
| def _predict_with_metrics(self, text: str) -> Tuple[List[str], Dict[str, float]]: |
| paths_and_scores = self._search_images(text) |
| metrics = {"mean_score": sum(paths_and_scores["score"]) / len(paths_and_scores["score"])} |
| return paths_and_scores["path"], metrics |
|
|
| def _log_inference(self, pred, metrics): |
| for key, value in metrics.items(): |
| logging.info(f"METRIC {key} {value}") |
| logging.info(f"PRED >begin\n{pred}\nPRED >end") |
|
|
|
|
| def _make_parser(): |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument( |
| "--model_url", |
| default=None, |
| type=str, |
| help="Identifies a URL to which to send image data. Data is base64-encoded, converted to a utf-8 string, and then set via a POST request as JSON with the key 'image'. Default is None, which instead sends the data to a model running locally.", |
| ) |
| parser.add_argument( |
| "--port", |
| default=DEFAULT_PORT, |
| type=int, |
| help=f"Port on which to expose this server. Default is {DEFAULT_PORT}.", |
| ) |
| parser.add_argument( |
| "--flagging", |
| action="store_true", |
| help="Pass this flag to allow users to 'flag' model behavior and provide feedback.", |
| ) |
| parser.add_argument( |
| "--gantry", |
| action="store_true", |
| help="Pass --flagging and this flag to log user feedback to Gantry. Requires GANTRY_API_KEY to be defined as an environment variable.", |
| ) |
| parser.add_argument( |
| "--application", |
| default=DEFAULT_APPLICATION_NAME, |
| type=str, |
| help=f"Name of the Gantry application to which feedback should be logged, if --gantry and --flagging are passed. Default is {DEFAULT_APPLICATION_NAME}.", |
| ) |
| return parser |
|
|
|
|
| if __name__ == "__main__": |
| parser = _make_parser() |
| args = parser.parse_args() |
| main(args) |