-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
zhujiaxing
committed
Oct 28, 2024
1 parent
ed9b849
commit adb7337
Showing
66 changed files
with
8,449 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
# Segment Anything | ||
|
||
The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks. | ||
|
||
## Installation | ||
|
||
The code requires `python>=3.7` and supports Ascend platform, some important pre-dependencies is: | ||
1. mindspore: Please follow the instructions [here](https://www.mindspore.cn/install) to install mindspore dependencies. | ||
2. mindformers: please follow the instructions [here](https://gitee.com/mindspore/mindformers) using source code to install mindformers, | ||
|
||
Clone the repository locally and install with | ||
|
||
```shell | ||
git clone https://github.com/Mark-ZhouWX/models.git | ||
cd models/official/cv/segment-anything | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Finetune | ||
|
||
Finetune is a popular method that adapts large pretrained model to specific downstream tasks. Currently, finetune with box-prompt and text-prompt is supported. | ||
|
||
*Note that finetune of SAM is not open-source at [official implementation of pytorch](https://github.com/facebookresearch/segment-anything). | ||
In this repository, finetune is an experimental function and still under improvement* | ||
|
||
### Finetune with box-prompt | ||
The bounding boxes are used as prompt input to predict mask. | ||
Beside fine-tuning our code on COCO2017 dataset which contains common seen objects and lies in the similar distribution of the original [training dataset](https://segment-anything.com/dataset/index.html) of SAM, We have done further experiments on a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/). Result shows that the finetune method in this repository is effective. | ||
|
||
The bellowing shows the mask quality before and after finetune. | ||
|
||
|
||
| pretrained_model | dataset | epochs | mIOU | ckpt | | ||
|:----------------:| -------- |:-------------:|------|--------------------------------------------------------------------------------------------------------------| | ||
| sam-vit-b | COCO2017 | 0 (zero-shot) | 74.5 | | | ||
| sam-vit-b | COCO2017 | 8 | 80.1 | [link](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_box_finetune_coco-a9b75828.ckpt) | | ||
| sam-vit-b | FLARE22 | 0 (zero-shot) | 78.6 | | | ||
| sam-vit-b | FLARE22 | 20 | 88.4 | [link](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_box_finetune_flare-ace06cc2.ckpt) | | ||
|
||
A machine with **32G ascend memory** is required for box-prompt finetune. | ||
|
||
for standalone finetune of COCO dataset, please run: | ||
```shell | ||
python train.py -c configs/coco_box_finetune.yaml -o amp_level=O2 | ||
``` | ||
|
||
for distributed finetune of COCO dataset, please run: | ||
```shell | ||
msrun --worker_num=8 --local_worker_num=8 train.py -c configs/coco_box_finetune.yaml -o amp_level=O2 | ||
``` | ||
the fine-tuned model will be saved at the work_root specified in `configs/coco_box_finetune.yaml`. to eval the model, please run: | ||
```shell | ||
python eval.py -c configs/coco_box_finetune.yaml -o amp_level=O2 network.model.checkpoint=your/path/to/ckpt | ||
``` | ||
for a fast single image inference, please run, | ||
```shell | ||
python box_inference.py --amp_level=O2 --checkpoint=your/path/to/ckpt | ||
``` | ||
|
||
The original FLARE22 dataset contains image in 3D format and ground truth labelled as instance segmentation ids. Run | ||
|
||
```shell | ||
python scripts/preprocess_CT_MR_dataset.py | ||
``` | ||
|
||
to preprocess it to the format of 2D RGB image and binary mask | ||
|
||
The following steps are similar to COCO dataset finetune, please refer to the aforementioned description. | ||
|
||
Here are the examples of segmentation result predicted by box-prompt fine-tuned SAM: | ||
|
||
<div align="center"> | ||
<img src="./images/coco_bear.jpg" height="350" /> | ||
|
||
<img src="images/flare_organ.jpg" height="350" /> | ||
</div> | ||
|
||
<p align="center"> | ||
<em> COCO2017 image example</em> | ||
|
||
|
||
<em> FLARE22 image example </em> | ||
</p> | ||
|
||
### Finetune with point-prompt | ||
The point in addition to the previous-step-output mask are used as prompt input to predict mask. | ||
We follow an iterative interactive training schedule described in the official SAM paper. First a foreground point is sampled uniformly from the ground truth mask. After making a prediction, | ||
subsequent points are selected uniformly from the error region between the previous mask prediction and the ground truth mask. Each new point is a foreground or background if the error region is a false negative or false positive. | ||
The mask prediction from the previous iteration is used as an additional prompt. In order to encourage the model to benefit from the supplied mask, several more iterations are used where no additional points are sampled. | ||
The total iteration number and the position where mask-only iterations are inserted is configurable. | ||
|
||
Since the original training dataset (SA-1B) is almost of common objects, we use a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/) (preprocess the raw dataset as mentioned in the last chapter) for the finetune experiment. | ||
We note that SAM model express strong zero-shot ability and the finetune process may learn mainly the labelling bias for most downstream datasets. | ||
|
||
for standalone finetune of FLARE22 dataset, please run: | ||
```shell | ||
python train.py -c configs/sa1b_point_finetune.yaml | ||
``` | ||
|
||
for distributed finetune of FLARE22 dataset, please run: | ||
```shell | ||
msrun --worker_num=8 --local_worker_num=8 train.py -c configs/sa1b_point_finetune.yaml | ||
``` | ||
|
||
the fine-tuned model will be saved at the work_root specified in `configs/sa1b_point_finetune.yaml`. For a fast single image inference, please run, | ||
|
||
```shell | ||
python point_inference.py --checkpoint=your/path/to/ckpt | ||
``` | ||
|
||
Below is an experimental result batch-prompted with 5 points and the model is trained at scale `vit_b`. The checkpoint can be downloaded [here](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_point_finetune_flare-898ae8f6.ckpt). | ||
<div align="center"> | ||
<img alt="img.png" src="images/tumor2_5point.png" width="600"/> | ||
</div> | ||
|
||
Explore more interesting applications such as iterative positive and negative points prompting described in the following Demo Chapter. | ||
|
||
### Finetune with text-prompt | ||
*Note again that text-to-mask finetune is exploratory and not robust, and the official pytorch code is not release yet.* | ||
|
||
|
||
The training procedure described in the official SAM paper is quite interesting that does not require new text annotation. Specifically, for each manually collected mask with area larger than 1002 we extract the CLIP image embedding. Then, during training, we prompt SAM | ||
with the extracted CLIP image embeddings as text prompt input. At inference time we run text through CLIP’s text encoder and then give the resulting text embedding as a prompt to SAM | ||
|
||
The key that make the training procedure work is that CLIP’s image embeddings are trained to align with its text embeddings. | ||
|
||
This repository provides an implementation of text-to-mask finetune referring to the model structure and training procedure described in the official SAM paper and introduces a stronger multimodal encoder BLIP2 in addition to CLIP. | ||
|
||
A machine with **64G ascend memory** is required for text-prompt finetune. | ||
|
||
First download SA-1B dataset and put it under `${project_root}/datasets/sa-1b`. | ||
|
||
for standalone finetune of SA-1B dataset with BLIP2 (CLIP is similar), please run: | ||
```shell | ||
python train.py -c configs/sa1b_text_finetune_blip2.yaml | ||
``` | ||
the BLIP2 checkpoint and bert vocabulary.txt will be automatically downloaded at `./checkpoint_download/` | ||
|
||
for distributed finetune, please run: | ||
```shell | ||
msrun --worker_num=8 --local_worker_num=8 train.py -c configs/sa1b_text_finetune_blip2.yaml | ||
``` | ||
the fine-tuned model will be saved at the work_root specified in `configs/sa1b_text_finetune.yaml`. For a fast single image inference, please run, | ||
|
||
```shell | ||
python text_inference.py --checkpoint=your/path/to/ckpt --text-prompt your_prompt | ||
``` | ||
|
||
## Demo | ||
|
||
First download the weights ([sam_vit_b](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_b-35e4849c.ckpt), [sam_vit_l](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_l-1b460f38.ckpt), [sam_vit_h](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_h-c72f8ba1.ckpt)) and put them under `${project_root}/models` directory. | ||
There are two recommended ways to use sam. | ||
|
||
### Using sam with prompts | ||
|
||
#### predict one object at one time | ||
|
||
1. points | ||
|
||
SAM predicts object masks given prompts that indicate the desired object. if a point prompt is given, three plausible masks are generated. | ||
|
||
```shell | ||
python demo/inference_with_promts.py --prompt-type point --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt | ||
``` | ||
|
||
<p float="left"> | ||
<img src=images/truck_mask1.png width="400"/><img src=images/truck_mask2.png width="400"/><img src=images/truck_mask3.png width="400"/> | ||
</p> | ||
|
||
If a prompt with two points is given, one plausible mask is generated instead of 3 because of less ambiguity compared to one point prompt. | ||
The star in green and red denotes positive and negtive point, respectively. | ||
|
||
<div align="center"> | ||
<img alt="img.png" src="images/truck_two_point.png" width="600"/> | ||
</div> | ||
|
||
2. one box | ||
|
||
If a box prompt is given, one plausible masks is generated. | ||
|
||
```shell | ||
python demo/inference_with_promts.py --prompt-type box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt | ||
``` | ||
|
||
<div align="center"> | ||
<img alt="img.png" width="600" src="images/truck_box.png"/> | ||
</div> | ||
|
||
3. one box and one point | ||
|
||
If a prompt with both a box and a point is given, one plausible mask is generated. | ||
|
||
```shell | ||
python demo/inference_with_promts.py --prompt-type point_box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt | ||
``` | ||
|
||
<div align="center"> | ||
<img alt="img.png" width="600" src="images/truck_point_box.png"/> | ||
</div> | ||
|
||
#### predict multiple objects at one time in a batch way | ||
|
||
1. batch point | ||
|
||
```shell | ||
python demo/inference_with_promts.py --prompt-type batch_point --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt | ||
``` | ||
|
||
<div align="center"> | ||
<img alt="img.png" src="images/truck_batch_point.png" width="600"/> | ||
</div> | ||
|
||
2. batch box | ||
|
||
```shell | ||
python demo/inference_with_promts.py --prompt-type batch_box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt | ||
``` | ||
|
||
<div align="center"> | ||
<img alt="img.png" width="600" src="images/truck_batch_box.png"/> | ||
</div> | ||
|
||
3. batch box and point | ||
|
||
```shell | ||
python demo/inference_with_promts.py --prompt-type batch_point_box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt | ||
``` | ||
|
||
<div align="center"> | ||
<img alt="img.png" width="600" src="images/truck_batch_point_box.png"/> | ||
</div> | ||
|
||
See `python demo/inference_with_promts.py --help` to explore more custom settings. | ||
|
||
### Using sam with Automatic Mask Generation(AMG) | ||
|
||
Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image. AMG works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes. | ||
|
||
```shell | ||
python demo/inference_with_amg.py --model-type vit_h | ||
``` | ||
|
||
See `python demo/inference_with_amg.py --help` to explore more custom settings. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import argparse | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
import mindspore as ms | ||
|
||
from segment_anything.build_sam import sam_model_registry | ||
from segment_anything.dataset.transform import TransformPipeline, ImageNorm, ImageResizeAndPad | ||
import matplotlib.pyplot as plt | ||
|
||
from segment_anything.utils.utils import Timer | ||
from segment_anything.utils.visualize import show_mask, show_box | ||
|
||
|
||
def infer(args): | ||
ms.context.set_context(mode=args.mode, device_target=args.device) | ||
|
||
# Step1: data preparation | ||
with Timer('preprocess'): | ||
transform_list = [ | ||
ImageResizeAndPad(target_size=1024, apply_mask=False), | ||
ImageNorm(), | ||
] | ||
transform_pipeline = TransformPipeline(transform_list) | ||
|
||
image_path = args.image_path | ||
image_np = cv2.imread(image_path) | ||
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | ||
boxes_np = np.array([[425, 600, 700, 875]]) | ||
|
||
transformed = transform_pipeline(dict(image=image_np, boxes=boxes_np)) | ||
image, boxes, origin_hw = transformed['image'], transformed['boxes'], transformed['origin_hw'] | ||
# batch_size for speed test | ||
# image = ms.Tensor(np.expand_dims(image, 0).repeat(8, axis=0)) # b, 3, 1023 | ||
# boxes = ms.Tensor(np.expand_dims(boxes, 0).repeat(8, axis=0)) # b, n, 4 | ||
image = ms.Tensor(image).unsqueeze(0) # b, 3, 1023 | ||
boxes = ms.Tensor(boxes).unsqueeze(0) # b, n, 4 | ||
|
||
# Step2: inference | ||
with Timer('model inference'): | ||
with Timer('load weight and build net'): | ||
network = sam_model_registry[args.model_type](checkpoint=args.checkpoint) | ||
ms.amp.auto_mixed_precision(network=network, amp_level=args.amp_level) | ||
mask_logits = network(image, boxes=boxes)[0] # (1, 1, 1024, 1024) | ||
|
||
with Timer('Second time inference'): | ||
mask_logits = network(image, boxes=boxes)[0] # (1, 1, 1024, 1024) | ||
|
||
# Step3: post-process | ||
with Timer('post-process'): | ||
mask_logits = mask_logits.asnumpy()[0, 0] > 0.0 | ||
mask_logits = mask_logits.astype(np.uint8) | ||
final_mask = cv2.resize(mask_logits[:origin_hw[2], :origin_hw[3]], tuple((origin_hw[1], origin_hw[0])), | ||
interpolation=cv2.INTER_CUBIC) | ||
|
||
# Step4: visualize | ||
plt.imshow(image_np) | ||
show_box(boxes_np[0], plt.gca()) | ||
show_mask(final_mask, plt.gca()) | ||
plt.savefig(args.image_path + '_infer.jpg') | ||
plt.show() | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description=("Runs inference on one image")) | ||
parser.add_argument("--image_path", type=str, default='./images/truck.jpg', help="Path to an input image.") | ||
parser.add_argument( | ||
"--model-type", | ||
type=str, | ||
default='vit_b', | ||
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']", | ||
) | ||
|
||
parser.add_argument( | ||
"--checkpoint", | ||
type=str, | ||
default='./models/sam_vit_b-35e4849c.ckpt', | ||
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", | ||
) | ||
|
||
parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.") | ||
parser.add_argument("--amp_level", type=str, default="O0", help="auto mixed precision level O0, O2.") | ||
parser.add_argument("--mode", type=int, default=0, help="MindSpore context mode. 0 for graph, 1 for pynative.") | ||
|
||
args = parser.parse_args() | ||
print(args) | ||
infer(args) |
Oops, something went wrong.