Skip to content

Commit 7f2bbcf

Browse files
GeauxEricEricDingNVDLaurentMazare
authored
[segment-anything] Support multi-point as the prompt input (huggingface#945)
* [sam] Support multi-point prompts * [segment-anything] Pass points by reference * [segment-anything] Update example code and image * Fix clippy lint. --------- Co-authored-by: Yun Ding <[email protected]> Co-authored-by: laurent <[email protected]>
1 parent dc47224 commit 7f2bbcf

File tree

6 files changed

+55
-34
lines changed

6 files changed

+55
-34
lines changed

candle-examples/examples/segment-anything/README.md

+11-6
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,30 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
1616
cargo run --example segment-anything --release -- \
1717
--image candle-examples/examples/yolo-v8/assets/bike.jpg
1818
--use-tiny
19-
--point-x 0.4
20-
--point-y 0.3
19+
--point-x 0.6,0.6
20+
--point-y 0.6,0.55
2121
```
2222

2323
Running this command generates a `sam_merged.jpg` file containing the original
24-
image with a blue overlay of the selected mask. The red dot represents the prompt
25-
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
24+
image with a blue overlay of the selected mask. The red dots represent the prompt
25+
specified by `--point-x 0.6,0.6 --point-y 0.6,0.55`, this prompt is assumed to be part
2626
of the target mask.
2727

2828
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
2929
are proportional to the image dimension, i.e. use 0.5 for the image center.
3030

31+
Original image:
3132
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
3233

33-
![Leading group, Giro d'Italia 2021](./assets/sam_merged.jpg)
34+
Segment results by prompting with a single point `--point-x 0.6 --point-y 0.55`:
35+
![Leading group, Giro d'Italia 2021](./assets/single_pt_prompt.jpg)
36+
37+
Segment results by prompting with multiple points `--point-x 0.6,0.6 --point-y 0.6,0.55`:
38+
![Leading group, Giro d'Italia 2021](./assets/two_pt_prompt.jpg)
3439

3540
### Command-line flags
3641
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
3742
one.
38-
- `--point-x`, `--point-y`: specifies the location of the target point.
43+
- `--point-x`, `--point-y`: specifies the location of the target points.
3944
- `--threshold`: sets the threshold value to be part of the mask, a negative
4045
value results in a larger mask and can be specified via `--threshold=-1.2`.
Loading
Loading

candle-examples/examples/segment-anything/main.rs

+26-14
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ struct Args {
2727
#[arg(long)]
2828
generate_masks: bool,
2929

30-
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
31-
#[arg(long, default_value_t = 0.5)]
32-
point_x: f64,
30+
/// Comma separated list of x coordinates, between 0 and 1 (0.5 is at the middle of the image).
31+
#[arg(long, use_value_delimiter = true)]
32+
point_x: Vec<f64>,
3333

34-
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
35-
#[arg(long, default_value_t = 0.5)]
36-
point_y: f64,
34+
/// Comma separated list of y coordinate, between 0 and 1 (0.5 is at the middle of the image).
35+
#[arg(long, use_value_delimiter = true)]
36+
point_y: Vec<f64>,
3737

3838
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
3939
/// mask, positive makes the mask more selective.
@@ -111,9 +111,16 @@ pub fn main() -> anyhow::Result<()> {
111111
)?;
112112
}
113113
} else {
114-
let point = Some((args.point_x, args.point_y));
114+
if args.point_x.len() != args.point_y.len() {
115+
anyhow::bail!(
116+
"number of x coordinates unequal to the number of y coordinates: {} v.s. {}",
117+
args.point_x.len(),
118+
args.point_y.len()
119+
);
120+
}
121+
let points: Vec<(f64, f64)> = args.point_x.into_iter().zip(args.point_y).collect();
115122
let start_time = std::time::Instant::now();
116-
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
123+
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
117124
println!(
118125
"mask generated in {:.2}s",
119126
start_time.elapsed().as_secs_f32()
@@ -151,12 +158,17 @@ pub fn main() -> anyhow::Result<()> {
151158
}
152159
}
153160
}
154-
let (x, y) = (
155-
(args.point_x * img.width() as f64) as i32,
156-
(args.point_y * img.height() as f64) as i32,
157-
);
158-
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
159-
.save("sam_merged.jpg")?
161+
for (x, y) in points {
162+
let x = (x * img.width() as f64) as i32;
163+
let y = (y * img.height() as f64) as i32;
164+
imageproc::drawing::draw_filled_circle_mut(
165+
&mut img,
166+
(x, y),
167+
3,
168+
image::Rgba([255, 0, 0, 200]),
169+
);
170+
}
171+
img.save("sam_merged.jpg")?
160172
}
161173
Ok(())
162174
}

candle-transformers/src/models/segment_anything/sam.rs

+17-13
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ impl Sam {
130130
pub fn forward(
131131
&self,
132132
img: &Tensor,
133-
point: Option<(f64, f64)>,
133+
points: &[(f64, f64)],
134134
multimask_output: bool,
135135
) -> Result<(Tensor, Tensor)> {
136136
let (_c, original_h, original_w) = img.dims3()?;
@@ -140,7 +140,7 @@ impl Sam {
140140
&img_embeddings,
141141
original_h,
142142
original_w,
143-
point,
143+
points,
144144
multimask_output,
145145
)?;
146146
let mask = low_res_mask
@@ -155,20 +155,24 @@ impl Sam {
155155
img_embeddings: &Tensor,
156156
original_h: usize,
157157
original_w: usize,
158-
point: Option<(f64, f64)>,
158+
points: &[(f64, f64)],
159159
multimask_output: bool,
160160
) -> Result<(Tensor, Tensor)> {
161161
let image_pe = self.prompt_encoder.get_dense_pe()?;
162-
let points = match point {
163-
None => None,
164-
Some((x, y)) => {
165-
let points = Tensor::new(
166-
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
167-
img_embeddings.device(),
168-
)?;
169-
let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
170-
Some((points, labels))
171-
}
162+
let points = if points.is_empty() {
163+
None
164+
} else {
165+
let n_points = points.len();
166+
let mut coords = vec![];
167+
points.iter().for_each(|(x, y)| {
168+
let x = (*x as f32) * (original_w as f32);
169+
let y = (*y as f32) * (original_h as f32);
170+
coords.push(x);
171+
coords.push(y);
172+
});
173+
let points = Tensor::from_vec(coords, (n_points, 1, 2), img_embeddings.device())?;
174+
let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
175+
Some((points, labels))
172176
};
173177
let points = points.as_ref().map(|(x, y)| (x, y));
174178
let (sparse_prompt_embeddings, dense_prompt_embeddings) =

candle-wasm-examples/segment-anything/src/bin/m.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ impl Model {
9494
&embeddings.data,
9595
embeddings.height as usize,
9696
embeddings.width as usize,
97-
Some((x, y)),
97+
&[(x, y)],
9898
false,
9999
)?;
100100
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];

0 commit comments

Comments
 (0)