Skip to content

Commit cb34be9

Browse files
authored
Move minDepth and maxDepth to estimation config (#1014)
1 parent 4cc7c20 commit cb34be9

File tree

10 files changed

+67
-61
lines changed

10 files changed

+67
-61
lines changed

depth-estimation/README.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@ For example:
3333

3434
```javascript
3535
const model = depthEstimation.SupportedModels.ARPortraitDepth;
36-
const estimatorConfig = {
37-
minDepth: 0,
38-
maxDepth: 1,
39-
}
40-
const estimator = await depthEstimation.createEstimator(model, estimatorConfig);
36+
const estimator = await depthEstimation.createEstimator(model);
4137
```
4238

4339
Next, you can use the estimator to estimate depth.
4440

4541
```javascript
46-
const depthMap = await estimator.estimateDepth(image);
42+
const estimationConfig = {
43+
minDepth: 0,
44+
maxDepth: 1,
45+
}
46+
const depthMap = await estimator.estimateDepth(image, estimationConfig);
4747
```
4848

4949
The returned depth map contains depth values for each pixel in the image.

depth-estimation/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@tensorflow-models/depth-estimation",
3-
"version": "0.0.1",
3+
"version": "0.0.2",
44
"description": "Pretrained depth model",
55
"main": "dist/index.js",
66
"jsnext:main": "dist/depth-estimation.esm.js",

depth-estimation/src/ar_portrait_depth/README.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,6 @@ Pass in `depthEstimation.SupportedModels.ARPortraitDepth` from the
6262

6363
`estimatorConfig` is an object that defines ARPortraitDepth specific configurations for `ARPortraitDepthModelConfig`:
6464

65-
* *minDepth*: The minimum depth value for the model to map to 0. Any smaller
66-
depth values will also get mapped to 0.
67-
68-
* *maxDepth*: The maximum depth value for the model to map to 1. Any larger
69-
depth values will also get mapped to 1.
70-
7165
* *segmentationModelUrl*: An optional string that specifies custom url of
7266
the segmenter model. This is useful for area/countries that don't have access to the model hosted on tf.hub. It also accepts `io.IOHandler` which can be used with
7367
[tfjs-react-native](https://github.com/tensorflow/tfjs/tree/master/tfjs-react-native)
@@ -97,6 +91,12 @@ options, you can pass in a second `estimationConfig` parameter.
9791

9892
`estimationConfig` is an object that defines ARPortraitDepth specific configurations for `ARPortraitDepthEstimationConfig`:
9993

94+
* *minDepth*: The minimum depth value for the model to map to 0. Any smaller
95+
depth values will also get mapped to 0.
96+
97+
* *maxDepth*: The maximum depth value for the model to map to 1. Any larger
98+
depth values will also get mapped to 1.
99+
100100
* *flipHorizontal*: Optional. Defaults to false. When image data comes from camera, the result has to flip horizontally.
101101

102102
The following code snippet demonstrates how to run the model inference:

depth-estimation/src/ar_portrait_depth/ar_portrait_depth_test.ts

+15-11
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,13 @@ describeWithFlags('ARPortraitDepth', ALL_ENVS, () => {
114114

115115
// Note: this makes a network request for model assets.
116116
const estimator = await depthEstimation.createEstimator(
117-
depthEstimation.SupportedModels.ARPortraitDepth,
118-
{minDepth: 0, maxDepth: 1});
117+
depthEstimation.SupportedModels.ARPortraitDepth);
119118
const input: tf.Tensor3D = tf.zeros([128, 128, 3]);
120119

121120
const beforeTensors = tf.memory().numTensors;
122121

123-
const depthMap = await estimator.estimateDepth(input);
122+
const depthMap =
123+
await estimator.estimateDepth(input, {minDepth: 0, maxDepth: 1});
124124

125125
(await depthMap.toTensor()).dispose();
126126
expect(tf.memory().numTensors).toEqual(beforeTensors);
@@ -133,21 +133,25 @@ describeWithFlags('ARPortraitDepth', ALL_ENVS, () => {
133133

134134
it('throws error when minDepth is not set.', async (done) => {
135135
try {
136-
await depthEstimation.createEstimator(
136+
const estimator = await depthEstimation.createEstimator(
137137
depthEstimation.SupportedModels.ARPortraitDepth);
138+
const input: tf.Tensor3D = tf.zeros([128, 128, 3]);
139+
await estimator.estimateDepth(input);
138140
done.fail('Loading without minDepth succeeded unexpectedly.');
139141
} catch (e) {
140142
expect(e.message).toEqual(
141-
'A model config with minDepth and maxDepth set must be provided.');
143+
'An estimation config with ' +
144+
'minDepth and maxDepth set must be provided.');
142145
done();
143146
}
144147
});
145148

146149
it('throws error when minDepth is greater than maxDepth.', async (done) => {
147150
try {
148-
await depthEstimation.createEstimator(
149-
depthEstimation.SupportedModels.ARPortraitDepth,
150-
{minDepth: 1, maxDepth: 0.99});
151+
const estimator = await depthEstimation.createEstimator(
152+
depthEstimation.SupportedModels.ARPortraitDepth);
153+
const input: tf.Tensor3D = tf.zeros([128, 128, 3]);
154+
await estimator.estimateDepth(input, {minDepth: 1, maxDepth: 0.99});
151155
done.fail(
152156
'Loading with minDepth greater than maxDepth ' +
153157
'succeeded unexpectedly.');
@@ -187,12 +191,12 @@ describeWithFlags('ARPortraitDepth static image ', BROWSER_ENVS, () => {
187191
// Get actual depth values.
188192
// Note: this makes a network request for model assets.
189193
estimator = await depthEstimation.createEstimator(
190-
depthEstimation.SupportedModels.ARPortraitDepth,
191-
{minDepth: 0.2, maxDepth: 0.9});
194+
depthEstimation.SupportedModels.ARPortraitDepth);
192195

193196
const beforeTensors = tf.memory().numTensors;
194197

195-
const result = await estimator.estimateDepth(image);
198+
const result =
199+
await estimator.estimateDepth(image, {minDepth: 0.2, maxDepth: 0.9});
196200
const actualDepthValues = await result.toTensor();
197201
const coloredDepthValues =
198202
actualDepthValues.arraySync().flat().map(value => turboPlus(value));

depth-estimation/src/ar_portrait_depth/constants.ts

+8-4
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
* limitations under the License.
1515
* =============================================================================
1616
*/
17-
import {ARPortraitDepthEstimationConfig} from './types';
17+
import {ARPortraitDepthModelConfig} from './types';
1818

1919
export const DEFAULT_AR_PORTRAIT_DEPTH_MODEL_URL =
2020
'https://tfhub.dev/tensorflow/tfjs-model/ar_portrait_depth/1';
2121

22-
export const DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG:
23-
ARPortraitDepthEstimationConfig = {
24-
flipHorizontal: false,
22+
export const DEFAULT_AR_PORTRAIT_DEPTH_MODEL_CONFIG:
23+
ARPortraitDepthModelConfig = {
24+
depthModelUrl: DEFAULT_AR_PORTRAIT_DEPTH_MODEL_URL,
2525
};
26+
27+
export const DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG = {
28+
flipHorizontal: false,
29+
};

depth-estimation/src/ar_portrait_depth/estimator.ts

+10-6
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ const PORTRAIT_WIDTH = 192;
5656
class ARPortraitDepthEstimator implements DepthEstimator {
5757
constructor(
5858
private readonly segmenter: bodySegmentation.BodySegmenter,
59-
private readonly estimatorModel: tfconv.GraphModel,
60-
private readonly minDepth: number, private readonly maxDepth: number) {}
59+
private readonly estimatorModel: tfconv.GraphModel) {}
6160

6261
/**
6362
* Estimates depth for an image or video frame.
@@ -69,8 +68,14 @@ class ARPortraitDepthEstimator implements DepthEstimator {
6968
* image to feed through the network.
7069
*
7170
* @param config Optional.
71+
* minDepth: The minimum depth value for the model to map to 0. Any
72+
* smaller depth values will also get mapped to 0.
73+
*
74+
* maxDepth`: The maximum depth value for the model to map to 1. Any
75+
* larger depth values will also get mapped to 1.
76+
*
7277
* flipHorizontal: Optional. Default to false. When image data comes
73-
* from camera, the result has to flip horizontally.
78+
* from camera, the result has to flip horizontally.
7479
*
7580
* @return `DepthMap`.
7681
*/
@@ -128,7 +133,7 @@ class ARPortraitDepthEstimator implements DepthEstimator {
128133

129134
// Normalize to user requirements.
130135
const depthTransform =
131-
transformValueRange(this.minDepth, this.maxDepth, 0, 1);
136+
transformValueRange(config.minDepth, config.maxDepth, 0, 1);
132137

133138
// depth4D is roughly in [0,2] range, so half the scale factor to put it
134139
// in [0,1] range.
@@ -188,6 +193,5 @@ export async function load(modelConfig: ARPortraitDepthModelConfig):
188193
bodySegmentation.SupportedModels.MediaPipeSelfieSegmentation,
189194
{runtime: 'tfjs', modelUrl: config.segmentationModelUrl});
190195

191-
return new ARPortraitDepthEstimator(
192-
segmenter, depthModel, config.minDepth, config.maxDepth);
196+
return new ARPortraitDepthEstimator(segmenter, depthModel);
193197
}

depth-estimation/src/ar_portrait_depth/estimator_utils.ts

+13-12
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,35 @@
1515
* =============================================================================
1616
*/
1717

18-
import {DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG, DEFAULT_AR_PORTRAIT_DEPTH_MODEL_URL} from './constants';
18+
import {DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG, DEFAULT_AR_PORTRAIT_DEPTH_MODEL_CONFIG} from './constants';
1919
import {ARPortraitDepthEstimationConfig, ARPortraitDepthModelConfig} from './types';
2020

2121
export function validateModelConfig(modelConfig: ARPortraitDepthModelConfig):
2222
ARPortraitDepthModelConfig {
23-
if (modelConfig == null || modelConfig.minDepth == null ||
24-
modelConfig.maxDepth == null) {
25-
throw new Error(
26-
`A model config with minDepth and maxDepth set must be provided.`);
27-
}
28-
29-
if (modelConfig.minDepth > modelConfig.maxDepth) {
30-
throw new Error('minDepth must be <= maxDepth.');
23+
if (modelConfig == null) {
24+
return {...DEFAULT_AR_PORTRAIT_DEPTH_MODEL_CONFIG};
3125
}
3226

3327
const config = {...modelConfig};
3428

3529
if (config.depthModelUrl == null) {
36-
config.depthModelUrl = DEFAULT_AR_PORTRAIT_DEPTH_MODEL_URL;
30+
config.depthModelUrl = DEFAULT_AR_PORTRAIT_DEPTH_MODEL_CONFIG.depthModelUrl;
3731
}
3832

3933
return config;
4034
}
4135

4236
export function validateEstimationConfig(
4337
estimationConfig: ARPortraitDepthEstimationConfig) {
44-
if (estimationConfig == null) {
45-
return {...DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG};
38+
if (estimationConfig == null || estimationConfig.minDepth == null ||
39+
estimationConfig.maxDepth == null) {
40+
throw new Error(
41+
'An estimation config with ' +
42+
'minDepth and maxDepth set must be provided.');
43+
}
44+
45+
if (estimationConfig.minDepth > estimationConfig.maxDepth) {
46+
throw new Error('minDepth must be <= maxDepth.');
4647
}
4748

4849
const config = {...estimationConfig};

depth-estimation/src/ar_portrait_depth/types.ts

-6
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@ import {EstimationConfig, ModelConfig} from '../types';
2222
/**
2323
* Model parameters for ARPortraitDepth.
2424
*
25-
* `minDepth`: The minimum depth value for the model to map to 0. Any smaller
26-
* depth values will also get mapped to 0.
27-
*
28-
* `maxDepth`: The maximum depth value for the model to map to 1. Any larger
29-
* depth values will also get mapped to 1.
30-
*
3125
* `segmentationModelUrl`: Optional. An optional string that specifies custom
3226
* url of the selfie segmentation model. This is useful for area/countries that
3327
* don't have access to the model hosted on tf.hub.

depth-estimation/src/types.ts

+7-8
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,25 @@ export enum SupportedModels {
2424

2525
/**
2626
* Common config to create the depth estimator.
27+
*/
28+
export interface ModelConfig {}
29+
30+
/**
31+
* Common config for the `estimateDepth` method.
2732
*
2833
* `minDepth`: The minimum depth value for the model to map to 0. Any smaller
2934
* depth values will also get mapped to 0.
3035
*
3136
* `maxDepth`: The maximum depth value for the model to map to 1. Any larger
3237
* depth values will also get mapped to 1.
33-
*/
34-
export interface ModelConfig {
35-
minDepth: number;
36-
maxDepth: number;
37-
}
38-
39-
/**
40-
* Common config for the `estimateDepth` method.
4138
*
4239
* `flipHorizontal`: Optional. Default to false. In some cases, the image is
4340
* mirrored, e.g., video stream from camera, flipHorizontal will flip the
4441
* keypoints horizontally.
4542
*/
4643
export interface EstimationConfig {
44+
minDepth: number;
45+
maxDepth: number;
4746
flipHorizontal?: boolean;
4847
}
4948

depth-estimation/src/version.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/** @license See the LICENSE file. */
22

33
// This code is auto-generated, do not modify this file!
4-
const version = '0.0.1';
4+
const version = '0.0.2';
55
export {version};

0 commit comments

Comments
 (0)