File size: 1,579 Bytes
4b7d0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import json
import os
import sys
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoModel

def format_texts(texts):
    # Add prompt instructions to generate embeddings that are optimized to classify texts according to preset labels
    return [f"task: classification | query: {c}" for c in texts]

def infer(texts):
    # Load model directly from Hub
    model = AutoModel.from_pretrained("govtech/lionguard-2-lite", trust_remote_code=True)

    # Download model from the 🤗 Hub
    embedding_model = SentenceTransformer("google/embeddinggemma-300m")
    formatted_texts = format_texts(texts)
    embeddings = embedding_model.encode(formatted_texts) # NOTE: use encode() instead of encode_documents()

    # Run inference
    results = model.predict(embeddings)
    return results


if __name__ == "__main__":

    # Load the data
    try:
        input_data = sys.argv[1]
        batch_text = json.loads(input_data)
        print("Using provided input texts")

    except (json.JSONDecodeError, IndexError) as e:
        print(f"Error parsing input data: {e}")
        print("Falling back to default sample texts")

        batch_text = ["Eh you damn stupid lah!", "Have a nice day :)"]

    # Generate the scores and predictions
    results = infer(batch_text)
    for i in range(len(batch_text)):
        print(f"Text: '{batch_text[i]}'")
        for category in results.keys():
            print(f"[Text {i+1}] {category} score: {results[category][i]:.4f}")
        print("---------------------------------------------")