-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsegment.py
More file actions
38 lines (22 loc) · 986 Bytes
/
segment.py
File metadata and controls
38 lines (22 loc) · 986 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
from sam.segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from huggingface_hub import hf_hub_download
from pad import padding
from PIL import Image
def segmentation(img0):
image = np.array(img0)
image=padding(image)
max_y=image.shape[0]
max_x=image.shape[1]
chkpt_path = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_b_01ec64.pth")
sam_checkpoint = chkpt_path
model_type = "vit_b"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
for i in range(len(masks)):
if 4>masks[i]["bbox"][0]>-1 and 4>masks[i]["bbox"][1]>-1 and max_x+3>masks[i]["bbox"][2]>max_x-2 and max_y+3>masks[i]["bbox"][3]>max_y-2:
out=Image.fromarray(masks[i]['segmentation'])
return image,out