Skip to content

Commit 1ca5f75

Browse files
authored
Merge pull request #95 from GWmodel-Lab/fix/GTWR-lambda
Feature: GTWR optimize for lambda&bandwidth
2 parents 00110f7 + fde6bb4 commit 1ca5f75

4 files changed

Lines changed: 341 additions & 67 deletions

File tree

include/gwmodelpp/GTWR.h

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ class GTWR : public GWRBase, public IBandwidthSelectable, public IParallelizabl
7777
return std::string(GWM_LOG_TAG_LAMBDA_OPTIMIZATION) + std::to_string(lambda) + "," + std::to_string(criterion);
7878
}
7979

80+
const double getLambda() {
81+
if (mStdistance != nullptr) {
82+
return mStdistance->lambda();
83+
} else {
84+
throw std::runtime_error("mStdistance is not initialized");
85+
}
86+
}
87+
88+
const double getAngle(){
89+
if (mStdistance != nullptr) {
90+
return mStdistance->angle();
91+
} else {
92+
throw std::runtime_error("mStdistance is not initialized");
93+
}
94+
}
95+
8096
private:
8197

8298
/**
@@ -578,7 +594,7 @@ class GTWR : public GWRBase, public IBandwidthSelectable, public IParallelizabl
578594
* @param bandwidthWeight 传入带宽值,来获取权重,后续更方便改成多元优化.
579595
* @return double 返回优选以后的lambda值.
580596
*/
581-
double LambdaAutoSelection(BandwidthWeight* bandwidthWeight);
597+
double lambdaAutoSelection(BandwidthWeight* bandwidthWeight);
582598

583599
/**
584600
* \~english
@@ -594,7 +610,59 @@ class GTWR : public GWRBase, public IBandwidthSelectable, public IParallelizabl
594610
* @param rsquare 根据输入的lambda值和带宽获取的R方值.
595611
* @return Status 算法运行状态。
596612
*/
597-
Status RsquareByLambda(BandwidthWeight* bandwidthWeight,double lambda, double& rsquare);
613+
Status r_squareByLambda(BandwidthWeight* bandwidthWeight,double lambda, double& rsquare);
614+
615+
struct Parameter {
616+
GTWR* instance; // GTWR实例
617+
BandwidthWeight* bandwidth; // 带宽
618+
double lambda; // lambda
619+
};
620+
621+
/**
622+
* \~english
623+
* @brief criterion function for gsl_multimin_function and params.
624+
* @param v gsl_vector target,
625+
* @param params the params.
626+
* @return criterion.
627+
* \~chinese
628+
* @brief 构建gsl的gsl_multimin_function以及优化指标
629+
* @param v gsl的优化向量(lambda, bw)
630+
* @param params 传入的参数,从void*转换成Parameter*
631+
* @return CV或AIC的指标值
632+
*/
633+
static double criterion_function (const gsl_vector *v, void *params);
634+
635+
/**
636+
* \~english
637+
* @brief gsl lambda bandwidth auto-selection function.
638+
* @param bandwidth BandwidthWeight,
639+
* @param max_iter max iter, internal set 1000.
640+
* @param min_step min steps for optimize change, internal set 0.01.
641+
* @return vector for (lambda, bw).
642+
* \~chinese
643+
* @brief 优化的主函数
644+
* @param bandwidth BandwidthWeight类型,带宽
645+
* @param max_iter 最大迭代次数,设置为1000
646+
* @param min_step 优化中的步长,设置为0.01(变化阈值为步长/1000)
647+
* @return 优化结果,一个向量:(lambda, bw).
648+
*/
649+
arma::vec lambdaBwAutoSelection(BandwidthWeight* bandwidth, size_t max_iter, double min_step);
650+
651+
/**
652+
* \~english
653+
* @brief criterion function by Lambda and Bw.
654+
* @param bandwidth bandwidth weight parameters,
655+
* @param lambda the lambda value.
656+
* @param criterion criterion type, BandwidthSelectionCriterionType.
657+
* @return criterion value.
658+
* \~chinese
659+
* @brief 利用带宽和lambda计算指标的函数
660+
* @param bandwidthWeight 输入带宽权重,
661+
* @param lambda 获取指标的lambda值
662+
* @param criterion BandwidthSelectionCriterionType类型,确定求的指标
663+
* @return 对应类型的指标值
664+
*/
665+
double criterionByLambdaBw(BandwidthWeight* bandwidth, double lambda, BandwidthSelectionCriterionType criterion);
598666

599667
public:
600668
/**
@@ -611,6 +679,8 @@ class GTWR : public GWRBase, public IBandwidthSelectable, public IParallelizabl
611679
*/
612680
void setIsAutoselectLambda(bool isAutoSelect) { mIsAutoselectLambda = isAutoSelect; }
613681

682+
void setIsAutoselectLambdaBw(bool isAutoSelect) { mIsAutoselectLambdaBw = isAutoSelect; }
683+
614684
protected:
615685

616686
bool mHasHatMatrix = true; //!< \~english Whether has hat-matrix. \~chinese 是否具有帽子矩阵。
@@ -619,6 +689,7 @@ class GTWR : public GWRBase, public IBandwidthSelectable, public IParallelizabl
619689

620690
bool mIsAutoselectBandwidth = false;//!< \~english Whether need bandwidth autoselect. \~chinese 是否需要自动优选带宽。
621691
bool mIsAutoselectLambda = false;//!< \~english Whether need lambda autoselect. \~chinese 是否需要自动优选lambda。
692+
bool mIsAutoselectLambdaBw = false;
622693

623694
BandwidthSelectionCriterionType mBandwidthSelectionCriterion = BandwidthSelectionCriterionType::AIC;//!< \~english Bandwidth Selection Criterion Type. \~chinese 默认的带宽优选方式。
624695
BandwidthSelectionCriterionCalculator mBandwidthSelectionCriterionFunction = &GTWR::bandwidthSizeCriterionCVSerial;//!< \~english Bandwidth Selection Criterion Function. \~chinese 默认的带宽优选函数。
@@ -640,7 +711,6 @@ class GTWR : public GWRBase, public IBandwidthSelectable, public IParallelizabl
640711

641712
CRSSTDistance* mStdistance;//use to change spatial temporal distance including lambda
642713

643-
// gsl_function F;
644714
};
645715

646716
}

include/gwmodelpp/spatialweight/CRSSTDistance.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class CRSSTDistance : public Distance
118118
const gwm::OneDimDistance* temporalDistance() const { return mTemporalDistance; }
119119

120120
// unused code to set lambda
121-
// double lambda() const { return mLambda; }
121+
const double lambda(){ return mLambda; }
122122
void setLambda(const double lambda) {
123123
if (lambda >= 0 && lambda <= 1)
124124
{
@@ -128,6 +128,11 @@ class CRSSTDistance : public Distance
128128
throw std::runtime_error("The lambda must be in [0,1].");
129129
}
130130

131+
const double angle() { return mAngle; }
132+
void setAngle(const double angle) {
133+
mAngle = angle;
134+
}
135+
131136
protected:
132137
Distance* mSpatialDistance = nullptr; //!< \~english Pointer to instance for spatial distance \~chinese 指向空间距离的指针
133138
gwm::OneDimDistance* mTemporalDistance = nullptr; //!< \~english Pointer to instance for temporal distance \~chinese 指向时间距离的指针

0 commit comments

Comments
 (0)