Skip to content

Commit

Permalink
Add SAM.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhujiaxing committed Oct 28, 2024
1 parent ed9b849 commit adb7337
Show file tree
Hide file tree
Showing 66 changed files with 8,449 additions and 0 deletions.
243 changes: 243 additions & 0 deletions official/cv/segment-anything/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# Segment Anything

The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.

## Installation

The code requires `python>=3.7` and supports Ascend platform, some important pre-dependencies is:
1. mindspore: Please follow the instructions [here](https://www.mindspore.cn/install) to install mindspore dependencies.
2. mindformers: please follow the instructions [here](https://gitee.com/mindspore/mindformers) using source code to install mindformers,

Clone the repository locally and install with

```shell
git clone https://github.com/Mark-ZhouWX/models.git
cd models/official/cv/segment-anything
pip install -r requirements.txt
```

## Finetune

Finetune is a popular method that adapts large pretrained model to specific downstream tasks. Currently, finetune with box-prompt and text-prompt is supported.

*Note that finetune of SAM is not open-source at [official implementation of pytorch](https://github.com/facebookresearch/segment-anything).
In this repository, finetune is an experimental function and still under improvement*

### Finetune with box-prompt
The bounding boxes are used as prompt input to predict mask.
Beside fine-tuning our code on COCO2017 dataset which contains common seen objects and lies in the similar distribution of the original [training dataset](https://segment-anything.com/dataset/index.html) of SAM, We have done further experiments on a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/). Result shows that the finetune method in this repository is effective.

The bellowing shows the mask quality before and after finetune.


| pretrained_model | dataset | epochs | mIOU | ckpt |
|:----------------:| -------- |:-------------:|------|--------------------------------------------------------------------------------------------------------------|
| sam-vit-b | COCO2017 | 0 (zero-shot) | 74.5 | |
| sam-vit-b | COCO2017 | 8 | 80.1 | [link](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_box_finetune_coco-a9b75828.ckpt) |
| sam-vit-b | FLARE22 | 0 (zero-shot) | 78.6 | |
| sam-vit-b | FLARE22 | 20 | 88.4 | [link](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_box_finetune_flare-ace06cc2.ckpt) |

A machine with **32G ascend memory** is required for box-prompt finetune.

for standalone finetune of COCO dataset, please run:
```shell
python train.py -c configs/coco_box_finetune.yaml -o amp_level=O2
```

for distributed finetune of COCO dataset, please run:
```shell
msrun --worker_num=8 --local_worker_num=8 train.py -c configs/coco_box_finetune.yaml -o amp_level=O2
```
the fine-tuned model will be saved at the work_root specified in `configs/coco_box_finetune.yaml`. to eval the model, please run:
```shell
python eval.py -c configs/coco_box_finetune.yaml -o amp_level=O2 network.model.checkpoint=your/path/to/ckpt
```
for a fast single image inference, please run,
```shell
python box_inference.py --amp_level=O2 --checkpoint=your/path/to/ckpt
```

The original FLARE22 dataset contains image in 3D format and ground truth labelled as instance segmentation ids. Run

```shell
python scripts/preprocess_CT_MR_dataset.py
```

to preprocess it to the format of 2D RGB image and binary mask

The following steps are similar to COCO dataset finetune, please refer to the aforementioned description.

Here are the examples of segmentation result predicted by box-prompt fine-tuned SAM:

<div align="center">
<img src="./images/coco_bear.jpg" height="350" />

<img src="images/flare_organ.jpg" height="350" />
</div>

<p align="center">
<em> COCO2017 image example</em>


<em> FLARE22 image example </em>
</p>

### Finetune with point-prompt
The point in addition to the previous-step-output mask are used as prompt input to predict mask.
We follow an iterative interactive training schedule described in the official SAM paper. First a foreground point is sampled uniformly from the ground truth mask. After making a prediction,
subsequent points are selected uniformly from the error region between the previous mask prediction and the ground truth mask. Each new point is a foreground or background if the error region is a false negative or false positive.
The mask prediction from the previous iteration is used as an additional prompt. In order to encourage the model to benefit from the supplied mask, several more iterations are used where no additional points are sampled.
The total iteration number and the position where mask-only iterations are inserted is configurable.

Since the original training dataset (SA-1B) is almost of common objects, we use a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/) (preprocess the raw dataset as mentioned in the last chapter) for the finetune experiment.
We note that SAM model express strong zero-shot ability and the finetune process may learn mainly the labelling bias for most downstream datasets.

for standalone finetune of FLARE22 dataset, please run:
```shell
python train.py -c configs/sa1b_point_finetune.yaml
```

for distributed finetune of FLARE22 dataset, please run:
```shell
msrun --worker_num=8 --local_worker_num=8 train.py -c configs/sa1b_point_finetune.yaml
```

the fine-tuned model will be saved at the work_root specified in `configs/sa1b_point_finetune.yaml`. For a fast single image inference, please run,

```shell
python point_inference.py --checkpoint=your/path/to/ckpt
```

Below is an experimental result batch-prompted with 5 points and the model is trained at scale `vit_b`. The checkpoint can be downloaded [here](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_point_finetune_flare-898ae8f6.ckpt).
<div align="center">
<img alt="img.png" src="images/tumor2_5point.png" width="600"/>
</div>

Explore more interesting applications such as iterative positive and negative points prompting described in the following Demo Chapter.

### Finetune with text-prompt
*Note again that text-to-mask finetune is exploratory and not robust, and the official pytorch code is not release yet.*


The training procedure described in the official SAM paper is quite interesting that does not require new text annotation. Specifically, for each manually collected mask with area larger than 1002 we extract the CLIP image embedding. Then, during training, we prompt SAM
with the extracted CLIP image embeddings as text prompt input. At inference time we run text through CLIP’s text encoder and then give the resulting text embedding as a prompt to SAM

The key that make the training procedure work is that CLIP’s image embeddings are trained to align with its text embeddings.

This repository provides an implementation of text-to-mask finetune referring to the model structure and training procedure described in the official SAM paper and introduces a stronger multimodal encoder BLIP2 in addition to CLIP.

A machine with **64G ascend memory** is required for text-prompt finetune.

First download SA-1B dataset and put it under `${project_root}/datasets/sa-1b`.

for standalone finetune of SA-1B dataset with BLIP2 (CLIP is similar), please run:
```shell
python train.py -c configs/sa1b_text_finetune_blip2.yaml
```
the BLIP2 checkpoint and bert vocabulary.txt will be automatically downloaded at `./checkpoint_download/`

for distributed finetune, please run:
```shell
msrun --worker_num=8 --local_worker_num=8 train.py -c configs/sa1b_text_finetune_blip2.yaml
```
the fine-tuned model will be saved at the work_root specified in `configs/sa1b_text_finetune.yaml`. For a fast single image inference, please run,

```shell
python text_inference.py --checkpoint=your/path/to/ckpt --text-prompt your_prompt
```

## Demo

First download the weights ([sam_vit_b](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_b-35e4849c.ckpt), [sam_vit_l](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_l-1b460f38.ckpt), [sam_vit_h](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_h-c72f8ba1.ckpt)) and put them under `${project_root}/models` directory.
There are two recommended ways to use sam.

### Using sam with prompts

#### predict one object at one time

1. points

SAM predicts object masks given prompts that indicate the desired object. if a point prompt is given, three plausible masks are generated.

```shell
python demo/inference_with_promts.py --prompt-type point --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<p float="left">
<img src=images/truck_mask1.png width="400"/><img src=images/truck_mask2.png width="400"/><img src=images/truck_mask3.png width="400"/>
</p>

If a prompt with two points is given, one plausible mask is generated instead of 3 because of less ambiguity compared to one point prompt.
The star in green and red denotes positive and negtive point, respectively.

<div align="center">
<img alt="img.png" src="images/truck_two_point.png" width="600"/>
</div>

2. one box

If a box prompt is given, one plausible masks is generated.

```shell
python demo/inference_with_promts.py --prompt-type box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" width="600" src="images/truck_box.png"/>
</div>

3. one box and one point

If a prompt with both a box and a point is given, one plausible mask is generated.

```shell
python demo/inference_with_promts.py --prompt-type point_box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" width="600" src="images/truck_point_box.png"/>
</div>

#### predict multiple objects at one time in a batch way

1. batch point

```shell
python demo/inference_with_promts.py --prompt-type batch_point --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" src="images/truck_batch_point.png" width="600"/>
</div>

2. batch box

```shell
python demo/inference_with_promts.py --prompt-type batch_box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" width="600" src="images/truck_batch_box.png"/>
</div>

3. batch box and point

```shell
python demo/inference_with_promts.py --prompt-type batch_point_box --model-type vit_h --checkpoint models/sam_vit_h-c72f8ba1.ckpt
```

<div align="center">
<img alt="img.png" width="600" src="images/truck_batch_point_box.png"/>
</div>

See `python demo/inference_with_promts.py --help` to explore more custom settings.

### Using sam with Automatic Mask Generation(AMG)

Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image. AMG works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes.

```shell
python demo/inference_with_amg.py --model-type vit_h
```

See `python demo/inference_with_amg.py --help` to explore more custom settings.
88 changes: 88 additions & 0 deletions official/cv/segment-anything/box_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import argparse

import cv2
import numpy as np

import mindspore as ms

from segment_anything.build_sam import sam_model_registry
from segment_anything.dataset.transform import TransformPipeline, ImageNorm, ImageResizeAndPad
import matplotlib.pyplot as plt

from segment_anything.utils.utils import Timer
from segment_anything.utils.visualize import show_mask, show_box


def infer(args):
ms.context.set_context(mode=args.mode, device_target=args.device)

# Step1: data preparation
with Timer('preprocess'):
transform_list = [
ImageResizeAndPad(target_size=1024, apply_mask=False),
ImageNorm(),
]
transform_pipeline = TransformPipeline(transform_list)

image_path = args.image_path
image_np = cv2.imread(image_path)
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
boxes_np = np.array([[425, 600, 700, 875]])

transformed = transform_pipeline(dict(image=image_np, boxes=boxes_np))
image, boxes, origin_hw = transformed['image'], transformed['boxes'], transformed['origin_hw']
# batch_size for speed test
# image = ms.Tensor(np.expand_dims(image, 0).repeat(8, axis=0)) # b, 3, 1023
# boxes = ms.Tensor(np.expand_dims(boxes, 0).repeat(8, axis=0)) # b, n, 4
image = ms.Tensor(image).unsqueeze(0) # b, 3, 1023
boxes = ms.Tensor(boxes).unsqueeze(0) # b, n, 4

# Step2: inference
with Timer('model inference'):
with Timer('load weight and build net'):
network = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
ms.amp.auto_mixed_precision(network=network, amp_level=args.amp_level)
mask_logits = network(image, boxes=boxes)[0] # (1, 1, 1024, 1024)

with Timer('Second time inference'):
mask_logits = network(image, boxes=boxes)[0] # (1, 1, 1024, 1024)

# Step3: post-process
with Timer('post-process'):
mask_logits = mask_logits.asnumpy()[0, 0] > 0.0
mask_logits = mask_logits.astype(np.uint8)
final_mask = cv2.resize(mask_logits[:origin_hw[2], :origin_hw[3]], tuple((origin_hw[1], origin_hw[0])),
interpolation=cv2.INTER_CUBIC)

# Step4: visualize
plt.imshow(image_np)
show_box(boxes_np[0], plt.gca())
show_mask(final_mask, plt.gca())
plt.savefig(args.image_path + '_infer.jpg')
plt.show()


if __name__ == '__main__':
parser = argparse.ArgumentParser(description=("Runs inference on one image"))
parser.add_argument("--image_path", type=str, default='./images/truck.jpg', help="Path to an input image.")
parser.add_argument(
"--model-type",
type=str,
default='vit_b',
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']",
)

parser.add_argument(
"--checkpoint",
type=str,
default='./models/sam_vit_b-35e4849c.ckpt',
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
)

parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.")
parser.add_argument("--amp_level", type=str, default="O0", help="auto mixed precision level O0, O2.")
parser.add_argument("--mode", type=int, default=0, help="MindSpore context mode. 0 for graph, 1 for pynative.")

args = parser.parse_args()
print(args)
infer(args)
Loading

0 comments on commit adb7337

Please sign in to comment.