You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
124 lines
2.8 KiB
C++
124 lines
2.8 KiB
C++
/*!
|
|
* \file TrainResult.h
|
|
* \date 2019/12/03
|
|
*
|
|
* \author Lin, Chi
|
|
* Contact: lin.chi@hzleaper.com
|
|
*
|
|
*
|
|
* \note
|
|
*/
|
|
|
|
#ifndef __TrainResult_h_
|
|
#define __TrainResult_h_
|
|
|
|
#include "CVUtils.h"
|
|
#include "CyclopsLock.h"
|
|
#include "StdUtils.h"
|
|
#include <atomic>
|
|
|
|
struct TrainResult
|
|
{
|
|
int id = -1;
|
|
double testAccuracy = -1;
|
|
double testConfidence = -1;
|
|
double time = 0;
|
|
std::string trainLog;
|
|
|
|
inline void setInvalid() {
|
|
testAccuracy = -FLT_MAX; testConfidence = -1;
|
|
}
|
|
inline bool isInvalid() const { return testAccuracy == -FLT_MAX; }
|
|
|
|
inline bool isBetter(const TrainResult& tr) {
|
|
return testAccuracy > tr.testAccuracy ||
|
|
(testAccuracy == tr.testAccuracy && testConfidence > tr.testConfidence);
|
|
}
|
|
|
|
inline bool isBetter(double testAcc, double testConf) {
|
|
return testAccuracy > testAcc ||
|
|
(testAccuracy == testAcc && testConfidence > testConf);
|
|
}
|
|
|
|
inline bool isSameGood(const TrainResult& tr) {
|
|
return testAccuracy == tr.testAccuracy && testConfidence == tr.testConfidence;
|
|
}
|
|
|
|
inline bool isSameGood(double testAcc, double testConf) {
|
|
return testAccuracy == testAcc && testConfidence == testConf;
|
|
}
|
|
|
|
inline bool isBetterOrSameGood(const TrainResult& tr) {
|
|
return testAccuracy >= tr.testAccuracy ||
|
|
(testAccuracy == tr.testAccuracy && testConfidence >= tr.testConfidence);
|
|
}
|
|
|
|
inline bool isBetterOrSameGood(double testAcc, double testConf) {
|
|
return testAccuracy >= testAcc ||
|
|
(testAccuracy == testAcc && testConfidence >= testConf);
|
|
}
|
|
};
|
|
|
|
class TrainResults : public AsyncResult
|
|
{
|
|
public:
|
|
typedef std::shared_ptr<TrainResults> Ptr;
|
|
|
|
TrainResults(AsyncFunc func) : AsyncResult(func), mLastValidIndex(-1)
|
|
{}
|
|
|
|
virtual int count() {
|
|
CyclopsSharedLockGuard shared_gaurd(&mLock);
|
|
return mResults.size();
|
|
}
|
|
TrainResult get(int idx) {
|
|
CyclopsSharedLockGuard shared_gaurd(&mLock);
|
|
if (idx < 0 || idx >= mResults.size()) return TrainResult();
|
|
return mResults[idx];
|
|
}
|
|
std::list<int> getBest(int id) {
|
|
CyclopsSharedLockGuard shared_gaurd(&mLock);
|
|
auto it = mBestIndex.find(id);
|
|
if (it != mBestIndex.end()) return it->second;
|
|
else return std::list<int>();
|
|
}
|
|
int getLastValid() {
|
|
return mLastValidIndex;
|
|
}
|
|
|
|
protected:
|
|
|
|
friend class LLClassifier;
|
|
|
|
void setTotal(int num) {
|
|
CyclopsLockGuard write_gaurd(&mLock);
|
|
mTotal = num;
|
|
mResults.reserve(num);
|
|
}
|
|
|
|
void add(const TrainResult& r) {
|
|
CyclopsLockGuard write_gaurd(&mLock);
|
|
mResults.push_back(r);
|
|
if (!r.isInvalid())
|
|
mLastValidIndex = mResults.size() - 1;
|
|
}
|
|
|
|
void addBest(int id, int idx) {
|
|
CyclopsLockGuard write_gaurd(&mLock);
|
|
mBestIndex[id].push_back(idx);
|
|
}
|
|
|
|
void clearBest(int id) {
|
|
CyclopsLockGuard write_gaurd(&mLock);
|
|
mBestIndex[id].clear();
|
|
}
|
|
|
|
private:
|
|
std::vector<TrainResult> mResults; // actual finished results
|
|
std::atomic_int mLastValidIndex;
|
|
std::map<int, std::list<int> > mBestIndex;
|
|
};
|
|
|
|
#endif // TrainResult_h_
|
|
|