You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

196 lines
4.7 KiB
C++

/*!
* \file CyclopsMLModel.h
* \date 2019/11/18
*
* \author Lin, Chi
* Contact: lin.chi@hzleaper.com
*
*
* \note
*/
#ifndef __CyclopsMLModel_h_
#define __CyclopsMLModel_h_
#include "CyclopsParam.h"
#include "CyclopsEnums.h"
#include "CyclopsGrid.h"
/*! \brief Base class for machine-learning model */
class CyclopsMLModel
{
public:
/*! Parameter configurations */
CyclopsParams mParams;
/*! Default auto-tune search grid */
CyclopsGrids mDefaultGrids;
/*! Termination criteria */
TermCriteria mTerm;
virtual ~CyclopsMLModel() {}
typedef std::shared_ptr<CyclopsMLModel> Ptr;
/*! Model's type */
MLModel getType() const { return mType; }
/*! Set termination criteria */
virtual void setTermCriteria(const TermCriteria& term) { mTerm = term; }
/*! Serialized to file or in-memory string */
virtual void serialize(cv::FileStorage& fs) const {
mParams.serialize(fs);
fs << "name" << mName;
}
/*! Deserialize from file or in-memory string */
virtual void deserialize(const cv::FileNode& node) {
mParams.deserialize(node);
node["name"] >> mName;
}
/*! Whether the model is trained or not */
virtual bool isTrained() const { return false; }
/*! Do training of given samples and their labels */
virtual bool train(const cv::Mat& samples, const cv::Mat& labels) { return false; }
/*! Prediction of given sample
* @param sample feature vector
* @param pConfidence output confidence score, pass null if not wanted
* @param return prediction result, the label
*/
virtual float predict(const cv::Mat& sample, float* pConfidence = nullptr) { return -FLT_MAX; }
/*! Backup parameters */
virtual void backupAll() { mParams.backupAll(); }
/*! Restore parameter from last backup */
virtual void restoreAll() { mParams.restoreAll(); }
/*! Whether any parameter is modified */
virtual bool isDirty() { return mParams.isDirty(); }
/*! Whether we need to fix some bad parameter configuration
* @param sampleSize uniformed sample size
* @param doFix true if we should also do the real fix
* @return true if fix is needed
*/
virtual bool needFix(const Size& sampleSize, bool doFix) {
return !mParams.isValid(sampleSize, doFix).empty();
}
/*! Clear trained cache */
virtual void clear() {}
/*! Get model's name */
const std::string& getName() const {
return mName;
}
protected:
MLModel mType;
std::string mName;
};
/*! \brief Classification model: K-Nearest-Neighbor */
class KNNModel : public CyclopsMLModel
{
public:
virtual ~KNNModel() {}
typedef std::shared_ptr<KNNModel> Ptr;
/*! Create a KNN model with default parameters */
static KNNModel::Ptr create();
enum Param {
/*! Number of neighbors used for prediction */
K = 0,
/*! Matching algorithm, see MatchingAlgorithm */
Matching,
/*! Distance type used for matching, see DistanceType */
Distance,
/*! How neighbors are weighted in prediction voting, see WeightBy */
Weight,
};
enum MatchingAlgorithm {
/*! Naive Brute-Force algorithm */
BruteForce = 0,
/*! Samples are organized and searched in KD-Tree */
KDTree,
/*! Sample are organized via K-Means algorithm */
Kmeans,
/*! Composite with KD-Tree and K-Means */
Composite,
/*! Sample are organized via hierarchical clustering */
Hierarchical,
/*! Sample are organized via auto-tuned indexing */
Auto
};
enum DistanceType {
/*! L2 distance */
Euclidean = 0,
/*! L1 distance */
Manhattan,
/*! Max(ai, bi), can only use together with BruteForce or Kmeans or Hierarchical matching */
Max,
/*! Sum(Min(ai, bi)) */
HistIntersection,
/*! Sum((Sqrt(ai) - Sqrt(bi))^2) */
Hellinger,
/*! Sum((ai - bi)^2/(ai + bi)) for ai + bi > 0 */
ChiSquare,
/*! Count(a xor b), used for binary feature, can only use together with BruteForce or Hierarchical matching */
Hamming,
};
enum WeightBy {
/*! Weight by distance to provided feature vector */
ByDistance = 0,
/*! Uniformed weight */
Uniformed
};
};
/*! \brief Classification model: Support Vector Machine */
class SVCModel : public CyclopsMLModel
{
public:
virtual ~SVCModel() {}
typedef std::shared_ptr<SVCModel> Ptr;
/*! Create a SVC model with default parameters */
static SVCModel::Ptr create();
enum Param {
/*! Type of SVM, see SVMType */
SVMType = 0,
/*! Constant used by C-SVM */
C,
/*! Nu used by Nu-SVM */
Nu,
/*! Kernel function, see KernelType */
KernelType,
/*! Parameter in kernel function, g */
Gamma,
/*! Parameter in kernel function, c */
Coef,
/*! Parameter in kernel function, n */
Degree
};
enum SVMType {
/*! C-SVM */
CSupport = 0,
/*! Nu-SVM */
NuSupport
};
enum KernelType {
/*! x dot y */
Linear = 0,
/*! (g (x dot y) + c)^n */
Poly,
/*! exp(-g |x - y|^2) */
RBF,
/*! tanh(g (x dot y) + c) */
Sigmoid,
// CHI2, Inter, not useful
};
};
#endif // CyclopsMLModel_h_