|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoModel |
|
|
|
|
|
def format_texts(texts): |
|
|
|
|
|
return [f"task: classification | query: {c}" for c in texts] |
|
|
|
|
|
def infer(texts): |
|
|
|
|
|
model = AutoModel.from_pretrained("govtech/lionguard-2-lite", trust_remote_code=True) |
|
|
|
|
|
|
|
|
embedding_model = SentenceTransformer("google/embeddinggemma-300m") |
|
|
formatted_texts = format_texts(texts) |
|
|
embeddings = embedding_model.encode(formatted_texts) |
|
|
|
|
|
|
|
|
results = model.predict(embeddings) |
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
try: |
|
|
input_data = sys.argv[1] |
|
|
batch_text = json.loads(input_data) |
|
|
print("Using provided input texts") |
|
|
|
|
|
except (json.JSONDecodeError, IndexError) as e: |
|
|
print(f"Error parsing input data: {e}") |
|
|
print("Falling back to default sample texts") |
|
|
|
|
|
batch_text = ["Eh you damn stupid lah!", "Have a nice day :)"] |
|
|
|
|
|
|
|
|
results = infer(batch_text) |
|
|
for i in range(len(batch_text)): |
|
|
print(f"Text: '{batch_text[i]}'") |
|
|
for category in results.keys(): |
|
|
print(f"[Text {i+1}] {category} score: {results[category][i]:.4f}") |
|
|
print("---------------------------------------------") |
|
|
|