import json import socket import threading import time from ultralytics import YOLO import argparse import logging from pathlib import Path # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class YOLOService: def __init__(self): self.models = {} # 缓存已加载的模型 self.default_model = None def load_model(self, model_path): """加载并缓存模型""" if model_path not in self.models: logger.info(f"Loading model: {model_path}") self.models[model_path] = YOLO(model_path) return self.models[model_path] def predict(self, model_path, input_path, output_path=None, **kwargs): """执行预测""" try: model = self.load_model(model_path) results = model.predict(input_path, **kwargs) if output_path: # 处理保存结果 for i, result in enumerate(results): save_path = f"{output_path}/{i}.jpg" if output_path.endswith('/') else f"{output_path}/{i}.jpg" result.save(save_path) return { 'status': 'success', 'message': f'Prediction completed. Results saved to {output_path}' if output_path else 'Prediction completed', 'output_path': output_path } except Exception as e: return { 'status': 'error', 'message': str(e) } def train(self, model_path, data_path, epochs=50, **kwargs): """执行训练""" try: model = self.load_model(model_path) results = model.train(data=data_path, epochs=epochs, **kwargs) return { 'status': 'success', 'message': 'Training completed', 'results_path': results.save_dir } except Exception as e: return { 'status': 'error', 'message': str(e) } def validate(self, model_path, data_path=None, **kwargs): """执行验证""" try: model = self.load_model(model_path) results = model.validate(data=data_path, **kwargs) return { 'status': 'success', 'message': 'Validation completed', 'metrics': str(results) } except Exception as e: return { 'status': 'error', 'message': str(e) } def handle_client(client_socket, yolo_service): """处理客户端请求""" try: data = client_socket.recv(4096).decode('utf-8') if not data: return request = json.loads(data) command = request.get('command') if command == 'predict': response = yolo_service.predict(**request.get('params', {})) elif command == 'train': response = yolo_service.train(**request.get('params', {})) elif command == 'validate': response = yolo_service.validate(**request.get('params', {})) else: response = {'status': 'error', 'message': f'Unknown command: {command}'} client_socket.send(json.dumps(response).encode('utf-8')) except Exception as e: error_response = {'status': 'error', 'message': str(e)} client_socket.send(json.dumps(error_response).encode('utf-8')) finally: client_socket.close() def main(): parser = argparse.ArgumentParser(description='YOLO Service') parser.add_argument('--host', default='localhost', help='Host to bind to') parser.add_argument('--port', type=int, default=9999, help='Port to bind to') parser.add_argument('--model', help='Default model to preload') args = parser.parse_args() yolo_service = YOLOService() # 预加载默认模型 if args.model and Path(args.model).exists(): yolo_service.load_model(args.model) logger.info(f"Preloaded model: {args.model}") # 启动socket服务 server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server.bind((args.host, args.port)) server.listen(5) logger.info(f"YOLO Service started on {args.host}:{args.port}") try: while True: client_socket, addr = server.accept() logger.info(f"Connection from {addr}") # 为每个客户端创建新线程 client_thread = threading.Thread( target=handle_client, args=(client_socket, yolo_service) ) client_thread.daemon = True client_thread.start() except KeyboardInterrupt: logger.info("Shutting down server...") finally: server.close() if __name__ == '__main__': main()