You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
5.0 KiB
Python
153 lines
5.0 KiB
Python
|
3 months ago
|
import json
|
||
|
|
import socket
|
||
|
|
import threading
|
||
|
|
import time
|
||
|
|
from ultralytics import YOLO
|
||
|
|
import argparse
|
||
|
|
import logging
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
# 配置日志
|
||
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
class YOLOService:
|
||
|
|
def __init__(self):
|
||
|
|
self.models = {} # 缓存已加载的模型
|
||
|
|
self.default_model = None
|
||
|
|
|
||
|
|
def load_model(self, model_path):
|
||
|
|
"""加载并缓存模型"""
|
||
|
|
if model_path not in self.models:
|
||
|
|
logger.info(f"Loading model: {model_path}")
|
||
|
|
self.models[model_path] = YOLO(model_path)
|
||
|
|
|
||
|
|
return self.models[model_path]
|
||
|
|
|
||
|
|
def predict(self, model_path, input_path, output_path=None, **kwargs):
|
||
|
|
"""执行预测"""
|
||
|
|
try:
|
||
|
|
model = self.load_model(model_path)
|
||
|
|
results = model.predict(input_path, **kwargs)
|
||
|
|
|
||
|
|
if output_path:
|
||
|
|
# 处理保存结果
|
||
|
|
for i, result in enumerate(results):
|
||
|
|
save_path = f"{output_path}/{i}.jpg" if output_path.endswith('/') else f"{output_path}/{i}.jpg"
|
||
|
|
result.save(save_path)
|
||
|
|
|
||
|
|
return {
|
||
|
|
'status': 'success',
|
||
|
|
'message': f'Prediction completed. Results saved to {output_path}' if output_path else 'Prediction completed',
|
||
|
|
'output_path': output_path
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
return {
|
||
|
|
'status': 'error',
|
||
|
|
'message': str(e)
|
||
|
|
}
|
||
|
|
|
||
|
|
def train(self, model_path, data_path, epochs=50, **kwargs):
|
||
|
|
"""执行训练"""
|
||
|
|
try:
|
||
|
|
model = self.load_model(model_path)
|
||
|
|
results = model.train(data=data_path, epochs=epochs, **kwargs)
|
||
|
|
|
||
|
|
return {
|
||
|
|
'status': 'success',
|
||
|
|
'message': 'Training completed',
|
||
|
|
'results_path': results.save_dir
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
return {
|
||
|
|
'status': 'error',
|
||
|
|
'message': str(e)
|
||
|
|
}
|
||
|
|
|
||
|
|
def validate(self, model_path, data_path=None, **kwargs):
|
||
|
|
"""执行验证"""
|
||
|
|
try:
|
||
|
|
model = self.load_model(model_path)
|
||
|
|
results = model.validate(data=data_path, **kwargs)
|
||
|
|
|
||
|
|
return {
|
||
|
|
'status': 'success',
|
||
|
|
'message': 'Validation completed',
|
||
|
|
'metrics': str(results)
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
return {
|
||
|
|
'status': 'error',
|
||
|
|
'message': str(e)
|
||
|
|
}
|
||
|
|
|
||
|
|
def handle_client(client_socket, yolo_service):
|
||
|
|
"""处理客户端请求"""
|
||
|
|
try:
|
||
|
|
data = client_socket.recv(4096).decode('utf-8')
|
||
|
|
if not data:
|
||
|
|
return
|
||
|
|
|
||
|
|
request = json.loads(data)
|
||
|
|
command = request.get('command')
|
||
|
|
|
||
|
|
if command == 'predict':
|
||
|
|
response = yolo_service.predict(**request.get('params', {}))
|
||
|
|
elif command == 'train':
|
||
|
|
response = yolo_service.train(**request.get('params', {}))
|
||
|
|
elif command == 'validate':
|
||
|
|
response = yolo_service.validate(**request.get('params', {}))
|
||
|
|
else:
|
||
|
|
response = {'status': 'error', 'message': f'Unknown command: {command}'}
|
||
|
|
|
||
|
|
client_socket.send(json.dumps(response).encode('utf-8'))
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
error_response = {'status': 'error', 'message': str(e)}
|
||
|
|
client_socket.send(json.dumps(error_response).encode('utf-8'))
|
||
|
|
finally:
|
||
|
|
client_socket.close()
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description='YOLO Service')
|
||
|
|
parser.add_argument('--host', default='localhost', help='Host to bind to')
|
||
|
|
parser.add_argument('--port', type=int, default=9999, help='Port to bind to')
|
||
|
|
parser.add_argument('--model', help='Default model to preload')
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
yolo_service = YOLOService()
|
||
|
|
|
||
|
|
# 预加载默认模型
|
||
|
|
if args.model and Path(args.model).exists():
|
||
|
|
yolo_service.load_model(args.model)
|
||
|
|
logger.info(f"Preloaded model: {args.model}")
|
||
|
|
|
||
|
|
# 启动socket服务
|
||
|
|
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
|
|
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||
|
|
server.bind((args.host, args.port))
|
||
|
|
server.listen(5)
|
||
|
|
|
||
|
|
logger.info(f"YOLO Service started on {args.host}:{args.port}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
client_socket, addr = server.accept()
|
||
|
|
logger.info(f"Connection from {addr}")
|
||
|
|
|
||
|
|
# 为每个客户端创建新线程
|
||
|
|
client_thread = threading.Thread(
|
||
|
|
target=handle_client,
|
||
|
|
args=(client_socket, yolo_service)
|
||
|
|
)
|
||
|
|
client_thread.daemon = True
|
||
|
|
client_thread.start()
|
||
|
|
|
||
|
|
except KeyboardInterrupt:
|
||
|
|
logger.info("Shutting down server...")
|
||
|
|
finally:
|
||
|
|
server.close()
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|