You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
wheeldetect/3part/Cyclops/include/TrainResult.h

124 lines
2.8 KiB
C++

/*!
* \file TrainResult.h
* \date 2019/12/03
*
* \author Lin, Chi
* Contact: lin.chi@hzleaper.com
*
*
* \note
*/
#ifndef __TrainResult_h_
#define __TrainResult_h_
#include "CVUtils.h"
#include "CyclopsLock.h"
#include "StdUtils.h"
#include <atomic>
struct TrainResult
{
int id = -1;
double testAccuracy = -1;
double testConfidence = -1;
double time = 0;
std::string trainLog;
inline void setInvalid() {
testAccuracy = -FLT_MAX; testConfidence = -1;
}
inline bool isInvalid() const { return testAccuracy == -FLT_MAX; }
inline bool isBetter(const TrainResult& tr) {
return testAccuracy > tr.testAccuracy ||
(testAccuracy == tr.testAccuracy && testConfidence > tr.testConfidence);
}
inline bool isBetter(double testAcc, double testConf) {
return testAccuracy > testAcc ||
(testAccuracy == testAcc && testConfidence > testConf);
}
inline bool isSameGood(const TrainResult& tr) {
return testAccuracy == tr.testAccuracy && testConfidence == tr.testConfidence;
}
inline bool isSameGood(double testAcc, double testConf) {
return testAccuracy == testAcc && testConfidence == testConf;
}
inline bool isBetterOrSameGood(const TrainResult& tr) {
return testAccuracy >= tr.testAccuracy ||
(testAccuracy == tr.testAccuracy && testConfidence >= tr.testConfidence);
}
inline bool isBetterOrSameGood(double testAcc, double testConf) {
return testAccuracy >= testAcc ||
(testAccuracy == testAcc && testConfidence >= testConf);
}
};
class TrainResults : public AsyncResult
{
public:
typedef std::shared_ptr<TrainResults> Ptr;
TrainResults(AsyncFunc func) : AsyncResult(func), mLastValidIndex(-1)
{}
virtual int count() {
CyclopsSharedLockGuard shared_gaurd(&mLock);
return mResults.size();
}
TrainResult get(int idx) {
CyclopsSharedLockGuard shared_gaurd(&mLock);
if (idx < 0 || idx >= mResults.size()) return TrainResult();
return mResults[idx];
}
std::list<int> getBest(int id) {
CyclopsSharedLockGuard shared_gaurd(&mLock);
auto it = mBestIndex.find(id);
if (it != mBestIndex.end()) return it->second;
else return std::list<int>();
}
int getLastValid() {
return mLastValidIndex;
}
protected:
friend class LLClassifier;
void setTotal(int num) {
CyclopsLockGuard write_gaurd(&mLock);
mTotal = num;
mResults.reserve(num);
}
void add(const TrainResult& r) {
CyclopsLockGuard write_gaurd(&mLock);
mResults.push_back(r);
if (!r.isInvalid())
mLastValidIndex = mResults.size() - 1;
}
void addBest(int id, int idx) {
CyclopsLockGuard write_gaurd(&mLock);
mBestIndex[id].push_back(idx);
}
void clearBest(int id) {
CyclopsLockGuard write_gaurd(&mLock);
mBestIndex[id].clear();
}
private:
std::vector<TrainResult> mResults; // actual finished results
std::atomic_int mLastValidIndex;
std::map<int, std::list<int> > mBestIndex;
};
#endif // TrainResult_h_