/*! * \file TrainResult.h * \date 2019/12/03 * * \author Lin, Chi * Contact: lin.chi@hzleaper.com * * * \note */ #ifndef __TrainResult_h_ #define __TrainResult_h_ #include "CVUtils.h" #include "CyclopsLock.h" #include "StdUtils.h" #include struct TrainResult { int id = -1; double testAccuracy = -1; double testConfidence = -1; double time = 0; std::string trainLog; inline void setInvalid() { testAccuracy = -FLT_MAX; testConfidence = -1; } inline bool isInvalid() const { return testAccuracy == -FLT_MAX; } inline bool isBetter(const TrainResult& tr) { return testAccuracy > tr.testAccuracy || (testAccuracy == tr.testAccuracy && testConfidence > tr.testConfidence); } inline bool isBetter(double testAcc, double testConf) { return testAccuracy > testAcc || (testAccuracy == testAcc && testConfidence > testConf); } inline bool isSameGood(const TrainResult& tr) { return testAccuracy == tr.testAccuracy && testConfidence == tr.testConfidence; } inline bool isSameGood(double testAcc, double testConf) { return testAccuracy == testAcc && testConfidence == testConf; } inline bool isBetterOrSameGood(const TrainResult& tr) { return testAccuracy >= tr.testAccuracy || (testAccuracy == tr.testAccuracy && testConfidence >= tr.testConfidence); } inline bool isBetterOrSameGood(double testAcc, double testConf) { return testAccuracy >= testAcc || (testAccuracy == testAcc && testConfidence >= testConf); } }; class TrainResults : public AsyncResult { public: typedef std::shared_ptr Ptr; TrainResults(AsyncFunc func) : AsyncResult(func), mLastValidIndex(-1) {} virtual int count() { CyclopsSharedLockGuard shared_gaurd(&mLock); return mResults.size(); } TrainResult get(int idx) { CyclopsSharedLockGuard shared_gaurd(&mLock); if (idx < 0 || idx >= mResults.size()) return TrainResult(); return mResults[idx]; } std::list getBest(int id) { CyclopsSharedLockGuard shared_gaurd(&mLock); auto it = mBestIndex.find(id); if (it != mBestIndex.end()) return it->second; else return std::list(); } int getLastValid() { return mLastValidIndex; } protected: friend class LLClassifier; void setTotal(int num) { CyclopsLockGuard write_gaurd(&mLock); mTotal = num; mResults.reserve(num); } void add(const TrainResult& r) { CyclopsLockGuard write_gaurd(&mLock); mResults.push_back(r); if (!r.isInvalid()) mLastValidIndex = mResults.size() - 1; } void addBest(int id, int idx) { CyclopsLockGuard write_gaurd(&mLock); mBestIndex[id].push_back(idx); } void clearBest(int id) { CyclopsLockGuard write_gaurd(&mLock); mBestIndex[id].clear(); } private: std::vector mResults; // actual finished results std::atomic_int mLastValidIndex; std::map > mBestIndex; }; #endif // TrainResult_h_