mangsense's picture
Upload handler.py
b7062c0 verified
import os
import torch
from transformers import AutoTokenizer, T5ForSequenceClassification
from typing import Dict, List, Any
class EndpointHandler:
"""
HuggingFace Inference Endpoint Handler for Java Vulnerability Detection
CodeT5 ๊ธฐ๋ฐ˜ ๋ถ„๋ฅ˜ ๋ชจ๋ธ (LoRA fine-tuned)
"""
def __init__(self, path="."):
"""
๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
Args:
path (str): ๋ชจ๋ธ์ด ์ €์žฅ๋œ ๊ฒฝ๋กœ (HuggingFace Hub์—์„œ ์ž๋™์œผ๋กœ ์„ค์ •๋จ)
"""
print(f"๐Ÿš€ Loading Java Vulnerability Detection Model from {path}")
# ๋””๋ฐ”์ด์Šค ์„ค์ •
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"๐Ÿ“ Device: {self.device}")
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
self.tokenizer = AutoTokenizer.from_pretrained(path)
# T5ForSequenceClassification ๋ชจ๋ธ ๋กœ๋“œ
self.model = T5ForSequenceClassification.from_pretrained(
path,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
# ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •ํ•˜๊ณ  ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
self.model.to(self.device)
self.model.eval()
print("โœ… Model loaded successfully!")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
๋ฉ”์ธ ์ถ”๋ก  ๋ฉ”์„œ๋“œ (HuggingFace Inference API๊ฐ€ ํ˜ธ์ถœ)
Args:
data (dict): ์ž…๋ ฅ ๋ฐ์ดํ„ฐ
- "inputs" (str): Java ์ฝ”๋“œ ๋˜๋Š”
- "code" (str): Java ์ฝ”๋“œ
Returns:
list: ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ
"""
# 1. ์ „์ฒ˜๋ฆฌ
inputs = self.preprocess(data)
# 2. ์ถ”๋ก 
outputs = self.inference(inputs)
# 3. ํ›„์ฒ˜๋ฆฌ
result = self.postprocess(outputs)
return result
def preprocess(self, request: Dict[str, Any]) -> Dict[str, torch.Tensor]:
"""
์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
Args:
request (dict): API ์š”์ฒญ ๋ฐ์ดํ„ฐ
Returns:
dict: ํ† ํฌ๋‚˜์ด์ฆˆ๋œ ์ž…๋ ฅ ํ…์„œ
"""
# ์ž…๋ ฅ ํ…์ŠคํŠธ ์ถ”์ถœ
if isinstance(request, dict):
# "inputs" ๋˜๋Š” "code" ํ‚ค์—์„œ Java ์ฝ”๋“œ ์ถ”์ถœ
code = request.get("inputs") or request.get("code")
elif isinstance(request, list) and len(request) > 0:
code = request[0].get("inputs") or request[0].get("code")
elif isinstance(request, str):
code = request
else:
raise ValueError(
"Invalid request format. Expected {'inputs': 'Java code here'} "
"or {'code': 'Java code here'}"
)
if not code:
raise ValueError("No code provided in request")
# ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์ ์šฉ
input_text = f"Is this Java code vulnerable?:\n{code}"
# ํ† ํฌ๋‚˜์ด์ง•
inputs = self.tokenizer(
input_text,
max_length=512,
truncation=True,
padding="max_length",
return_tensors="pt"
)
# ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
inputs = {k: v.to(self.device) for k, v in inputs.items()}
return inputs
def inference(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
๋ชจ๋ธ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
Args:
inputs (dict): ์ „์ฒ˜๋ฆฌ๋œ ์ž…๋ ฅ ํ…์„œ
Returns:
torch.Tensor: ๋ชจ๋ธ ์ถœ๋ ฅ ๋กœ์ง“
"""
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
return logits
def postprocess(self, logits: torch.Tensor) -> List[Dict[str, Any]]:
"""
๋ชจ๋ธ ์ถœ๋ ฅ์„ ์‚ฌ๋žŒ์ด ์ฝ์„ ์ˆ˜ ์žˆ๋Š” ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
Args:
logits (torch.Tensor): ๋ชจ๋ธ ์ถœ๋ ฅ ๋กœ์ง“
Returns:
list: ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ
"""
# ๋กœ์ง“ ์ฒ˜๋ฆฌ (๋‹จ์ผ ์ถœ๋ ฅ vs ๋‹ค์ค‘ ํด๋ž˜์Šค)
if logits.shape[-1] == 1:
# Binary classification with single output
prob = torch.sigmoid(logits).item()
predicted_class = 1 if prob > 0.5 else 0
confidence = prob if predicted_class == 1 else (1 - prob)
probabilities = {
"LABEL_0": 1 - prob,
"LABEL_1": prob
}
else:
# Multi-class classification
probs = torch.softmax(logits, dim=1)[0]
predicted_class = torch.argmax(logits, dim=1).item()
confidence = probs[predicted_class].item()
probabilities = {
f"LABEL_{i}": probs[i].item()
for i in range(len(probs))
}
# ๋ ˆ์ด๋ธ” ๋งคํ•‘
label_map = {
0: "safe",
1: "vulnerable"
}
# ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
result = {
"label": label_map.get(predicted_class, f"LABEL_{predicted_class}"),
"score": confidence,
"probabilities": probabilities,
"details": {
"is_vulnerable": predicted_class == 1,
"confidence_percentage": f"{confidence * 100:.2f}%",
"safe_probability": probabilities.get("LABEL_0", 0),
"vulnerable_probability": probabilities.get("LABEL_1", 0)
}
}
return [result]
# ๋กœ์ปฌ ํ…Œ์ŠคํŠธ์šฉ ์ฝ”๋“œ
if __name__ == "__main__":
# ๋กœ์ปฌ์—์„œ ํ…Œ์ŠคํŠธํ•  ๋•Œ ์‚ฌ์šฉ
handler = EndpointHandler(path=".")
# ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค
test_code = """
import java.sql.*;
public class SQLInjectionVulnerable {
public void getUser(String userInput) {
String query = "SELECT * FROM users WHERE username = '" + userInput + "'";
Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery(query);
}
}
"""
# ์ถ”๋ก  ์‹คํ–‰
request = {"inputs": test_code}
result = handler(request)
print("\n๐Ÿ“Š Test Result:")
print(result)