test script
This commit is contained in:
parent
3c83a57e44
commit
cf24a172a2
1 changed files with 189 additions and 0 deletions
189
test_inference.py
Normal file
189
test_inference.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
|
||||
import time
|
||||
import torch
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from services.model_repository import TensorRTModelRepository
|
||||
from services.stream_decoder import StreamDecoderFactory
|
||||
import numpy as np
|
||||
|
||||
# COCO class names for YOLOv8
|
||||
COCO_CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
||||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
||||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
||||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
|
||||
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
||||
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
||||
]
|
||||
|
||||
def postprocess(output, confidence_threshold=0.25, iou_threshold=0.45):
|
||||
"""
|
||||
Post-processes the output of a YOLOv8 model to extract bounding boxes, scores, and class IDs.
|
||||
"""
|
||||
# output shape: (batch_size, 84, 8400)
|
||||
# 84 = 4 (bbox) + 80 (classes)
|
||||
|
||||
# Transpose the output to (batch_size, 8400, 84)
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
class_ids = []
|
||||
|
||||
for detection in output[0]:
|
||||
# First 4 values are bbox (cx, cy, w, h)
|
||||
# The rest are class scores
|
||||
|
||||
class_scores = detection[4:]
|
||||
max_score, max_class_id = torch.max(class_scores, 0)
|
||||
|
||||
if max_score > confidence_threshold:
|
||||
|
||||
cx, cy, w, h = detection[:4]
|
||||
|
||||
# Convert from center-width-height to x1-y1-x2-y2
|
||||
x1 = cx - w / 2
|
||||
y1 = cy - h / 2
|
||||
x2 = cx + w / 2
|
||||
y2 = cy + h / 2
|
||||
|
||||
boxes.append([x1.item(), y1.item(), x2.item(), y2.item()])
|
||||
scores.append(max_score.item())
|
||||
class_ids.append(max_class_id.item())
|
||||
|
||||
if not boxes:
|
||||
return [], [], []
|
||||
|
||||
# Perform Non-Maximum Suppression (NMS)
|
||||
# This is a simplified version. For production, use a library like torchvision.ops.nms
|
||||
indices = []
|
||||
boxes_np = np.array(boxes)
|
||||
scores_np = np.array(scores)
|
||||
|
||||
order = scores_np.argsort()[::-1]
|
||||
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
indices.append(i)
|
||||
|
||||
xx1 = np.maximum(boxes_np[i, 0], boxes_np[order[1:], 0])
|
||||
yy1 = np.maximum(boxes_np[i, 1], boxes_np[order[1:], 1])
|
||||
xx2 = np.minimum(boxes_np[i, 2], boxes_np[order[1:], 2])
|
||||
yy2 = np.minimum(boxes_np[i, 3], boxes_np[order[1:], 3])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
|
||||
ovr = inter / ((boxes_np[i, 2] - boxes_np[i, 0] + 1) * (boxes_np[i, 3] - boxes_np[i, 1] + 1) + \
|
||||
(boxes_np[order[1:], 2] - boxes_np[order[1:], 0] + 1) * \
|
||||
(boxes_np[order[1:], 3] - boxes_np[order[1:], 1] + 1) - inter)
|
||||
|
||||
inds = np.where(ovr <= iou_threshold)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
final_boxes = [boxes[i] for i in indices]
|
||||
final_scores = [scores[i] for i in indices]
|
||||
final_class_ids = [class_ids[i] for i in indices]
|
||||
|
||||
return final_boxes, final_scores, final_class_ids
|
||||
|
||||
|
||||
def test_rtsp_stream_with_inference():
|
||||
"""
|
||||
Decodes an RTSP stream and runs inference, printing bounding boxes and class names.
|
||||
"""
|
||||
load_dotenv()
|
||||
rtsp_url = os.getenv("CAMERA_URL_1")
|
||||
if not rtsp_url:
|
||||
print("Error: CAMERA_URL_1 not found in .env file.")
|
||||
return
|
||||
|
||||
print("=" * 80)
|
||||
print("RTSP Stream + TensorRT Inference")
|
||||
print("=" * 80)
|
||||
|
||||
# Initialize components
|
||||
decoder_factory = StreamDecoderFactory(gpu_id=0)
|
||||
model_repo = TensorRTModelRepository(gpu_id=0, default_num_contexts=1)
|
||||
|
||||
# Setup camera stream
|
||||
decoder = decoder_factory.create_decoder(rtsp_url, buffer_size=1)
|
||||
decoder.start()
|
||||
|
||||
# Load inference model
|
||||
model_path = "models/yolov8n.trt"
|
||||
try:
|
||||
model_repo.load_model(
|
||||
model_id="camera_main",
|
||||
file_path=model_path
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
print(f"Please ensure '{model_path}' exists.")
|
||||
decoder.stop()
|
||||
return
|
||||
|
||||
print("\nWaiting for stream to buffer frames...")
|
||||
time.sleep(3)
|
||||
|
||||
try:
|
||||
while True:
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame_gpu is None:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Preprocess frame for YOLOv8
|
||||
# Resize to 640x640, normalize, and add batch dimension
|
||||
frame_float = frame_gpu.unsqueeze(0).float() # Convert to float here
|
||||
frame_resized = torch.nn.functional.interpolate(
|
||||
frame_float, size=(640, 640), mode='bilinear', align_corners=False
|
||||
)
|
||||
frame_normalized = frame_resized.float() / 255.0
|
||||
|
||||
# Run inference
|
||||
try:
|
||||
outputs = model_repo.infer(
|
||||
model_id="camera_main",
|
||||
inputs={"images": frame_normalized},
|
||||
synchronize=True
|
||||
)
|
||||
|
||||
# Post-process the output
|
||||
output_tensor = outputs['output0']
|
||||
boxes, scores, class_ids = postprocess(output_tensor)
|
||||
|
||||
# Print results
|
||||
print(f"\n--- Frame at {time.time():.2f} ---")
|
||||
if boxes:
|
||||
for box, score, class_id in zip(boxes, scores, class_ids):
|
||||
class_name = COCO_CLASSES[class_id]
|
||||
print(
|
||||
f" Detected: {class_name} "
|
||||
f"(confidence: {score:.2f}) at "
|
||||
f"bbox: [{box[0]:.0f}, {box[1]:.0f}, {box[2]:.0f}, {box[3]:.0f}]"
|
||||
)
|
||||
else:
|
||||
print(" No objects detected.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Inference failed: {e}")
|
||||
|
||||
time.sleep(0.03) # ~30 FPS
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping...")
|
||||
finally:
|
||||
# Cleanup
|
||||
decoder.stop()
|
||||
model_repo.unload_model("camera_main")
|
||||
print("Stream and model unloaded.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rtsp_stream_with_inference()
|
||||
Loading…
Add table
Add a link
Reference in a new issue