File size: 1,579 Bytes
4b7d0f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import json
import os
import sys
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoModel
def format_texts(texts):
# Add prompt instructions to generate embeddings that are optimized to classify texts according to preset labels
return [f"task: classification | query: {c}" for c in texts]
def infer(texts):
# Load model directly from Hub
model = AutoModel.from_pretrained("govtech/lionguard-2-lite", trust_remote_code=True)
# Download model from the 🤗 Hub
embedding_model = SentenceTransformer("google/embeddinggemma-300m")
formatted_texts = format_texts(texts)
embeddings = embedding_model.encode(formatted_texts) # NOTE: use encode() instead of encode_documents()
# Run inference
results = model.predict(embeddings)
return results
if __name__ == "__main__":
# Load the data
try:
input_data = sys.argv[1]
batch_text = json.loads(input_data)
print("Using provided input texts")
except (json.JSONDecodeError, IndexError) as e:
print(f"Error parsing input data: {e}")
print("Falling back to default sample texts")
batch_text = ["Eh you damn stupid lah!", "Have a nice day :)"]
# Generate the scores and predictions
results = infer(batch_text)
for i in range(len(batch_text)):
print(f"Text: '{batch_text[i]}'")
for category in results.keys():
print(f"[Text {i+1}] {category} score: {results[category][i]:.4f}")
print("---------------------------------------------")
|