python-rtsp-worker/test_inference.py
2025-11-09 01:07:16 +07:00

189 lines
6.5 KiB
Python

import time
import torch
import os
from dotenv import load_dotenv
from services.model_repository import TensorRTModelRepository
from services.stream_decoder import StreamDecoderFactory
import numpy as np
# COCO class names for YOLOv8
COCO_CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
def postprocess(output, confidence_threshold=0.25, iou_threshold=0.45):
"""
Post-processes the output of a YOLOv8 model to extract bounding boxes, scores, and class IDs.
"""
# output shape: (batch_size, 84, 8400)
# 84 = 4 (bbox) + 80 (classes)
# Transpose the output to (batch_size, 8400, 84)
output = output.transpose(1, 2)
boxes = []
scores = []
class_ids = []
for detection in output[0]:
# First 4 values are bbox (cx, cy, w, h)
# The rest are class scores
class_scores = detection[4:]
max_score, max_class_id = torch.max(class_scores, 0)
if max_score > confidence_threshold:
cx, cy, w, h = detection[:4]
# Convert from center-width-height to x1-y1-x2-y2
x1 = cx - w / 2
y1 = cy - h / 2
x2 = cx + w / 2
y2 = cy + h / 2
boxes.append([x1.item(), y1.item(), x2.item(), y2.item()])
scores.append(max_score.item())
class_ids.append(max_class_id.item())
if not boxes:
return [], [], []
# Perform Non-Maximum Suppression (NMS)
# This is a simplified version. For production, use a library like torchvision.ops.nms
indices = []
boxes_np = np.array(boxes)
scores_np = np.array(scores)
order = scores_np.argsort()[::-1]
while order.size > 0:
i = order[0]
indices.append(i)
xx1 = np.maximum(boxes_np[i, 0], boxes_np[order[1:], 0])
yy1 = np.maximum(boxes_np[i, 1], boxes_np[order[1:], 1])
xx2 = np.minimum(boxes_np[i, 2], boxes_np[order[1:], 2])
yy2 = np.minimum(boxes_np[i, 3], boxes_np[order[1:], 3])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / ((boxes_np[i, 2] - boxes_np[i, 0] + 1) * (boxes_np[i, 3] - boxes_np[i, 1] + 1) + \
(boxes_np[order[1:], 2] - boxes_np[order[1:], 0] + 1) * \
(boxes_np[order[1:], 3] - boxes_np[order[1:], 1] + 1) - inter)
inds = np.where(ovr <= iou_threshold)[0]
order = order[inds + 1]
final_boxes = [boxes[i] for i in indices]
final_scores = [scores[i] for i in indices]
final_class_ids = [class_ids[i] for i in indices]
return final_boxes, final_scores, final_class_ids
def test_rtsp_stream_with_inference():
"""
Decodes an RTSP stream and runs inference, printing bounding boxes and class names.
"""
load_dotenv()
rtsp_url = os.getenv("CAMERA_URL_1")
if not rtsp_url:
print("Error: CAMERA_URL_1 not found in .env file.")
return
print("=" * 80)
print("RTSP Stream + TensorRT Inference")
print("=" * 80)
# Initialize components
decoder_factory = StreamDecoderFactory(gpu_id=0)
model_repo = TensorRTModelRepository(gpu_id=0, default_num_contexts=1)
# Setup camera stream
decoder = decoder_factory.create_decoder(rtsp_url, buffer_size=1)
decoder.start()
# Load inference model
model_path = "models/yolov8n.trt"
try:
model_repo.load_model(
model_id="camera_main",
file_path=model_path
)
except Exception as e:
print(f"Error loading model: {e}")
print(f"Please ensure '{model_path}' exists.")
decoder.stop()
return
print("\nWaiting for stream to buffer frames...")
time.sleep(3)
try:
while True:
frame_gpu = decoder.get_latest_frame(rgb=True)
if frame_gpu is None:
time.sleep(0.1)
continue
# Preprocess frame for YOLOv8
# Resize to 640x640, normalize, and add batch dimension
frame_float = frame_gpu.unsqueeze(0).float() # Convert to float here
frame_resized = torch.nn.functional.interpolate(
frame_float, size=(640, 640), mode='bilinear', align_corners=False
)
frame_normalized = frame_resized.float() / 255.0
# Run inference
try:
outputs = model_repo.infer(
model_id="camera_main",
inputs={"images": frame_normalized},
synchronize=True
)
# Post-process the output
output_tensor = outputs['output0']
boxes, scores, class_ids = postprocess(output_tensor)
# Print results
print(f"\n--- Frame at {time.time():.2f} ---")
if boxes:
for box, score, class_id in zip(boxes, scores, class_ids):
class_name = COCO_CLASSES[class_id]
print(
f" Detected: {class_name} "
f"(confidence: {score:.2f}) at "
f"bbox: [{box[0]:.0f}, {box[1]:.0f}, {box[2]:.0f}, {box[3]:.0f}]"
)
else:
print(" No objects detected.")
except Exception as e:
print(f"Inference failed: {e}")
time.sleep(0.03) # ~30 FPS
except KeyboardInterrupt:
print("\nStopping...")
finally:
# Cleanup
decoder.stop()
model_repo.unload_model("camera_main")
print("Stream and model unloaded.")
if __name__ == "__main__":
test_rtsp_stream_with_inference()