You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
yudao/yolo_service.py

153 lines
5.0 KiB
Python

import json
import socket
import threading
import time
from ultralytics import YOLO
import argparse
import logging
from pathlib import Path
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class YOLOService:
def __init__(self):
self.models = {} # 缓存已加载的模型
self.default_model = None
def load_model(self, model_path):
"""加载并缓存模型"""
if model_path not in self.models:
logger.info(f"Loading model: {model_path}")
self.models[model_path] = YOLO(model_path)
return self.models[model_path]
def predict(self, model_path, input_path, output_path=None, **kwargs):
"""执行预测"""
try:
model = self.load_model(model_path)
results = model.predict(input_path, **kwargs)
if output_path:
# 处理保存结果
for i, result in enumerate(results):
save_path = f"{output_path}/{i}.jpg" if output_path.endswith('/') else f"{output_path}/{i}.jpg"
result.save(save_path)
return {
'status': 'success',
'message': f'Prediction completed. Results saved to {output_path}' if output_path else 'Prediction completed',
'output_path': output_path
}
except Exception as e:
return {
'status': 'error',
'message': str(e)
}
def train(self, model_path, data_path, epochs=50, **kwargs):
"""执行训练"""
try:
model = self.load_model(model_path)
results = model.train(data=data_path, epochs=epochs, **kwargs)
return {
'status': 'success',
'message': 'Training completed',
'results_path': results.save_dir
}
except Exception as e:
return {
'status': 'error',
'message': str(e)
}
def validate(self, model_path, data_path=None, **kwargs):
"""执行验证"""
try:
model = self.load_model(model_path)
results = model.validate(data=data_path, **kwargs)
return {
'status': 'success',
'message': 'Validation completed',
'metrics': str(results)
}
except Exception as e:
return {
'status': 'error',
'message': str(e)
}
def handle_client(client_socket, yolo_service):
"""处理客户端请求"""
try:
data = client_socket.recv(4096).decode('utf-8')
if not data:
return
request = json.loads(data)
command = request.get('command')
if command == 'predict':
response = yolo_service.predict(**request.get('params', {}))
elif command == 'train':
response = yolo_service.train(**request.get('params', {}))
elif command == 'validate':
response = yolo_service.validate(**request.get('params', {}))
else:
response = {'status': 'error', 'message': f'Unknown command: {command}'}
client_socket.send(json.dumps(response).encode('utf-8'))
except Exception as e:
error_response = {'status': 'error', 'message': str(e)}
client_socket.send(json.dumps(error_response).encode('utf-8'))
finally:
client_socket.close()
def main():
parser = argparse.ArgumentParser(description='YOLO Service')
parser.add_argument('--host', default='localhost', help='Host to bind to')
parser.add_argument('--port', type=int, default=9999, help='Port to bind to')
parser.add_argument('--model', help='Default model to preload')
args = parser.parse_args()
yolo_service = YOLOService()
# 预加载默认模型
if args.model and Path(args.model).exists():
yolo_service.load_model(args.model)
logger.info(f"Preloaded model: {args.model}")
# 启动socket服务
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind((args.host, args.port))
server.listen(5)
logger.info(f"YOLO Service started on {args.host}:{args.port}")
try:
while True:
client_socket, addr = server.accept()
logger.info(f"Connection from {addr}")
# 为每个客户端创建新线程
client_thread = threading.Thread(
target=handle_client,
args=(client_socket, yolo_service)
)
client_thread.daemon = True
client_thread.start()
except KeyboardInterrupt:
logger.info("Shutting down server...")
finally:
server.close()
if __name__ == '__main__':
main()