test script
This commit is contained in:
parent
3c83a57e44
commit
cf24a172a2
1 changed files with 189 additions and 0 deletions
189
test_inference.py
Normal file
189
test_inference.py
Normal file
|
|
@ -0,0 +1,189 @@
|
||||||
|
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from services.model_repository import TensorRTModelRepository
|
||||||
|
from services.stream_decoder import StreamDecoderFactory
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# COCO class names for YOLOv8
|
||||||
|
COCO_CLASSES = [
|
||||||
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
||||||
|
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||||
|
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
||||||
|
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||||
|
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||||
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
||||||
|
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
|
||||||
|
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
||||||
|
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
||||||
|
]
|
||||||
|
|
||||||
|
def postprocess(output, confidence_threshold=0.25, iou_threshold=0.45):
|
||||||
|
"""
|
||||||
|
Post-processes the output of a YOLOv8 model to extract bounding boxes, scores, and class IDs.
|
||||||
|
"""
|
||||||
|
# output shape: (batch_size, 84, 8400)
|
||||||
|
# 84 = 4 (bbox) + 80 (classes)
|
||||||
|
|
||||||
|
# Transpose the output to (batch_size, 8400, 84)
|
||||||
|
output = output.transpose(1, 2)
|
||||||
|
|
||||||
|
boxes = []
|
||||||
|
scores = []
|
||||||
|
class_ids = []
|
||||||
|
|
||||||
|
for detection in output[0]:
|
||||||
|
# First 4 values are bbox (cx, cy, w, h)
|
||||||
|
# The rest are class scores
|
||||||
|
|
||||||
|
class_scores = detection[4:]
|
||||||
|
max_score, max_class_id = torch.max(class_scores, 0)
|
||||||
|
|
||||||
|
if max_score > confidence_threshold:
|
||||||
|
|
||||||
|
cx, cy, w, h = detection[:4]
|
||||||
|
|
||||||
|
# Convert from center-width-height to x1-y1-x2-y2
|
||||||
|
x1 = cx - w / 2
|
||||||
|
y1 = cy - h / 2
|
||||||
|
x2 = cx + w / 2
|
||||||
|
y2 = cy + h / 2
|
||||||
|
|
||||||
|
boxes.append([x1.item(), y1.item(), x2.item(), y2.item()])
|
||||||
|
scores.append(max_score.item())
|
||||||
|
class_ids.append(max_class_id.item())
|
||||||
|
|
||||||
|
if not boxes:
|
||||||
|
return [], [], []
|
||||||
|
|
||||||
|
# Perform Non-Maximum Suppression (NMS)
|
||||||
|
# This is a simplified version. For production, use a library like torchvision.ops.nms
|
||||||
|
indices = []
|
||||||
|
boxes_np = np.array(boxes)
|
||||||
|
scores_np = np.array(scores)
|
||||||
|
|
||||||
|
order = scores_np.argsort()[::-1]
|
||||||
|
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
indices.append(i)
|
||||||
|
|
||||||
|
xx1 = np.maximum(boxes_np[i, 0], boxes_np[order[1:], 0])
|
||||||
|
yy1 = np.maximum(boxes_np[i, 1], boxes_np[order[1:], 1])
|
||||||
|
xx2 = np.minimum(boxes_np[i, 2], boxes_np[order[1:], 2])
|
||||||
|
yy2 = np.minimum(boxes_np[i, 3], boxes_np[order[1:], 3])
|
||||||
|
|
||||||
|
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||||
|
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||||
|
inter = w * h
|
||||||
|
|
||||||
|
ovr = inter / ((boxes_np[i, 2] - boxes_np[i, 0] + 1) * (boxes_np[i, 3] - boxes_np[i, 1] + 1) + \
|
||||||
|
(boxes_np[order[1:], 2] - boxes_np[order[1:], 0] + 1) * \
|
||||||
|
(boxes_np[order[1:], 3] - boxes_np[order[1:], 1] + 1) - inter)
|
||||||
|
|
||||||
|
inds = np.where(ovr <= iou_threshold)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
|
||||||
|
final_boxes = [boxes[i] for i in indices]
|
||||||
|
final_scores = [scores[i] for i in indices]
|
||||||
|
final_class_ids = [class_ids[i] for i in indices]
|
||||||
|
|
||||||
|
return final_boxes, final_scores, final_class_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_rtsp_stream_with_inference():
|
||||||
|
"""
|
||||||
|
Decodes an RTSP stream and runs inference, printing bounding boxes and class names.
|
||||||
|
"""
|
||||||
|
load_dotenv()
|
||||||
|
rtsp_url = os.getenv("CAMERA_URL_1")
|
||||||
|
if not rtsp_url:
|
||||||
|
print("Error: CAMERA_URL_1 not found in .env file.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("RTSP Stream + TensorRT Inference")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
decoder_factory = StreamDecoderFactory(gpu_id=0)
|
||||||
|
model_repo = TensorRTModelRepository(gpu_id=0, default_num_contexts=1)
|
||||||
|
|
||||||
|
# Setup camera stream
|
||||||
|
decoder = decoder_factory.create_decoder(rtsp_url, buffer_size=1)
|
||||||
|
decoder.start()
|
||||||
|
|
||||||
|
# Load inference model
|
||||||
|
model_path = "models/yolov8n.trt"
|
||||||
|
try:
|
||||||
|
model_repo.load_model(
|
||||||
|
model_id="camera_main",
|
||||||
|
file_path=model_path
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading model: {e}")
|
||||||
|
print(f"Please ensure '{model_path}' exists.")
|
||||||
|
decoder.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\nWaiting for stream to buffer frames...")
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||||
|
|
||||||
|
if frame_gpu is None:
|
||||||
|
time.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Preprocess frame for YOLOv8
|
||||||
|
# Resize to 640x640, normalize, and add batch dimension
|
||||||
|
frame_float = frame_gpu.unsqueeze(0).float() # Convert to float here
|
||||||
|
frame_resized = torch.nn.functional.interpolate(
|
||||||
|
frame_float, size=(640, 640), mode='bilinear', align_corners=False
|
||||||
|
)
|
||||||
|
frame_normalized = frame_resized.float() / 255.0
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
try:
|
||||||
|
outputs = model_repo.infer(
|
||||||
|
model_id="camera_main",
|
||||||
|
inputs={"images": frame_normalized},
|
||||||
|
synchronize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Post-process the output
|
||||||
|
output_tensor = outputs['output0']
|
||||||
|
boxes, scores, class_ids = postprocess(output_tensor)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"\n--- Frame at {time.time():.2f} ---")
|
||||||
|
if boxes:
|
||||||
|
for box, score, class_id in zip(boxes, scores, class_ids):
|
||||||
|
class_name = COCO_CLASSES[class_id]
|
||||||
|
print(
|
||||||
|
f" Detected: {class_name} "
|
||||||
|
f"(confidence: {score:.2f}) at "
|
||||||
|
f"bbox: [{box[0]:.0f}, {box[1]:.0f}, {box[2]:.0f}, {box[3]:.0f}]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(" No objects detected.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Inference failed: {e}")
|
||||||
|
|
||||||
|
time.sleep(0.03) # ~30 FPS
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopping...")
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
decoder.stop()
|
||||||
|
model_repo.unload_model("camera_main")
|
||||||
|
print("Stream and model unloaded.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_rtsp_stream_with_inference()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue