Skip to content

Commit fafa4f0

Browse files
authored
Add PersonLab human pose estimator (tensorflow#563)
* First personlab commit * Add DepthwiseSeparableConvBloc * Finish Personlab convnet * Add heap structure for priority queue for personlab Heap code taken from this tutorial: https://www.raywenderlich.com/586-swift-algorithm-club-heap-and-priority-queue-data-structure * Start working on decoder, added some data structures and function Intermediate non-compiling state * Heavily modify data structures, get forward crawl working It compiles now * Get backward crawl working * Add checks which I had ignored from original python implementation * Unify forward and backward concepts into single outward concept * Under heavy debugging, got poses to be correct but filter is buggy * Fixed all bugs in decoder Still gotta refactor and clean debugging code * Remove a lot of debugging code * Delete some comments * Remove redundant index calculations * Temporary solution: Move Tensors to custom CPU container on decoder init * Remove some profiling code * Add SwiftCV to dependencies * Add SwiftCV image loading from disc and from webcam * Add pose drawing capability using SwiftCV * Add Personlab CLI and improve readme * Add link to checkpoint * Improve readme * Update README.md Tried it in the 0.9rc and it ran more slowly than on 0.8. * Reduce wait time between frames on webcam demo We were waiting 5 seconds between each frame, lol. * Unify all model parts into a single PersonLab model struct * Remove binary heap, which we didn't need after all. On the initial implementation we added and removed candidate keypoints through our code, turns out we can just sort them all at the beginning, so having a priority queue is unnecessary. * Improve some comments * Improve profiling print formatting * Move code around a bit * Rename all Personlab instances to PersonLab as is written in paper * Minor refactor * Add error checking for file loading * Update code to accomodate for `Image.resized` not adding batch dim now * Move PersonLab out of Examples/ and into root dir * Improve profiling code * Update PersonLab README * Use swift-linter on PersonLab * Capitalize PersonLab files to better match Swift standard style * Add copyright headers to PersonLab files * Remove SwiftCV dependency This removes video and image drawing support. * Add line drawing support * Run swift-format * Fix version regression on Package.swift and outdated abstract on cli tool * Add automatic checkpoint downloading from default URL * Fix merge conflicts * Remove redundant tensor declarations in PersonLab backbone
1 parent ed7923e commit fafa4f0

10 files changed

+905
-0
lines changed

Package.swift

+3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ let package = Package(
8181
name: "MobileNetV2-Imagenette",
8282
dependencies: ["Datasets", "ImageClassificationModels", "TrainingLoop"],
8383
path: "Examples/MobileNetV2-Imagenette"),
84+
.target(
85+
name: "PersonLab", dependencies: ["Checkpoints", "ModelSupport", .product(name: "ArgumentParser", package: "swift-argument-parser")],
86+
path: "PersonLab"),
8487
.target(
8588
name: "MiniGo", dependencies: ["Checkpoints"], path: "MiniGo", exclude: ["main.swift"]),
8689
.target(

PersonLab/Backbone.swift

+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Checkpoints
16+
import TensorFlow
17+
18+
public struct DepthwiseSeparableConvBlock: Layer {
19+
var dConv: DepthwiseConv2D<Float>
20+
var conv: Conv2D<Float>
21+
22+
public init(
23+
depthWiseFilter: Tensor<Float>,
24+
depthWiseBias: Tensor<Float>,
25+
pointWiseFilter: Tensor<Float>,
26+
pointWiseBias: Tensor<Float>,
27+
strides: (Int, Int)
28+
) {
29+
30+
dConv = DepthwiseConv2D<Float>(
31+
filter: depthWiseFilter,
32+
bias: depthWiseBias,
33+
activation: relu6,
34+
strides: strides,
35+
padding: .same
36+
)
37+
38+
conv = Conv2D<Float>(
39+
filter: pointWiseFilter,
40+
bias: pointWiseBias,
41+
activation: relu6,
42+
padding: .same
43+
)
44+
}
45+
46+
@differentiable
47+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
48+
return input.sequenced(through: dConv, conv)
49+
}
50+
}
51+
52+
public struct MobileNetLikeBackbone: Layer {
53+
@noDerivative let ckpt: CheckpointReader
54+
55+
public var convBlock0: Conv2D<Float>
56+
public var dConvBlock1: DepthwiseSeparableConvBlock
57+
public var dConvBlock2: DepthwiseSeparableConvBlock
58+
public var dConvBlock3: DepthwiseSeparableConvBlock
59+
public var dConvBlock4: DepthwiseSeparableConvBlock
60+
public var dConvBlock5: DepthwiseSeparableConvBlock
61+
public var dConvBlock6: DepthwiseSeparableConvBlock
62+
public var dConvBlock7: DepthwiseSeparableConvBlock
63+
public var dConvBlock8: DepthwiseSeparableConvBlock
64+
public var dConvBlock9: DepthwiseSeparableConvBlock
65+
public var dConvBlock10: DepthwiseSeparableConvBlock
66+
public var dConvBlock11: DepthwiseSeparableConvBlock
67+
public var dConvBlock12: DepthwiseSeparableConvBlock
68+
public var dConvBlock13: DepthwiseSeparableConvBlock
69+
70+
public init(checkpoint: CheckpointReader) {
71+
self.ckpt = checkpoint
72+
73+
self.convBlock0 = Conv2D<Float>(
74+
filter: ckpt.load(from: "Conv2d_0/weights"),
75+
bias: ckpt.load(from: "Conv2d_0/biases"),
76+
activation: relu6,
77+
strides: (2, 2),
78+
padding: .same
79+
)
80+
self.dConvBlock1 = DepthwiseSeparableConvBlock(
81+
depthWiseFilter: ckpt.load(from: "Conv2d_1_depthwise/depthwise_weights"),
82+
depthWiseBias: ckpt.load(from: "Conv2d_1_depthwise/biases"),
83+
pointWiseFilter: ckpt.load(from: "Conv2d_1_pointwise/weights"),
84+
pointWiseBias: ckpt.load(from: "Conv2d_1_pointwise/biases"),
85+
strides: (1, 1)
86+
)
87+
self.dConvBlock2 = DepthwiseSeparableConvBlock(
88+
depthWiseFilter: ckpt.load(from: "Conv2d_2_depthwise/depthwise_weights"),
89+
depthWiseBias: ckpt.load(from: "Conv2d_2_depthwise/biases"),
90+
pointWiseFilter: ckpt.load(from: "Conv2d_2_pointwise/weights"),
91+
pointWiseBias: ckpt.load(from: "Conv2d_2_pointwise/biases"),
92+
strides: (2, 2)
93+
)
94+
self.dConvBlock3 = DepthwiseSeparableConvBlock(
95+
depthWiseFilter: ckpt.load(from: "Conv2d_3_depthwise/depthwise_weights"),
96+
depthWiseBias: ckpt.load(from: "Conv2d_3_depthwise/biases"),
97+
pointWiseFilter: ckpt.load(from: "Conv2d_3_pointwise/weights"),
98+
pointWiseBias: ckpt.load(from: "Conv2d_3_pointwise/biases"),
99+
strides: (1, 1)
100+
)
101+
self.dConvBlock4 = DepthwiseSeparableConvBlock(
102+
depthWiseFilter: ckpt.load(from: "Conv2d_4_depthwise/depthwise_weights"),
103+
depthWiseBias: ckpt.load(from: "Conv2d_4_depthwise/biases"),
104+
pointWiseFilter: ckpt.load(from: "Conv2d_4_pointwise/weights"),
105+
pointWiseBias: ckpt.load(from: "Conv2d_4_pointwise/biases"),
106+
strides: (2, 2)
107+
)
108+
self.dConvBlock5 = DepthwiseSeparableConvBlock(
109+
depthWiseFilter: ckpt.load(from: "Conv2d_5_depthwise/depthwise_weights"),
110+
depthWiseBias: ckpt.load(from: "Conv2d_5_depthwise/biases"),
111+
pointWiseFilter: ckpt.load(from: "Conv2d_5_pointwise/weights"),
112+
pointWiseBias: ckpt.load(from: "Conv2d_5_pointwise/biases"),
113+
strides: (1, 1)
114+
)
115+
self.dConvBlock6 = DepthwiseSeparableConvBlock(
116+
depthWiseFilter: ckpt.load(from: "Conv2d_6_depthwise/depthwise_weights"),
117+
depthWiseBias: ckpt.load(from: "Conv2d_6_depthwise/biases"),
118+
pointWiseFilter: ckpt.load(from: "Conv2d_6_pointwise/weights"),
119+
pointWiseBias: ckpt.load(from: "Conv2d_6_pointwise/biases"),
120+
strides: (2, 2)
121+
)
122+
self.dConvBlock7 = DepthwiseSeparableConvBlock(
123+
depthWiseFilter: ckpt.load(from: "Conv2d_7_depthwise/depthwise_weights"),
124+
depthWiseBias: ckpt.load(from: "Conv2d_7_depthwise/biases"),
125+
pointWiseFilter: ckpt.load(from: "Conv2d_7_pointwise/weights"),
126+
pointWiseBias: ckpt.load(from: "Conv2d_7_pointwise/biases"),
127+
strides: (1, 1)
128+
)
129+
self.dConvBlock8 = DepthwiseSeparableConvBlock(
130+
depthWiseFilter: ckpt.load(from: "Conv2d_8_depthwise/depthwise_weights"),
131+
depthWiseBias: ckpt.load(from: "Conv2d_8_depthwise/biases"),
132+
pointWiseFilter: ckpt.load(from: "Conv2d_8_pointwise/weights"),
133+
pointWiseBias: ckpt.load(from: "Conv2d_8_pointwise/biases"),
134+
strides: (1, 1)
135+
)
136+
self.dConvBlock9 = DepthwiseSeparableConvBlock(
137+
depthWiseFilter: ckpt.load(from: "Conv2d_9_depthwise/depthwise_weights"),
138+
depthWiseBias: ckpt.load(from: "Conv2d_9_depthwise/biases"),
139+
pointWiseFilter: ckpt.load(from: "Conv2d_9_pointwise/weights"),
140+
pointWiseBias: ckpt.load(from: "Conv2d_9_pointwise/biases"),
141+
strides: (1, 1)
142+
)
143+
self.dConvBlock10 = DepthwiseSeparableConvBlock(
144+
depthWiseFilter: ckpt.load(from: "Conv2d_10_depthwise/depthwise_weights"),
145+
depthWiseBias: ckpt.load(from: "Conv2d_10_depthwise/biases"),
146+
pointWiseFilter: ckpt.load(from: "Conv2d_10_pointwise/weights"),
147+
pointWiseBias: ckpt.load(from: "Conv2d_10_pointwise/biases"),
148+
strides: (1, 1)
149+
)
150+
self.dConvBlock11 = DepthwiseSeparableConvBlock(
151+
depthWiseFilter: ckpt.load(from: "Conv2d_11_depthwise/depthwise_weights"),
152+
depthWiseBias: ckpt.load(from: "Conv2d_11_depthwise/biases"),
153+
pointWiseFilter: ckpt.load(from: "Conv2d_11_pointwise/weights"),
154+
pointWiseBias: ckpt.load(from: "Conv2d_11_pointwise/biases"),
155+
strides: (1, 1)
156+
)
157+
self.dConvBlock12 = DepthwiseSeparableConvBlock(
158+
depthWiseFilter: ckpt.load(from: "Conv2d_12_depthwise/depthwise_weights"),
159+
depthWiseBias: ckpt.load(from: "Conv2d_12_depthwise/biases"),
160+
pointWiseFilter: ckpt.load(from: "Conv2d_12_pointwise/weights"),
161+
pointWiseBias: ckpt.load(from: "Conv2d_12_pointwise/biases"),
162+
strides: (1, 1)
163+
)
164+
self.dConvBlock13 = DepthwiseSeparableConvBlock(
165+
depthWiseFilter: ckpt.load(from: "Conv2d_13_depthwise/depthwise_weights"),
166+
depthWiseBias: ckpt.load(from: "Conv2d_13_depthwise/biases"),
167+
pointWiseFilter: ckpt.load(from: "Conv2d_13_pointwise/weights"),
168+
pointWiseBias: ckpt.load(from: "Conv2d_13_pointwise/biases"),
169+
strides: (1, 1)
170+
)
171+
}
172+
173+
@differentiable
174+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
175+
var x = convBlock0(input)
176+
x = dConvBlock1(x)
177+
x = dConvBlock2(x)
178+
x = dConvBlock3(x)
179+
x = dConvBlock4(x)
180+
x = dConvBlock5(x)
181+
x = dConvBlock6(x)
182+
x = dConvBlock7(x)
183+
x = dConvBlock8(x)
184+
x = dConvBlock9(x)
185+
x = dConvBlock10(x)
186+
x = dConvBlock11(x)
187+
x = dConvBlock12(x)
188+
x = dConvBlock13(x)
189+
return x
190+
}
191+
192+
}

PersonLab/Decoder.swift

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
import TensorFlow
17+
18+
// This whole struct should probably be merged into the PersonLab model struct when we no longer
19+
// need to do CPUTensor wrapping when SwiftRT fixes the GPU->CPU copy issue.
20+
struct PoseDecoder {
21+
let heatmap: CPUTensor<Float>
22+
let offsets: CPUTensor<Float>
23+
let displacementsFwd: CPUTensor<Float>
24+
let displacementsBwd: CPUTensor<Float>
25+
let config: Config
26+
27+
init(for results: PersonlabHeadsResults, with config: Config) {
28+
// Hardcoded to batch size == 1 at the moment
29+
self.heatmap = CPUTensor<Float>(results.heatmap[0])
30+
self.offsets = CPUTensor<Float>(results.offsets[0])
31+
self.displacementsFwd = CPUTensor<Float>(results.displacementsFwd[0])
32+
self.displacementsBwd = CPUTensor<Float>(results.displacementsBwd[0])
33+
self.config = config
34+
}
35+
36+
func decode() -> [Pose] {
37+
var poses = [Pose]()
38+
var sortedLocallyMaximumKeypoints = getSortedLocallyMaximumKeypoints()
39+
while sortedLocallyMaximumKeypoints.count > 0 {
40+
let rootKeypoint = sortedLocallyMaximumKeypoints.removeFirst()
41+
if rootKeypoint.isWithinRadiusOfCorrespondingKeypoints(in: poses, radius: config.nmsRadius) {
42+
continue
43+
}
44+
45+
var pose = Pose(resolution: self.config.inputImageSize)
46+
pose.add(rootKeypoint)
47+
48+
// Recursivelly parse keypoint tree going in both forwards & backwards directions optimally
49+
recursivellyAddNextKeypoint(
50+
after: rootKeypoint,
51+
into: &pose
52+
)
53+
54+
if getPoseScore(for: pose, considering: poses) > config.poseScoreThreshold {
55+
poses.append(pose)
56+
}
57+
}
58+
return poses
59+
}
60+
61+
func recursivellyAddNextKeypoint(after previousKeypoint: Keypoint, into pose: inout Pose) {
62+
for (nextKeypointIndex, direction) in getNextKeypointIndexAndDirection(previousKeypoint.index) {
63+
if pose.getKeypoint(nextKeypointIndex) == nil {
64+
let nextKeypoint = followDisplacement(
65+
from: previousKeypoint,
66+
to: nextKeypointIndex,
67+
using: direction == .fwd ? displacementsFwd : displacementsBwd
68+
)
69+
pose.add(nextKeypoint)
70+
recursivellyAddNextKeypoint(after: nextKeypoint, into: &pose)
71+
}
72+
}
73+
}
74+
75+
func followDisplacement(
76+
from previousKeypoint: Keypoint, to nextKeypointIndex: KeypointIndex,
77+
using displacements: CPUTensor<Float>
78+
) -> Keypoint {
79+
let displacementKeypointIndexY = keypointPairToDisplacementIndexMap[
80+
Set([previousKeypoint.index, nextKeypointIndex])]!
81+
let displacementKeypointIndexX = displacementKeypointIndexY + displacements.shape[2] / 2
82+
let displacementYIndex = getUnstridedIndex(y: previousKeypoint.y)
83+
let displacementXIndex = getUnstridedIndex(x: previousKeypoint.x)
84+
85+
let displacementY = displacements[
86+
displacementYIndex,
87+
displacementXIndex,
88+
displacementKeypointIndexY
89+
]
90+
let displacementX = displacements[
91+
displacementYIndex,
92+
displacementXIndex,
93+
displacementKeypointIndexX
94+
]
95+
96+
let displacedY = getUnstridedIndex(y: previousKeypoint.y + displacementY)
97+
let displacedX = getUnstridedIndex(x: previousKeypoint.x + displacementX)
98+
99+
let yOffset = offsets[
100+
displacedY,
101+
displacedX,
102+
nextKeypointIndex.rawValue
103+
]
104+
let xOffset = offsets[
105+
displacedY,
106+
displacedX,
107+
nextKeypointIndex.rawValue + KeypointIndex.allCases.count
108+
]
109+
110+
// If we are getting the offset from an exact point in the heatmap, we should add this
111+
// offset parting from that exact point in the heatmap, so we just nearest neighbour
112+
// interpolate it back, then re strech using output stride, and then add said offset.
113+
let nextY = Float(displacedY * config.outputStride) + yOffset
114+
let nextX = Float(displacedX * config.outputStride) + xOffset
115+
116+
return Keypoint(
117+
y: nextY,
118+
x: nextX,
119+
index: nextKeypointIndex,
120+
score: heatmap[
121+
displacedY, displacedX, nextKeypointIndex.rawValue
122+
]
123+
)
124+
}
125+
126+
func scoreIsMaximumInLocalWindow(heatmapY: Int, heatmapX: Int, score: Float, keypointIndex: Int)
127+
-> Bool
128+
{
129+
let yStart = max(heatmapY - config.keypointLocalMaximumRadius, 0)
130+
let yEnd = min(heatmapY + config.keypointLocalMaximumRadius, heatmap.shape[0] - 1)
131+
for windowY in yStart...yEnd {
132+
let xStart = max(heatmapX - config.keypointLocalMaximumRadius, 0)
133+
let xEnd = min(heatmapX + config.keypointLocalMaximumRadius, heatmap.shape[1] - 1)
134+
for windowX in xStart...xEnd {
135+
if heatmap[windowY, windowX, keypointIndex] > score {
136+
return false
137+
}
138+
}
139+
}
140+
return true
141+
}
142+
143+
func getUnstridedIndex(y: Float) -> Int {
144+
let downScaled = y / Float(config.outputStride)
145+
let clamped = min(max(0, downScaled.rounded()), Float(heatmap.shape[0] - 1))
146+
return Int(clamped)
147+
}
148+
149+
func getUnstridedIndex(x: Float) -> Int {
150+
let downScaled = x / Float(config.outputStride)
151+
let clamped = min(max(0, downScaled.rounded()), Float(heatmap.shape[1] - 1))
152+
return Int(clamped)
153+
}
154+
155+
func getSortedLocallyMaximumKeypoints() -> [Keypoint] {
156+
var sortedLocallyMaximumKeypoints = [Keypoint]()
157+
for heatmapY in 0..<heatmap.shape[0] {
158+
for heatmapX in 0..<heatmap.shape[1] {
159+
for keypointIndex in 0..<heatmap.shape[2] {
160+
let score = heatmap[heatmapY, heatmapX, keypointIndex]
161+
162+
if score < config.keypointScoreThreshold { continue }
163+
if scoreIsMaximumInLocalWindow(
164+
heatmapY: heatmapY,
165+
heatmapX: heatmapX,
166+
score: score,
167+
keypointIndex: keypointIndex
168+
) {
169+
sortedLocallyMaximumKeypoints.append(
170+
Keypoint(
171+
heatmapY: heatmapY,
172+
heatmapX: heatmapX,
173+
index: keypointIndex,
174+
score: score,
175+
offsets: offsets,
176+
outputStride: config.outputStride
177+
)
178+
)
179+
}
180+
}
181+
}
182+
}
183+
sortedLocallyMaximumKeypoints.sort { $0.score > $1.score }
184+
return sortedLocallyMaximumKeypoints
185+
}
186+
187+
func getPoseScore(for pose: Pose, considering poses: [Pose]) -> Float {
188+
var notOverlappedKeypointScoreAccumulator: Float = 0
189+
for keypoint in pose.keypoints {
190+
if !keypoint!.isWithinRadiusOfCorrespondingKeypoints(in: poses, radius: config.nmsRadius) {
191+
notOverlappedKeypointScoreAccumulator += keypoint!.score
192+
}
193+
}
194+
return notOverlappedKeypointScoreAccumulator / Float(KeypointIndex.allCases.count)
195+
}
196+
}

0 commit comments

Comments
 (0)