Skip to content

Commit 40b83a1

Browse files
authored
Merge pull request #45 from GPUEngineering/b/qr-malloc-error
fix initialisation of Q
2 parents c6bf540 + 4cbc3ae commit 40b83a1

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88

9+
<!-- ---------------------
10+
v1.2.1
11+
--------------------- -->
12+
## v1.2.1 - 07-10-2024
13+
14+
### Added
15+
16+
- Patch initialisation of Q in QR decomposition.
17+
- Add test for tall skinny matrices.
18+
919
<!-- ---------------------
1020
v1.2.0
1121
--------------------- -->

include/tensor.cuh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,8 +1488,10 @@ inline int QRFactoriser<double>::getQR(DTensor<double> &Q, DTensor<double> &R) {
14881488
throw std::invalid_argument("[QR] invalid shape of R.");
14891489
// Initialize Q to 1's on diagonal
14901490
std::vector<double> vecQ(m * n, 0.);
1491-
for (size_t i = 0; i < m; i++) {
1492-
vecQ[i * n + i] = 1.;
1491+
for (size_t r = 0; r < m; r++) {
1492+
for (size_t c = 0; c < n; c++) {
1493+
if (r == c) { vecQ[r * n + c] = 1.; }
1494+
}
14931495
}
14941496
Q.upload(vecQ, rowMajor);
14951497
// Apply Householder reflectors to compute Q
@@ -1521,8 +1523,10 @@ inline int QRFactoriser<float>::getQR(DTensor<float> &Q, DTensor<float> &R) {
15211523
throw std::invalid_argument("[QR] invalid shape of R.");
15221524
// Initialize Q to 1's on diagonal
15231525
std::vector<float> vecQ(m * n, 0.);
1524-
for (size_t i = 0; i < m; i++) {
1525-
vecQ[i * n + i] = 1.;
1526+
for (size_t r = 0; r < m; r++) {
1527+
for (size_t c = 0; c < n; c++) {
1528+
if (r == c) { vecQ[r * n + c] = 1.; }
1529+
}
15261530
}
15271531
Q.upload(vecQ, rowMajor);
15281532
// Apply Householder reflectors to compute Q

test/testTensor.cu

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,37 @@ TEST_F(QRTest, qrFactorisation) {
10791079
qrFactorisation<double>(PRECISION_HIGH);
10801080
}
10811081

1082+
/* ---------------------------------------
1083+
* QR factorisation
1084+
* - tall and skinny matrix
1085+
* --------------------------------------- */
1086+
1087+
TEMPLATE_WITH_TYPE_T TEMPLATE_CONSTRAINT_REQUIRES_FPX
1088+
void qrFactorisationTall(T epsilon) {
1089+
size_t nR = 20;
1090+
size_t nC = 3;
1091+
DTensor<T> temp(nR, nC);
1092+
DTensor<T> A = DTensor<T>::createRandomTensor(nR, nC, 1, -100, 100);
1093+
QRFactoriser<T> qr(temp);
1094+
A.deviceCopyTo(temp);
1095+
int status = qr.factorise();
1096+
EXPECT_EQ(status, 0);
1097+
DTensor<T> Q(nR, nC);
1098+
DTensor<T> R(nC, nC, 1, true);
1099+
DTensor<T> QR(nR, nC);
1100+
status = qr.getQR(Q, R);
1101+
EXPECT_EQ(status, 0);
1102+
QR.addAB(Q, R);
1103+
QR -= A;
1104+
T nrm = QR.normF();
1105+
EXPECT_NEAR(nrm, 0., epsilon);
1106+
}
1107+
1108+
TEST_F(QRTest, qrFactorisationTall) {
1109+
qrFactorisationTall<float>(PRECISION_LOW);
1110+
qrFactorisationTall<double>(PRECISION_HIGH);
1111+
}
1112+
10821113
/* ---------------------------------------
10831114
* QR factorisation: solve least squares
10841115
* --------------------------------------- */

0 commit comments

Comments
 (0)