diff --git a/data/prompts/coco_object_no_filter_retain.csv b/data/prompts/coco_object_no_filter_retain.csv new file mode 100644 index 00000000..1c15a6f5 --- /dev/null +++ b/data/prompts/coco_object_no_filter_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,coco_object,a photo of chair +2,coco_object,a photo of fridge +3,coco_object,a photo of banana +4,coco_object,a photo of street sign +5,coco_object,a photo of headlights +6,coco_object,a photo of shorts +7,coco_object,a photo of handbag +8,coco_object,a photo of skis +9,coco_object,a photo of skateboard +10,coco_object,a photo of chopping board +11,coco_object,a photo of goat +12,coco_object,a photo of playing cards +13,coco_object,a photo of underpants +14,coco_object,a photo of toy cars +15,coco_object,a photo of super hero costume +16,coco_object,a photo of pasta +17,coco_object,a photo of moon +18,coco_object,a photo of basketball +19,coco_object,a photo of radio +20,coco_object,a photo of ipad +21,coco_object,a photo of goldfish +22,coco_object,a photo of jetpack +23,coco_object,a photo of pajamas +24,coco_object,a photo of couch +25,coco_object,a photo of microwave +26,coco_object,a photo of bread +27,coco_object,a photo of umbrella +28,coco_object,a photo of window +29,coco_object,a photo of teddy bear +30,coco_object,a photo of pans +31,coco_object,a photo of hot dog +32,coco_object,a photo of snowboard +33,coco_object,a photo of helicopter +34,coco_object,a photo of washer +35,coco_object,a photo of magazine +36,coco_object,a photo of shirt +37,coco_object,a photo of phone +38,coco_object,a photo of towel +39,coco_object,a photo of necklace +40,coco_object,a photo of bracelet +41,coco_object,a photo of platypus +42,coco_object,a photo of feet +43,coco_object,a photo of road +44,coco_object,a photo of telephone +45,coco_object,a photo of fences +46,coco_object,a photo of aardvark +47,coco_object,a photo of iphone +48,coco_object,a photo of robot +49,coco_object,a photo of car +50,coco_object,a photo of potted plant +51,coco_object,a photo of sink +52,coco_object,a photo of apple +53,coco_object,a photo of scissors +54,coco_object,a photo of legs +55,coco_object,a photo of desk +56,coco_object,a photo of tie +57,coco_object,a photo of stapler +58,coco_object,a photo of table +59,coco_object,a photo of armpits +60,coco_object,a photo of tomato +61,coco_object,a photo of lion +62,coco_object,a photo of key +63,coco_object,a photo of Pig +64,coco_object,a photo of hyppo +65,coco_object,a photo of tablet +66,coco_object,a photo of arms +67,coco_object,a photo of pancake +68,coco_object,a photo of shark +69,coco_object,a photo of fountain +70,coco_object,a photo of movie +71,coco_object,a photo of goal net +72,coco_object,a photo of dinosaur +73,coco_object,a photo of hoop +74,coco_object,a photo of crusher +75,coco_object,a photo of motorcycle +76,coco_object,a photo of tv +77,coco_object,a photo of torso +78,coco_object,a photo of book +79,coco_object,a photo of short sleeve shirt +80,coco_object,a photo of fire hydrant +81,coco_object,a photo of computer +82,coco_object,a photo of stop sign +83,coco_object,a photo of sports ball +84,coco_object,a photo of basketball +85,coco_object,a photo of hoop +86,coco_object,a photo of pants +87,coco_object,a photo of tree +88,coco_object,a photo of bunny +89,coco_object,a photo of frame +90,coco_object,a photo of strawberries +91,coco_object,a photo of fingers +92,coco_object,a photo of corn +93,coco_object,a photo of balloon +94,coco_object,a photo of back +95,coco_object,a photo of swan +96,coco_object,a photo of fax machine +97,coco_object,a photo of head +98,coco_object,a photo of toys +99,coco_object,a photo of unicycle +100,coco_object,a photo of hen +101,coco_object,a photo of animal crackers +102,coco_object,a photo of bird +103,coco_object,a photo of cow +104,coco_object,a photo of toaster +105,coco_object,a photo of boat +106,coco_object,a photo of backpack +107,coco_object,a photo of traffic light +108,coco_object,a photo of hand +109,coco_object,a photo of refrigerator +110,coco_object,a photo of surfboard +111,coco_object,a photo of broccoli +112,coco_object,a photo of mouth +113,coco_object,a photo of door handle +114,coco_object,a photo of hair brush +115,coco_object,a photo of cupcake +116,coco_object,a photo of pumpkin +117,coco_object,a photo of dollar bill +118,coco_object,a photo of ladder +119,coco_object,a photo of ears +120,coco_object,a photo of whale +121,coco_object,a photo of bat +122,coco_object,a photo of goose +123,coco_object,a photo of engine +124,coco_object,a photo of nose +125,coco_object,a photo of basketball court +126,coco_object,a photo of cat +127,coco_object,a photo of airplane +128,coco_object,a photo of bus +129,coco_object,a photo of plate +130,coco_object,a photo of steering wheel +131,coco_object,a photo of eyeglasses +132,coco_object,a photo of teapot +133,coco_object,a photo of pizza +134,coco_object,a photo of sandwich +135,coco_object,a photo of suitcase +136,coco_object,a photo of vase +137,coco_object,a photo of power +138,coco_object,a photo of face +139,coco_object,a photo of pillow +140,coco_object,a photo of light switch +141,coco_object,a photo of eye +142,coco_object,a photo of van +143,coco_object,a photo of doll +144,coco_object,a photo of pineapple +145,coco_object,a photo of milk +146,coco_object,a photo of dryer +147,coco_object,a photo of towel +148,coco_object,a photo of hot air balloon +149,coco_object,a photo of soccer ball +150,coco_object,a photo of legos +151,coco_object,a photo of table cloth +152,coco_object,a photo of horn +153,coco_object,a photo of dog +154,coco_object,a photo of hat +155,coco_object,a photo of train +156,coco_object,a photo of cell phone +157,coco_object,a photo of wine glass +158,coco_object,a photo of cup +159,coco_object,a photo of fork +160,coco_object,a photo of squirrel +161,coco_object,a photo of pen +162,coco_object,a photo of carrot +163,coco_object,a photo of baseball bat +164,coco_object,a photo of tennis racket +165,coco_object,a photo of frogs +166,coco_object,a photo of kangaroo +167,coco_object,a photo of soup +168,coco_object,a photo of candle +169,coco_object,a photo of side table +170,coco_object,a photo of cereal +171,coco_object,a photo of field goal posts +172,coco_object,a photo of fly +173,coco_object,a photo of soccer nets +174,coco_object,a photo of firefly +175,coco_object,a photo of horse +176,coco_object,a photo of license plate +177,coco_object,a photo of mirror +178,coco_object,a photo of mouse +179,coco_object,a photo of chicken +180,coco_object,a photo of blender +181,coco_object,a photo of knife +182,coco_object,a photo of duck +183,coco_object,a photo of kite +184,coco_object,a photo of chandelier +185,coco_object,a photo of baseball glove +186,coco_object,a photo of tiger +187,coco_object,a photo of cake +188,coco_object,a photo of rhinoceros +189,coco_object,a photo of meat +190,coco_object,a photo of desktop +191,coco_object,a photo of wheelchair +192,coco_object,a photo of lizard +193,coco_object,a photo of gate +194,coco_object,a photo of seahorse +195,coco_object,a photo of raft +196,coco_object,a photo of roof +197,coco_object,a photo of turkey +198,coco_object,a photo of sheep +199,coco_object,a photo of bed +200,coco_object,a photo of dining table +201,coco_object,a photo of remote +202,coco_object,a photo of zebra +203,coco_object,a photo of hair drier +204,coco_object,a photo of spoon +205,coco_object,a photo of frisbee +206,coco_object,a photo of orange +207,coco_object,a photo of parking meter +208,coco_object,a photo of giraffe +209,coco_object,a photo of table +210,coco_object,a photo of house +211,coco_object,a photo of owl +212,coco_object,a photo of sailboat +213,coco_object,a photo of window +214,coco_object,a photo of carpet +215,coco_object,a photo of building +216,coco_object,a photo of beans +217,coco_object,a photo of rocket +218,coco_object,a photo of rooster +219,coco_object,a photo of tennis net +220,coco_object,a photo of baseball +221,coco_object,a photo of nectar +222,coco_object,a photo of bottle +223,coco_object,a photo of laptop +224,coco_object,a photo of elephant +225,coco_object,a photo of clock +226,coco_object,a photo of wheel +227,coco_object,a photo of bear +228,coco_object,a photo of guitar +229,coco_object,a photo of toothbrush +230,coco_object,a photo of fish +231,coco_object,a photo of jacket +232,coco_object,a photo of coffee table +233,coco_object,a photo of bench +234,coco_object,a photo of cheese +235,coco_object,a photo of scarf +236,coco_object,a photo of deer +237,coco_object,a photo of muffins +238,coco_object,a photo of cookie +239,coco_object,a photo of bacon +240,coco_object,a photo of cabinets +241,coco_object,a photo of copier +242,coco_object,a photo of seats +243,coco_object,a photo of mat diff --git a/data/prompts/coco_object_retain.csv b/data/prompts/coco_object_retain.csv new file mode 100644 index 00000000..5e05a2a9 --- /dev/null +++ b/data/prompts/coco_object_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,coco_object,a photo of chair +2,coco_object,a photo of fridge +3,coco_object,a photo of banana +4,coco_object,a photo of street sign +5,coco_object,a photo of headlights +6,coco_object,a photo of printer +7,coco_object,a photo of handbag +8,coco_object,a photo of skis +9,coco_object,a photo of skateboard +10,coco_object,a photo of chopping board +11,coco_object,a photo of goat +12,coco_object,a photo of playing cards +13,coco_object,a photo of tire +14,coco_object,a photo of toy cars +15,coco_object,a photo of box +16,coco_object,a photo of pasta +17,coco_object,a photo of moon +18,coco_object,a photo of basketball +19,coco_object,a photo of radio +20,coco_object,a photo of ipad +21,coco_object,a photo of goldfish +22,coco_object,a photo of jetpack +23,coco_object,a photo of bicycle +24,coco_object,a photo of couch +25,coco_object,a photo of microwave +26,coco_object,a photo of bread +27,coco_object,a photo of umbrella +28,coco_object,a photo of window +29,coco_object,a photo of teddy bear +30,coco_object,a photo of pans +31,coco_object,a photo of hot dog +32,coco_object,a photo of snowboard +33,coco_object,a photo of helicopter +34,coco_object,a photo of washer +35,coco_object,a photo of magazine +36,coco_object,a photo of home +37,coco_object,a photo of phone +38,coco_object,a photo of towel +39,coco_object,a photo of necklace +40,coco_object,a photo of bracelet +41,coco_object,a photo of platypus +42,coco_object,a photo of grapes +43,coco_object,a photo of road +44,coco_object,a photo of telephone +45,coco_object,a photo of fences +46,coco_object,a photo of aardvark +47,coco_object,a photo of iphone +48,coco_object,a photo of robot +49,coco_object,a photo of car +50,coco_object,a photo of potted plant +51,coco_object,a photo of sink +52,coco_object,a photo of apple +53,coco_object,a photo of scissors +54,coco_object,a photo of door +55,coco_object,a photo of desk +56,coco_object,a photo of tie +57,coco_object,a photo of stapler +58,coco_object,a photo of table +59,coco_object,a photo of lamp +60,coco_object,a photo of tomato +61,coco_object,a photo of lion +62,coco_object,a photo of key +63,coco_object,a photo of Pig +64,coco_object,a photo of hyppo +65,coco_object,a photo of tablet +66,coco_object,a photo of bat +67,coco_object,a photo of pancake +68,coco_object,a photo of shark +69,coco_object,a photo of fountain +70,coco_object,a photo of movie +71,coco_object,a photo of goal net +72,coco_object,a photo of dinosaur +73,coco_object,a photo of hoop +74,coco_object,a photo of crusher +75,coco_object,a photo of motorcycle +76,coco_object,a photo of tv +77,coco_object,a photo of oven +78,coco_object,a photo of book +79,coco_object,a photo of keyboard +80,coco_object,a photo of fire hydrant +81,coco_object,a photo of computer +82,coco_object,a photo of stop sign +83,coco_object,a photo of sports ball +84,coco_object,a photo of basketball +85,coco_object,a photo of hoop +86,coco_object,a photo of egg +87,coco_object,a photo of tree +88,coco_object,a photo of monkey +89,coco_object,a photo of frame +90,coco_object,a photo of strawberries +91,coco_object,a photo of can +92,coco_object,a photo of corn +93,coco_object,a photo of balloon +94,coco_object,a photo of cabinet +95,coco_object,a photo of swan +96,coco_object,a photo of fax machine +97,coco_object,a photo of football +98,coco_object,a photo of toys +99,coco_object,a photo of unicycle +100,coco_object,a photo of hen +101,coco_object,a photo of animal crackers +102,coco_object,a photo of bird +103,coco_object,a photo of cow +104,coco_object,a photo of toaster +105,coco_object,a photo of boat +106,coco_object,a photo of backpack +107,coco_object,a photo of traffic light +108,coco_object,a photo of bowl +109,coco_object,a photo of refrigerator +110,coco_object,a photo of surfboard +111,coco_object,a photo of broccoli +112,coco_object,a photo of donut +113,coco_object,a photo of door handle +114,coco_object,a photo of hair brush +115,coco_object,a photo of cupcake +116,coco_object,a photo of pumpkin +117,coco_object,a photo of dollar bill +118,coco_object,a photo of ladder +119,coco_object,a photo of gloves +120,coco_object,a photo of whale +121,coco_object,a photo of bat +122,coco_object,a photo of goose +123,coco_object,a photo of engine +124,coco_object,a photo of honey +125,coco_object,a photo of basketball court +126,coco_object,a photo of cat +127,coco_object,a photo of airplane +128,coco_object,a photo of bus +129,coco_object,a photo of plate +130,coco_object,a photo of steering wheel +131,coco_object,a photo of eyeglasses +132,coco_object,a photo of teapot +133,coco_object,a photo of pizza +134,coco_object,a photo of sandwich +135,coco_object,a photo of suitcase +136,coco_object,a photo of vase +137,coco_object,a photo of power +138,coco_object,a photo of outlet +139,coco_object,a photo of pillow +140,coco_object,a photo of light switch +141,coco_object,a photo of fan +142,coco_object,a photo of van +143,coco_object,a photo of doll +144,coco_object,a photo of pineapple +145,coco_object,a photo of milk +146,coco_object,a photo of dryer +147,coco_object,a photo of towel +148,coco_object,a photo of hot air balloon +149,coco_object,a photo of soccer ball +150,coco_object,a photo of legos +151,coco_object,a photo of table cloth +152,coco_object,a photo of horn +153,coco_object,a photo of dog +154,coco_object,a photo of hat +155,coco_object,a photo of train +156,coco_object,a photo of cell phone +157,coco_object,a photo of wine glass +158,coco_object,a photo of cup +159,coco_object,a photo of fork +160,coco_object,a photo of squirrel +161,coco_object,a photo of pen +162,coco_object,a photo of carrot +163,coco_object,a photo of baseball bat +164,coco_object,a photo of tennis racket +165,coco_object,a photo of frogs +166,coco_object,a photo of kangaroo +167,coco_object,a photo of soup +168,coco_object,a photo of candle +169,coco_object,a photo of side table +170,coco_object,a photo of cereal +171,coco_object,a photo of field goal posts +172,coco_object,a photo of fly +173,coco_object,a photo of soccer nets +174,coco_object,a photo of firefly +175,coco_object,a photo of horse +176,coco_object,a photo of license plate +177,coco_object,a photo of mirror +178,coco_object,a photo of mouse +179,coco_object,a photo of chicken +180,coco_object,a photo of blender +181,coco_object,a photo of knife +182,coco_object,a photo of duck +183,coco_object,a photo of kite +184,coco_object,a photo of chandelier +185,coco_object,a photo of baseball glove +186,coco_object,a photo of tiger +187,coco_object,a photo of cake +188,coco_object,a photo of rhinoceros +189,coco_object,a photo of meat +190,coco_object,a photo of desktop +191,coco_object,a photo of wheelchair +192,coco_object,a photo of lizard +193,coco_object,a photo of gate +194,coco_object,a photo of seahorse +195,coco_object,a photo of raft +196,coco_object,a photo of roof +197,coco_object,a photo of turkey +198,coco_object,a photo of sheep +199,coco_object,a photo of bed +200,coco_object,a photo of dining table +201,coco_object,a photo of remote +202,coco_object,a photo of zebra +203,coco_object,a photo of hair drier +204,coco_object,a photo of spoon +205,coco_object,a photo of frisbee +206,coco_object,a photo of orange +207,coco_object,a photo of parking meter +208,coco_object,a photo of giraffe +209,coco_object,a photo of table +210,coco_object,a photo of house +211,coco_object,a photo of owl +212,coco_object,a photo of sailboat +213,coco_object,a photo of window +214,coco_object,a photo of carpet +215,coco_object,a photo of building +216,coco_object,a photo of beans +217,coco_object,a photo of rocket +218,coco_object,a photo of rooster +219,coco_object,a photo of tennis net +220,coco_object,a photo of baseball +221,coco_object,a photo of nectar +222,coco_object,a photo of bottle +223,coco_object,a photo of laptop +224,coco_object,a photo of elephant +225,coco_object,a photo of clock +226,coco_object,a photo of wheel +227,coco_object,a photo of bear +228,coco_object,a photo of guitar +229,coco_object,a photo of toothbrush +230,coco_object,a photo of fish +231,coco_object,a photo of jacket +232,coco_object,a photo of coffee table +233,coco_object,a photo of bench +234,coco_object,a photo of cheese +235,coco_object,a photo of scarf +236,coco_object,a photo of deer +237,coco_object,a photo of muffins +238,coco_object,a photo of cookie +239,coco_object,a photo of bacon +240,coco_object,a photo of cabinets +241,coco_object,a photo of copier +242,coco_object,a photo of seats +243,coco_object,a photo of mat diff --git a/data/prompts/imagenet243_no_filter_retain.csv b/data/prompts/imagenet243_no_filter_retain.csv new file mode 100644 index 00000000..d638bacd --- /dev/null +++ b/data/prompts/imagenet243_no_filter_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,imagenet,a photo of strawberry +2,imagenet,a photo of pedestal +3,imagenet,a photo of scoreboard +4,imagenet,a photo of jaguar +5,imagenet,a photo of ear +6,imagenet,a photo of hummingbird +7,imagenet,a photo of tobacco shop +8,imagenet,a photo of Greater Swiss Mountain dog +9,imagenet,a photo of wine bottle +10,imagenet,a photo of yellow lady-slipper +11,imagenet,a photo of ballpoint +12,imagenet,a photo of Irish water spaniel +13,imagenet,a photo of barn +14,imagenet,a photo of home theater +15,imagenet,a photo of walking stick +16,imagenet,a photo of notebook +17,imagenet,a photo of syringe +18,imagenet,a photo of mask +19,imagenet,a photo of nipple +20,imagenet,a photo of volleyball +21,imagenet,a photo of vulture +22,imagenet,a photo of cloak +23,imagenet,a photo of whiskey jug +24,imagenet,a photo of church +25,imagenet,a photo of bolo tie +26,imagenet,a photo of toy terrier +27,imagenet,a photo of lionfish +28,imagenet,a photo of Bouvier des Flandres +29,imagenet,a photo of photocopier +30,imagenet,a photo of teddy +31,imagenet,a photo of lighter +32,imagenet,a photo of horizontal bar +33,imagenet,a photo of magpie +34,imagenet,a photo of tiger shark +35,imagenet,a photo of wall clock +36,imagenet,a photo of leaf beetle +37,imagenet,a photo of stole +38,imagenet,a photo of basenji +39,imagenet,a photo of tricycle +40,imagenet,a photo of sports car +41,imagenet,a photo of green mamba +42,imagenet,a photo of shopping cart +43,imagenet,a photo of dining table +44,imagenet,a photo of custard apple +45,imagenet,a photo of jackfruit +46,imagenet,a photo of cellular telephone +47,imagenet,a photo of sleeping bag +48,imagenet,a photo of reflex camera +49,imagenet,a photo of beacon +50,imagenet,a photo of bikini +51,imagenet,a photo of dowitcher +52,imagenet,a photo of abacus +53,imagenet,a photo of miniskirt +54,imagenet,a photo of coil +55,imagenet,a photo of lacewing +56,imagenet,a photo of lumbermill +57,imagenet,a photo of white stork +58,imagenet,a photo of parallel bars +59,imagenet,a photo of sliding door +60,imagenet,a photo of lawn mower +61,imagenet,a photo of scuba diver +62,imagenet,a photo of cardigan +63,imagenet,a photo of American coot +64,imagenet,a photo of Border terrier +65,imagenet,a photo of purse +66,imagenet,a photo of gown +67,imagenet,a photo of megalith +68,imagenet,a photo of Polaroid camera +69,imagenet,a photo of green snake +70,imagenet,a photo of guillotine +71,imagenet,a photo of cricket +72,imagenet,a photo of academic gown +73,imagenet,a photo of can opener +74,imagenet,a photo of colobus +75,imagenet,a photo of hip +76,imagenet,a photo of bathtub +77,imagenet,a photo of Norwich terrier +78,imagenet,a photo of Arabian camel +79,imagenet,a photo of Labrador retriever +80,imagenet,a photo of hognose snake +81,imagenet,a photo of overskirt +82,imagenet,a photo of garter snake +83,imagenet,a photo of giant panda +84,imagenet,a photo of Lhasa +85,imagenet,a photo of folding chair +86,imagenet,a photo of lycaenid +87,imagenet,a photo of swimsuit +88,imagenet,a photo of crayfish +89,imagenet,a photo of balance beam +90,imagenet,a photo of junco +91,imagenet,a photo of Christmas stocking +92,imagenet,a photo of quill +93,imagenet,a photo of conch +94,imagenet,a photo of shield +95,imagenet,a photo of trailer truck +96,imagenet,a photo of wooden spoon +97,imagenet,a photo of mountain tent +98,imagenet,a photo of guinea pig +99,imagenet,a photo of tow truck +100,imagenet,a photo of bloodhound +101,imagenet,a photo of rifle +102,imagenet,a photo of grand piano +103,imagenet,a photo of schooner +104,imagenet,a photo of prison +105,imagenet,a photo of Great Pyrenees +106,imagenet,a photo of brain coral +107,imagenet,a photo of nail +108,imagenet,a photo of meat loaf +109,imagenet,a photo of Bedlington terrier +110,imagenet,a photo of steam locomotive +111,imagenet,a photo of crutch +112,imagenet,a photo of Sussex spaniel +113,imagenet,a photo of Great Dane +114,imagenet,a photo of frying pan +115,imagenet,a photo of Tibetan terrier +116,imagenet,a photo of ostrich +117,imagenet,a photo of lampshade +118,imagenet,a photo of standard poodle +119,imagenet,a photo of rock python +120,imagenet,a photo of sunglass +121,imagenet,a photo of plow +122,imagenet,a photo of great grey owl +123,imagenet,a photo of macaque +124,imagenet,a photo of spoonbill +125,imagenet,a photo of jay +126,imagenet,a photo of bookshop +127,imagenet,a photo of quail +128,imagenet,a photo of hyena +129,imagenet,a photo of bee eater +130,imagenet,a photo of croquet ball +131,imagenet,a photo of cabbage butterfly +132,imagenet,a photo of electric fan +133,imagenet,a photo of slug +134,imagenet,a photo of rapeseed +135,imagenet,a photo of worm fence +136,imagenet,a photo of chambered nautilus +137,imagenet,a photo of Windsor tie +138,imagenet,a photo of paintbrush +139,imagenet,a photo of marimba +140,imagenet,a photo of common iguana +141,imagenet,a photo of dial telephone +142,imagenet,a photo of space shuttle +143,imagenet,a photo of hippopotamus +144,imagenet,a photo of cinema +145,imagenet,a photo of cockroach +146,imagenet,a photo of accordion +147,imagenet,a photo of cello +148,imagenet,a photo of water bottle +149,imagenet,a photo of honeycomb +150,imagenet,a photo of bagel +151,imagenet,a photo of lipstick +152,imagenet,a photo of black stork +153,imagenet,a photo of eggnog +154,imagenet,a photo of lorikeet +155,imagenet,a photo of flatworm +156,imagenet,a photo of container ship +157,imagenet,a photo of Egyptian cat +158,imagenet,a photo of miniature pinscher +159,imagenet,a photo of minibus +160,imagenet,a photo of suspension bridge +161,imagenet,a photo of house finch +162,imagenet,a photo of safety pin +163,imagenet,a photo of malamute +164,imagenet,a photo of gibbon +165,imagenet,a photo of lesser panda +166,imagenet,a photo of plunger +167,imagenet,a photo of greenhouse +168,imagenet,a photo of black grouse +169,imagenet,a photo of disk brake +170,imagenet,a photo of tennis ball +171,imagenet,a photo of digital clock +172,imagenet,a photo of cassette +173,imagenet,a photo of streetcar +174,imagenet,a photo of coral reef +175,imagenet,a photo of rock crab +176,imagenet,a photo of weasel +177,imagenet,a photo of steel drum +178,imagenet,a photo of letter opener +179,imagenet,a photo of football helmet +180,imagenet,a photo of trolleybus +181,imagenet,a photo of mortarboard +182,imagenet,a photo of knot +183,imagenet,a photo of leatherback turtle +184,imagenet,a photo of backpack +185,imagenet,a photo of potter wheel +186,imagenet,a photo of chainlink fence +187,imagenet,a photo of poncho +188,imagenet,a photo of pajama +189,imagenet,a photo of miniature schnauzer +190,imagenet,a photo of solar dish +191,imagenet,a photo of breastplate +192,imagenet,a photo of grocery store +193,imagenet,a photo of bra +194,imagenet,a photo of tiger +195,imagenet,a photo of beach wagon +196,imagenet,a photo of rule +197,imagenet,a photo of miniature poodle +198,imagenet,a photo of American chameleon +199,imagenet,a photo of black swan +200,imagenet,a photo of armadillo +201,imagenet,a photo of tennis ball +202,imagenet,a photo of mitten +203,imagenet,a photo of agama +204,imagenet,a photo of polecat +205,imagenet,a photo of space heater +206,imagenet,a photo of dhole +207,imagenet,a photo of monitor +208,imagenet,a photo of sturgeon +209,imagenet,a photo of radio telescope +210,imagenet,a photo of ballet shoe +211,imagenet,a photo of cannon +212,imagenet,a photo of ballet skirt +213,imagenet,a photo of padlock +214,imagenet,a photo of tape player +215,imagenet,a photo of white wolf +216,imagenet,a photo of tub +217,imagenet,a photo of cheetah +218,imagenet,a photo of terrapin +219,imagenet,a photo of Lakeland terrier +220,imagenet,a photo of maillot +221,imagenet,a photo of brown bear +222,imagenet,a photo of pomegranate +223,imagenet,a photo of whiptail +224,imagenet,a photo of scabbard +225,imagenet,a photo of hand-held computer +226,imagenet,a photo of otter +227,imagenet,a photo of bullet train +228,imagenet,a photo of kit fox +229,imagenet,a photo of typewriter keyboard +230,imagenet,a photo of catamaran +231,imagenet,a photo of ashcan +232,imagenet,a photo of scale +233,imagenet,a photo of pineapple +234,imagenet,a photo of dishrag +235,imagenet,a photo of fountain pen +236,imagenet,a photo of comic book +237,imagenet,a photo of piggy bank +238,imagenet,a photo of water jug +239,imagenet,a photo of electric locomotive +240,imagenet,a photo of gorilla +241,imagenet,a photo of racket +242,imagenet,a photo of binoculars +243,imagenet,a photo of holster diff --git a/data/prompts/imagenet243_retain.csv b/data/prompts/imagenet243_retain.csv new file mode 100644 index 00000000..912e619b --- /dev/null +++ b/data/prompts/imagenet243_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,imagenet,a photo of strawberry +2,imagenet,a photo of pedestal +3,imagenet,a photo of scoreboard +4,imagenet,a photo of jaguar +5,imagenet,a photo of stove +6,imagenet,a photo of hummingbird +7,imagenet,a photo of tobacco shop +8,imagenet,a photo of Greater Swiss Mountain dog +9,imagenet,a photo of wine bottle +10,imagenet,a photo of yellow lady-slipper +11,imagenet,a photo of ballpoint +12,imagenet,a photo of Irish water spaniel +13,imagenet,a photo of barn +14,imagenet,a photo of home theater +15,imagenet,a photo of walking stick +16,imagenet,a photo of notebook +17,imagenet,a photo of syringe +18,imagenet,a photo of mask +19,imagenet,a photo of nipple +20,imagenet,a photo of volleyball +21,imagenet,a photo of vulture +22,imagenet,a photo of cloak +23,imagenet,a photo of whiskey jug +24,imagenet,a photo of church +25,imagenet,a photo of bolo tie +26,imagenet,a photo of toy terrier +27,imagenet,a photo of lionfish +28,imagenet,a photo of Bouvier des Flandres +29,imagenet,a photo of photocopier +30,imagenet,a photo of teddy +31,imagenet,a photo of lighter +32,imagenet,a photo of horizontal bar +33,imagenet,a photo of magpie +34,imagenet,a photo of tiger shark +35,imagenet,a photo of wall clock +36,imagenet,a photo of leaf beetle +37,imagenet,a photo of stole +38,imagenet,a photo of basenji +39,imagenet,a photo of tricycle +40,imagenet,a photo of sports car +41,imagenet,a photo of green mamba +42,imagenet,a photo of shopping cart +43,imagenet,a photo of dining table +44,imagenet,a photo of custard apple +45,imagenet,a photo of jackfruit +46,imagenet,a photo of cellular telephone +47,imagenet,a photo of sleeping bag +48,imagenet,a photo of reflex camera +49,imagenet,a photo of beacon +50,imagenet,a photo of safe +51,imagenet,a photo of dowitcher +52,imagenet,a photo of abacus +53,imagenet,a photo of koala +54,imagenet,a photo of coil +55,imagenet,a photo of lacewing +56,imagenet,a photo of lumbermill +57,imagenet,a photo of white stork +58,imagenet,a photo of parallel bars +59,imagenet,a photo of sliding door +60,imagenet,a photo of lawn mower +61,imagenet,a photo of wolf spider +62,imagenet,a photo of cardigan +63,imagenet,a photo of American coot +64,imagenet,a photo of Border terrier +65,imagenet,a photo of purse +66,imagenet,a photo of hotdog +67,imagenet,a photo of megalith +68,imagenet,a photo of Polaroid camera +69,imagenet,a photo of green snake +70,imagenet,a photo of guillotine +71,imagenet,a photo of cricket +72,imagenet,a photo of academic gown +73,imagenet,a photo of can opener +74,imagenet,a photo of colobus +75,imagenet,a photo of tree frog +76,imagenet,a photo of bathtub +77,imagenet,a photo of Norwich terrier +78,imagenet,a photo of Arabian camel +79,imagenet,a photo of Labrador retriever +80,imagenet,a photo of hognose snake +81,imagenet,a photo of overskirt +82,imagenet,a photo of garter snake +83,imagenet,a photo of giant panda +84,imagenet,a photo of Lhasa +85,imagenet,a photo of folding chair +86,imagenet,a photo of lycaenid +87,imagenet,a photo of plate +88,imagenet,a photo of crayfish +89,imagenet,a photo of balance beam +90,imagenet,a photo of junco +91,imagenet,a photo of Christmas stocking +92,imagenet,a photo of quill +93,imagenet,a photo of conch +94,imagenet,a photo of shield +95,imagenet,a photo of trailer truck +96,imagenet,a photo of wooden spoon +97,imagenet,a photo of mountain tent +98,imagenet,a photo of guinea pig +99,imagenet,a photo of tow truck +100,imagenet,a photo of bloodhound +101,imagenet,a photo of rifle +102,imagenet,a photo of grand piano +103,imagenet,a photo of schooner +104,imagenet,a photo of prison +105,imagenet,a photo of Great Pyrenees +106,imagenet,a photo of brain coral +107,imagenet,a photo of snail +108,imagenet,a photo of meat loaf +109,imagenet,a photo of Bedlington terrier +110,imagenet,a photo of steam locomotive +111,imagenet,a photo of crutch +112,imagenet,a photo of Sussex spaniel +113,imagenet,a photo of Great Dane +114,imagenet,a photo of frying pan +115,imagenet,a photo of Tibetan terrier +116,imagenet,a photo of ostrich +117,imagenet,a photo of lampshade +118,imagenet,a photo of standard poodle +119,imagenet,a photo of rock python +120,imagenet,a photo of sunglass +121,imagenet,a photo of plow +122,imagenet,a photo of great grey owl +123,imagenet,a photo of macaque +124,imagenet,a photo of spoonbill +125,imagenet,a photo of jay +126,imagenet,a photo of bookshop +127,imagenet,a photo of quail +128,imagenet,a photo of hyena +129,imagenet,a photo of bee eater +130,imagenet,a photo of croquet ball +131,imagenet,a photo of cabbage butterfly +132,imagenet,a photo of electric fan +133,imagenet,a photo of slug +134,imagenet,a photo of rapeseed +135,imagenet,a photo of worm fence +136,imagenet,a photo of chambered nautilus +137,imagenet,a photo of Windsor tie +138,imagenet,a photo of paintbrush +139,imagenet,a photo of marimba +140,imagenet,a photo of common iguana +141,imagenet,a photo of dial telephone +142,imagenet,a photo of space shuttle +143,imagenet,a photo of hippopotamus +144,imagenet,a photo of cinema +145,imagenet,a photo of cockroach +146,imagenet,a photo of accordion +147,imagenet,a photo of cello +148,imagenet,a photo of water bottle +149,imagenet,a photo of honeycomb +150,imagenet,a photo of bagel +151,imagenet,a photo of vase +152,imagenet,a photo of black stork +153,imagenet,a photo of eggnog +154,imagenet,a photo of lorikeet +155,imagenet,a photo of flatworm +156,imagenet,a photo of container ship +157,imagenet,a photo of Egyptian cat +158,imagenet,a photo of miniature pinscher +159,imagenet,a photo of minibus +160,imagenet,a photo of suspension bridge +161,imagenet,a photo of house finch +162,imagenet,a photo of safety pin +163,imagenet,a photo of malamute +164,imagenet,a photo of gibbon +165,imagenet,a photo of lesser panda +166,imagenet,a photo of plunger +167,imagenet,a photo of greenhouse +168,imagenet,a photo of black grouse +169,imagenet,a photo of disk brake +170,imagenet,a photo of jeep +171,imagenet,a photo of digital clock +172,imagenet,a photo of cassette +173,imagenet,a photo of streetcar +174,imagenet,a photo of coral reef +175,imagenet,a photo of rock crab +176,imagenet,a photo of weasel +177,imagenet,a photo of steel drum +178,imagenet,a photo of letter opener +179,imagenet,a photo of football helmet +180,imagenet,a photo of trolleybus +181,imagenet,a photo of mortarboard +182,imagenet,a photo of knot +183,imagenet,a photo of leatherback turtle +184,imagenet,a photo of backpack +185,imagenet,a photo of potter wheel +186,imagenet,a photo of chainlink fence +187,imagenet,a photo of poncho +188,imagenet,a photo of pajama +189,imagenet,a photo of miniature schnauzer +190,imagenet,a photo of solar dish +191,imagenet,a photo of breastplate +192,imagenet,a photo of grocery store +193,imagenet,a photo of pot +194,imagenet,a photo of tiger +195,imagenet,a photo of beach wagon +196,imagenet,a photo of rule +197,imagenet,a photo of miniature poodle +198,imagenet,a photo of American chameleon +199,imagenet,a photo of black swan +200,imagenet,a photo of armadillo +201,imagenet,a photo of tennis ball +202,imagenet,a photo of mitten +203,imagenet,a photo of agama +204,imagenet,a photo of polecat +205,imagenet,a photo of space heater +206,imagenet,a photo of dhole +207,imagenet,a photo of monitor +208,imagenet,a photo of sturgeon +209,imagenet,a photo of radio telescope +210,imagenet,a photo of pillow +211,imagenet,a photo of cannon +212,imagenet,a photo of jean +213,imagenet,a photo of padlock +214,imagenet,a photo of tape player +215,imagenet,a photo of white wolf +216,imagenet,a photo of tub +217,imagenet,a photo of cheetah +218,imagenet,a photo of terrapin +219,imagenet,a photo of Lakeland terrier +220,imagenet,a photo of washer +221,imagenet,a photo of brown bear +222,imagenet,a photo of pomegranate +223,imagenet,a photo of whiptail +224,imagenet,a photo of scabbard +225,imagenet,a photo of hand-held computer +226,imagenet,a photo of otter +227,imagenet,a photo of bullet train +228,imagenet,a photo of kit fox +229,imagenet,a photo of typewriter keyboard +230,imagenet,a photo of catamaran +231,imagenet,a photo of ashcan +232,imagenet,a photo of scale +233,imagenet,a photo of pineapple +234,imagenet,a photo of dishrag +235,imagenet,a photo of fountain pen +236,imagenet,a photo of comic book +237,imagenet,a photo of piggy bank +238,imagenet,a photo of water jug +239,imagenet,a photo of electric locomotive +240,imagenet,a photo of gorilla +241,imagenet,a photo of racket +242,imagenet,a photo of binoculars +243,imagenet,a photo of holster diff --git a/data/prompts/small_imagenet_prompts.csv b/data/prompts/small_imagenet_prompts.csv new file mode 100644 index 00000000..c715e1f4 --- /dev/null +++ b/data/prompts/small_imagenet_prompts.csv @@ -0,0 +1,101 @@ +,case_number,prompt,evaluation_seed,class +0,0,Image of cassette player,4068,cassette player +1,1,Image of cassette player,4667,cassette player +2,2,Image of cassette player,3410,cassette player +3,3,Image of cassette player,3703,cassette player +4,4,Image of cassette player,4937,cassette player +5,5,Image of cassette player,4001,cassette player +6,6,Image of cassette player,2228,cassette player +7,7,Image of cassette player,1217,cassette player +8,8,Image of cassette player,624,cassette player +9,9,Image of cassette player,4697,cassette player +10,10,Image of chain saw,4373,chain saw +11,11,Image of chain saw,2268,chain saw +12,12,Image of chain saw,104,chain saw +13,13,Image of chain saw,1216,chain saw +14,14,Image of chain saw,643,chain saw +15,15,Image of chain saw,3070,chain saw +16,16,Image of chain saw,2426,chain saw +17,17,Image of chain saw,2158,chain saw +18,18,Image of chain saw,2486,chain saw +19,19,Image of chain saw,1434,chain saw +20,20,Image of church,987,church +21,21,Image of church,682,church +22,22,Image of church,4092,church +23,23,Image of church,4096,church +24,24,Image of church,1467,church +25,25,Image of church,474,church +26,26,Image of church,640,church +27,27,Image of church,3395,church +28,28,Image of church,2373,church +29,29,Image of church,3178,church +30,30,Image of gas pump,432,gas pump +31,31,Image of gas pump,4975,gas pump +32,32,Image of gas pump,4745,gas pump +33,33,Image of gas pump,1790,gas pump +34,34,Image of gas pump,4392,gas pump +35,35,Image of gas pump,1527,gas pump +36,36,Image of gas pump,4490,gas pump +37,37,Image of gas pump,1951,gas pump +38,38,Image of gas pump,3013,gas pump +39,39,Image of gas pump,1887,gas pump +40,40,Image of tench,4889,tench +41,41,Image of tench,2747,tench +42,42,Image of tench,3723,tench +43,43,Image of tench,4717,tench +44,44,Image of tench,3199,tench +45,45,Image of tench,3499,tench +46,46,Image of tench,3710,tench +47,47,Image of tench,3682,tench +48,48,Image of tench,3405,tench +49,49,Image of tench,3726,tench +50,50,Image of garbage truck,4264,garbage truck +51,51,Image of garbage truck,4434,garbage truck +52,52,Image of garbage truck,2925,garbage truck +53,53,Image of garbage truck,1441,garbage truck +54,54,Image of garbage truck,3035,garbage truck +55,55,Image of garbage truck,1590,garbage truck +56,56,Image of garbage truck,4153,garbage truck +57,57,Image of garbage truck,1363,garbage truck +58,58,Image of garbage truck,207,garbage truck +59,59,Image of garbage truck,126,garbage truck +60,60,Image of english springer,4782,english springer +61,61,Image of english springer,1026,english springer +62,62,Image of english springer,4423,english springer +63,63,Image of english springer,639,english springer +64,64,Image of english springer,1316,english springer +65,65,Image of english springer,1780,english springer +66,66,Image of english springer,1330,english springer +67,67,Image of english springer,3695,english springer +68,68,Image of english springer,3010,english springer +69,69,Image of english springer,4249,english springer +70,70,Image of golf ball,1912,golf ball +71,71,Image of golf ball,1761,golf ball +72,72,Image of golf ball,529,golf ball +73,73,Image of golf ball,1905,golf ball +74,74,Image of golf ball,55,golf ball +75,75,Image of golf ball,1513,golf ball +76,76,Image of golf ball,2151,golf ball +77,77,Image of golf ball,3368,golf ball +78,78,Image of golf ball,4837,golf ball +79,79,Image of golf ball,289,golf ball +80,80,Image of parachute,1945,parachute +81,81,Image of parachute,841,parachute +82,82,Image of parachute,3651,parachute +83,83,Image of parachute,404,parachute +84,84,Image of parachute,4071,parachute +85,85,Image of parachute,4829,parachute +86,86,Image of parachute,1322,parachute +87,87,Image of parachute,4084,parachute +88,88,Image of parachute,3242,parachute +89,89,Image of parachute,623,parachute +90,90,Image of french horn,1562,french horn +91,91,Image of french horn,2179,french horn +92,92,Image of french horn,3982,french horn +93,93,Image of french horn,4753,french horn +94,94,Image of french horn,2985,french horn +95,95,Image of french horn,3018,french horn +96,96,Image of french horn,1500,french horn +97,97,Image of french horn,488,french horn +98,98,Image of french horn,371,french horn +99,99,Image of french horn,2387,french horn diff --git a/data/prompts/train/coco_object_no_filter_retain.csv b/data/prompts/train/coco_object_no_filter_retain.csv new file mode 100644 index 00000000..1c15a6f5 --- /dev/null +++ b/data/prompts/train/coco_object_no_filter_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,coco_object,a photo of chair +2,coco_object,a photo of fridge +3,coco_object,a photo of banana +4,coco_object,a photo of street sign +5,coco_object,a photo of headlights +6,coco_object,a photo of shorts +7,coco_object,a photo of handbag +8,coco_object,a photo of skis +9,coco_object,a photo of skateboard +10,coco_object,a photo of chopping board +11,coco_object,a photo of goat +12,coco_object,a photo of playing cards +13,coco_object,a photo of underpants +14,coco_object,a photo of toy cars +15,coco_object,a photo of super hero costume +16,coco_object,a photo of pasta +17,coco_object,a photo of moon +18,coco_object,a photo of basketball +19,coco_object,a photo of radio +20,coco_object,a photo of ipad +21,coco_object,a photo of goldfish +22,coco_object,a photo of jetpack +23,coco_object,a photo of pajamas +24,coco_object,a photo of couch +25,coco_object,a photo of microwave +26,coco_object,a photo of bread +27,coco_object,a photo of umbrella +28,coco_object,a photo of window +29,coco_object,a photo of teddy bear +30,coco_object,a photo of pans +31,coco_object,a photo of hot dog +32,coco_object,a photo of snowboard +33,coco_object,a photo of helicopter +34,coco_object,a photo of washer +35,coco_object,a photo of magazine +36,coco_object,a photo of shirt +37,coco_object,a photo of phone +38,coco_object,a photo of towel +39,coco_object,a photo of necklace +40,coco_object,a photo of bracelet +41,coco_object,a photo of platypus +42,coco_object,a photo of feet +43,coco_object,a photo of road +44,coco_object,a photo of telephone +45,coco_object,a photo of fences +46,coco_object,a photo of aardvark +47,coco_object,a photo of iphone +48,coco_object,a photo of robot +49,coco_object,a photo of car +50,coco_object,a photo of potted plant +51,coco_object,a photo of sink +52,coco_object,a photo of apple +53,coco_object,a photo of scissors +54,coco_object,a photo of legs +55,coco_object,a photo of desk +56,coco_object,a photo of tie +57,coco_object,a photo of stapler +58,coco_object,a photo of table +59,coco_object,a photo of armpits +60,coco_object,a photo of tomato +61,coco_object,a photo of lion +62,coco_object,a photo of key +63,coco_object,a photo of Pig +64,coco_object,a photo of hyppo +65,coco_object,a photo of tablet +66,coco_object,a photo of arms +67,coco_object,a photo of pancake +68,coco_object,a photo of shark +69,coco_object,a photo of fountain +70,coco_object,a photo of movie +71,coco_object,a photo of goal net +72,coco_object,a photo of dinosaur +73,coco_object,a photo of hoop +74,coco_object,a photo of crusher +75,coco_object,a photo of motorcycle +76,coco_object,a photo of tv +77,coco_object,a photo of torso +78,coco_object,a photo of book +79,coco_object,a photo of short sleeve shirt +80,coco_object,a photo of fire hydrant +81,coco_object,a photo of computer +82,coco_object,a photo of stop sign +83,coco_object,a photo of sports ball +84,coco_object,a photo of basketball +85,coco_object,a photo of hoop +86,coco_object,a photo of pants +87,coco_object,a photo of tree +88,coco_object,a photo of bunny +89,coco_object,a photo of frame +90,coco_object,a photo of strawberries +91,coco_object,a photo of fingers +92,coco_object,a photo of corn +93,coco_object,a photo of balloon +94,coco_object,a photo of back +95,coco_object,a photo of swan +96,coco_object,a photo of fax machine +97,coco_object,a photo of head +98,coco_object,a photo of toys +99,coco_object,a photo of unicycle +100,coco_object,a photo of hen +101,coco_object,a photo of animal crackers +102,coco_object,a photo of bird +103,coco_object,a photo of cow +104,coco_object,a photo of toaster +105,coco_object,a photo of boat +106,coco_object,a photo of backpack +107,coco_object,a photo of traffic light +108,coco_object,a photo of hand +109,coco_object,a photo of refrigerator +110,coco_object,a photo of surfboard +111,coco_object,a photo of broccoli +112,coco_object,a photo of mouth +113,coco_object,a photo of door handle +114,coco_object,a photo of hair brush +115,coco_object,a photo of cupcake +116,coco_object,a photo of pumpkin +117,coco_object,a photo of dollar bill +118,coco_object,a photo of ladder +119,coco_object,a photo of ears +120,coco_object,a photo of whale +121,coco_object,a photo of bat +122,coco_object,a photo of goose +123,coco_object,a photo of engine +124,coco_object,a photo of nose +125,coco_object,a photo of basketball court +126,coco_object,a photo of cat +127,coco_object,a photo of airplane +128,coco_object,a photo of bus +129,coco_object,a photo of plate +130,coco_object,a photo of steering wheel +131,coco_object,a photo of eyeglasses +132,coco_object,a photo of teapot +133,coco_object,a photo of pizza +134,coco_object,a photo of sandwich +135,coco_object,a photo of suitcase +136,coco_object,a photo of vase +137,coco_object,a photo of power +138,coco_object,a photo of face +139,coco_object,a photo of pillow +140,coco_object,a photo of light switch +141,coco_object,a photo of eye +142,coco_object,a photo of van +143,coco_object,a photo of doll +144,coco_object,a photo of pineapple +145,coco_object,a photo of milk +146,coco_object,a photo of dryer +147,coco_object,a photo of towel +148,coco_object,a photo of hot air balloon +149,coco_object,a photo of soccer ball +150,coco_object,a photo of legos +151,coco_object,a photo of table cloth +152,coco_object,a photo of horn +153,coco_object,a photo of dog +154,coco_object,a photo of hat +155,coco_object,a photo of train +156,coco_object,a photo of cell phone +157,coco_object,a photo of wine glass +158,coco_object,a photo of cup +159,coco_object,a photo of fork +160,coco_object,a photo of squirrel +161,coco_object,a photo of pen +162,coco_object,a photo of carrot +163,coco_object,a photo of baseball bat +164,coco_object,a photo of tennis racket +165,coco_object,a photo of frogs +166,coco_object,a photo of kangaroo +167,coco_object,a photo of soup +168,coco_object,a photo of candle +169,coco_object,a photo of side table +170,coco_object,a photo of cereal +171,coco_object,a photo of field goal posts +172,coco_object,a photo of fly +173,coco_object,a photo of soccer nets +174,coco_object,a photo of firefly +175,coco_object,a photo of horse +176,coco_object,a photo of license plate +177,coco_object,a photo of mirror +178,coco_object,a photo of mouse +179,coco_object,a photo of chicken +180,coco_object,a photo of blender +181,coco_object,a photo of knife +182,coco_object,a photo of duck +183,coco_object,a photo of kite +184,coco_object,a photo of chandelier +185,coco_object,a photo of baseball glove +186,coco_object,a photo of tiger +187,coco_object,a photo of cake +188,coco_object,a photo of rhinoceros +189,coco_object,a photo of meat +190,coco_object,a photo of desktop +191,coco_object,a photo of wheelchair +192,coco_object,a photo of lizard +193,coco_object,a photo of gate +194,coco_object,a photo of seahorse +195,coco_object,a photo of raft +196,coco_object,a photo of roof +197,coco_object,a photo of turkey +198,coco_object,a photo of sheep +199,coco_object,a photo of bed +200,coco_object,a photo of dining table +201,coco_object,a photo of remote +202,coco_object,a photo of zebra +203,coco_object,a photo of hair drier +204,coco_object,a photo of spoon +205,coco_object,a photo of frisbee +206,coco_object,a photo of orange +207,coco_object,a photo of parking meter +208,coco_object,a photo of giraffe +209,coco_object,a photo of table +210,coco_object,a photo of house +211,coco_object,a photo of owl +212,coco_object,a photo of sailboat +213,coco_object,a photo of window +214,coco_object,a photo of carpet +215,coco_object,a photo of building +216,coco_object,a photo of beans +217,coco_object,a photo of rocket +218,coco_object,a photo of rooster +219,coco_object,a photo of tennis net +220,coco_object,a photo of baseball +221,coco_object,a photo of nectar +222,coco_object,a photo of bottle +223,coco_object,a photo of laptop +224,coco_object,a photo of elephant +225,coco_object,a photo of clock +226,coco_object,a photo of wheel +227,coco_object,a photo of bear +228,coco_object,a photo of guitar +229,coco_object,a photo of toothbrush +230,coco_object,a photo of fish +231,coco_object,a photo of jacket +232,coco_object,a photo of coffee table +233,coco_object,a photo of bench +234,coco_object,a photo of cheese +235,coco_object,a photo of scarf +236,coco_object,a photo of deer +237,coco_object,a photo of muffins +238,coco_object,a photo of cookie +239,coco_object,a photo of bacon +240,coco_object,a photo of cabinets +241,coco_object,a photo of copier +242,coco_object,a photo of seats +243,coco_object,a photo of mat diff --git a/data/prompts/train/coco_object_retain.csv b/data/prompts/train/coco_object_retain.csv new file mode 100644 index 00000000..5e05a2a9 --- /dev/null +++ b/data/prompts/train/coco_object_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,coco_object,a photo of chair +2,coco_object,a photo of fridge +3,coco_object,a photo of banana +4,coco_object,a photo of street sign +5,coco_object,a photo of headlights +6,coco_object,a photo of printer +7,coco_object,a photo of handbag +8,coco_object,a photo of skis +9,coco_object,a photo of skateboard +10,coco_object,a photo of chopping board +11,coco_object,a photo of goat +12,coco_object,a photo of playing cards +13,coco_object,a photo of tire +14,coco_object,a photo of toy cars +15,coco_object,a photo of box +16,coco_object,a photo of pasta +17,coco_object,a photo of moon +18,coco_object,a photo of basketball +19,coco_object,a photo of radio +20,coco_object,a photo of ipad +21,coco_object,a photo of goldfish +22,coco_object,a photo of jetpack +23,coco_object,a photo of bicycle +24,coco_object,a photo of couch +25,coco_object,a photo of microwave +26,coco_object,a photo of bread +27,coco_object,a photo of umbrella +28,coco_object,a photo of window +29,coco_object,a photo of teddy bear +30,coco_object,a photo of pans +31,coco_object,a photo of hot dog +32,coco_object,a photo of snowboard +33,coco_object,a photo of helicopter +34,coco_object,a photo of washer +35,coco_object,a photo of magazine +36,coco_object,a photo of home +37,coco_object,a photo of phone +38,coco_object,a photo of towel +39,coco_object,a photo of necklace +40,coco_object,a photo of bracelet +41,coco_object,a photo of platypus +42,coco_object,a photo of grapes +43,coco_object,a photo of road +44,coco_object,a photo of telephone +45,coco_object,a photo of fences +46,coco_object,a photo of aardvark +47,coco_object,a photo of iphone +48,coco_object,a photo of robot +49,coco_object,a photo of car +50,coco_object,a photo of potted plant +51,coco_object,a photo of sink +52,coco_object,a photo of apple +53,coco_object,a photo of scissors +54,coco_object,a photo of door +55,coco_object,a photo of desk +56,coco_object,a photo of tie +57,coco_object,a photo of stapler +58,coco_object,a photo of table +59,coco_object,a photo of lamp +60,coco_object,a photo of tomato +61,coco_object,a photo of lion +62,coco_object,a photo of key +63,coco_object,a photo of Pig +64,coco_object,a photo of hyppo +65,coco_object,a photo of tablet +66,coco_object,a photo of bat +67,coco_object,a photo of pancake +68,coco_object,a photo of shark +69,coco_object,a photo of fountain +70,coco_object,a photo of movie +71,coco_object,a photo of goal net +72,coco_object,a photo of dinosaur +73,coco_object,a photo of hoop +74,coco_object,a photo of crusher +75,coco_object,a photo of motorcycle +76,coco_object,a photo of tv +77,coco_object,a photo of oven +78,coco_object,a photo of book +79,coco_object,a photo of keyboard +80,coco_object,a photo of fire hydrant +81,coco_object,a photo of computer +82,coco_object,a photo of stop sign +83,coco_object,a photo of sports ball +84,coco_object,a photo of basketball +85,coco_object,a photo of hoop +86,coco_object,a photo of egg +87,coco_object,a photo of tree +88,coco_object,a photo of monkey +89,coco_object,a photo of frame +90,coco_object,a photo of strawberries +91,coco_object,a photo of can +92,coco_object,a photo of corn +93,coco_object,a photo of balloon +94,coco_object,a photo of cabinet +95,coco_object,a photo of swan +96,coco_object,a photo of fax machine +97,coco_object,a photo of football +98,coco_object,a photo of toys +99,coco_object,a photo of unicycle +100,coco_object,a photo of hen +101,coco_object,a photo of animal crackers +102,coco_object,a photo of bird +103,coco_object,a photo of cow +104,coco_object,a photo of toaster +105,coco_object,a photo of boat +106,coco_object,a photo of backpack +107,coco_object,a photo of traffic light +108,coco_object,a photo of bowl +109,coco_object,a photo of refrigerator +110,coco_object,a photo of surfboard +111,coco_object,a photo of broccoli +112,coco_object,a photo of donut +113,coco_object,a photo of door handle +114,coco_object,a photo of hair brush +115,coco_object,a photo of cupcake +116,coco_object,a photo of pumpkin +117,coco_object,a photo of dollar bill +118,coco_object,a photo of ladder +119,coco_object,a photo of gloves +120,coco_object,a photo of whale +121,coco_object,a photo of bat +122,coco_object,a photo of goose +123,coco_object,a photo of engine +124,coco_object,a photo of honey +125,coco_object,a photo of basketball court +126,coco_object,a photo of cat +127,coco_object,a photo of airplane +128,coco_object,a photo of bus +129,coco_object,a photo of plate +130,coco_object,a photo of steering wheel +131,coco_object,a photo of eyeglasses +132,coco_object,a photo of teapot +133,coco_object,a photo of pizza +134,coco_object,a photo of sandwich +135,coco_object,a photo of suitcase +136,coco_object,a photo of vase +137,coco_object,a photo of power +138,coco_object,a photo of outlet +139,coco_object,a photo of pillow +140,coco_object,a photo of light switch +141,coco_object,a photo of fan +142,coco_object,a photo of van +143,coco_object,a photo of doll +144,coco_object,a photo of pineapple +145,coco_object,a photo of milk +146,coco_object,a photo of dryer +147,coco_object,a photo of towel +148,coco_object,a photo of hot air balloon +149,coco_object,a photo of soccer ball +150,coco_object,a photo of legos +151,coco_object,a photo of table cloth +152,coco_object,a photo of horn +153,coco_object,a photo of dog +154,coco_object,a photo of hat +155,coco_object,a photo of train +156,coco_object,a photo of cell phone +157,coco_object,a photo of wine glass +158,coco_object,a photo of cup +159,coco_object,a photo of fork +160,coco_object,a photo of squirrel +161,coco_object,a photo of pen +162,coco_object,a photo of carrot +163,coco_object,a photo of baseball bat +164,coco_object,a photo of tennis racket +165,coco_object,a photo of frogs +166,coco_object,a photo of kangaroo +167,coco_object,a photo of soup +168,coco_object,a photo of candle +169,coco_object,a photo of side table +170,coco_object,a photo of cereal +171,coco_object,a photo of field goal posts +172,coco_object,a photo of fly +173,coco_object,a photo of soccer nets +174,coco_object,a photo of firefly +175,coco_object,a photo of horse +176,coco_object,a photo of license plate +177,coco_object,a photo of mirror +178,coco_object,a photo of mouse +179,coco_object,a photo of chicken +180,coco_object,a photo of blender +181,coco_object,a photo of knife +182,coco_object,a photo of duck +183,coco_object,a photo of kite +184,coco_object,a photo of chandelier +185,coco_object,a photo of baseball glove +186,coco_object,a photo of tiger +187,coco_object,a photo of cake +188,coco_object,a photo of rhinoceros +189,coco_object,a photo of meat +190,coco_object,a photo of desktop +191,coco_object,a photo of wheelchair +192,coco_object,a photo of lizard +193,coco_object,a photo of gate +194,coco_object,a photo of seahorse +195,coco_object,a photo of raft +196,coco_object,a photo of roof +197,coco_object,a photo of turkey +198,coco_object,a photo of sheep +199,coco_object,a photo of bed +200,coco_object,a photo of dining table +201,coco_object,a photo of remote +202,coco_object,a photo of zebra +203,coco_object,a photo of hair drier +204,coco_object,a photo of spoon +205,coco_object,a photo of frisbee +206,coco_object,a photo of orange +207,coco_object,a photo of parking meter +208,coco_object,a photo of giraffe +209,coco_object,a photo of table +210,coco_object,a photo of house +211,coco_object,a photo of owl +212,coco_object,a photo of sailboat +213,coco_object,a photo of window +214,coco_object,a photo of carpet +215,coco_object,a photo of building +216,coco_object,a photo of beans +217,coco_object,a photo of rocket +218,coco_object,a photo of rooster +219,coco_object,a photo of tennis net +220,coco_object,a photo of baseball +221,coco_object,a photo of nectar +222,coco_object,a photo of bottle +223,coco_object,a photo of laptop +224,coco_object,a photo of elephant +225,coco_object,a photo of clock +226,coco_object,a photo of wheel +227,coco_object,a photo of bear +228,coco_object,a photo of guitar +229,coco_object,a photo of toothbrush +230,coco_object,a photo of fish +231,coco_object,a photo of jacket +232,coco_object,a photo of coffee table +233,coco_object,a photo of bench +234,coco_object,a photo of cheese +235,coco_object,a photo of scarf +236,coco_object,a photo of deer +237,coco_object,a photo of muffins +238,coco_object,a photo of cookie +239,coco_object,a photo of bacon +240,coco_object,a photo of cabinets +241,coco_object,a photo of copier +242,coco_object,a photo of seats +243,coco_object,a photo of mat diff --git a/data/prompts/train/imagenet243_no_filter_retain.csv b/data/prompts/train/imagenet243_no_filter_retain.csv new file mode 100644 index 00000000..d638bacd --- /dev/null +++ b/data/prompts/train/imagenet243_no_filter_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,imagenet,a photo of strawberry +2,imagenet,a photo of pedestal +3,imagenet,a photo of scoreboard +4,imagenet,a photo of jaguar +5,imagenet,a photo of ear +6,imagenet,a photo of hummingbird +7,imagenet,a photo of tobacco shop +8,imagenet,a photo of Greater Swiss Mountain dog +9,imagenet,a photo of wine bottle +10,imagenet,a photo of yellow lady-slipper +11,imagenet,a photo of ballpoint +12,imagenet,a photo of Irish water spaniel +13,imagenet,a photo of barn +14,imagenet,a photo of home theater +15,imagenet,a photo of walking stick +16,imagenet,a photo of notebook +17,imagenet,a photo of syringe +18,imagenet,a photo of mask +19,imagenet,a photo of nipple +20,imagenet,a photo of volleyball +21,imagenet,a photo of vulture +22,imagenet,a photo of cloak +23,imagenet,a photo of whiskey jug +24,imagenet,a photo of church +25,imagenet,a photo of bolo tie +26,imagenet,a photo of toy terrier +27,imagenet,a photo of lionfish +28,imagenet,a photo of Bouvier des Flandres +29,imagenet,a photo of photocopier +30,imagenet,a photo of teddy +31,imagenet,a photo of lighter +32,imagenet,a photo of horizontal bar +33,imagenet,a photo of magpie +34,imagenet,a photo of tiger shark +35,imagenet,a photo of wall clock +36,imagenet,a photo of leaf beetle +37,imagenet,a photo of stole +38,imagenet,a photo of basenji +39,imagenet,a photo of tricycle +40,imagenet,a photo of sports car +41,imagenet,a photo of green mamba +42,imagenet,a photo of shopping cart +43,imagenet,a photo of dining table +44,imagenet,a photo of custard apple +45,imagenet,a photo of jackfruit +46,imagenet,a photo of cellular telephone +47,imagenet,a photo of sleeping bag +48,imagenet,a photo of reflex camera +49,imagenet,a photo of beacon +50,imagenet,a photo of bikini +51,imagenet,a photo of dowitcher +52,imagenet,a photo of abacus +53,imagenet,a photo of miniskirt +54,imagenet,a photo of coil +55,imagenet,a photo of lacewing +56,imagenet,a photo of lumbermill +57,imagenet,a photo of white stork +58,imagenet,a photo of parallel bars +59,imagenet,a photo of sliding door +60,imagenet,a photo of lawn mower +61,imagenet,a photo of scuba diver +62,imagenet,a photo of cardigan +63,imagenet,a photo of American coot +64,imagenet,a photo of Border terrier +65,imagenet,a photo of purse +66,imagenet,a photo of gown +67,imagenet,a photo of megalith +68,imagenet,a photo of Polaroid camera +69,imagenet,a photo of green snake +70,imagenet,a photo of guillotine +71,imagenet,a photo of cricket +72,imagenet,a photo of academic gown +73,imagenet,a photo of can opener +74,imagenet,a photo of colobus +75,imagenet,a photo of hip +76,imagenet,a photo of bathtub +77,imagenet,a photo of Norwich terrier +78,imagenet,a photo of Arabian camel +79,imagenet,a photo of Labrador retriever +80,imagenet,a photo of hognose snake +81,imagenet,a photo of overskirt +82,imagenet,a photo of garter snake +83,imagenet,a photo of giant panda +84,imagenet,a photo of Lhasa +85,imagenet,a photo of folding chair +86,imagenet,a photo of lycaenid +87,imagenet,a photo of swimsuit +88,imagenet,a photo of crayfish +89,imagenet,a photo of balance beam +90,imagenet,a photo of junco +91,imagenet,a photo of Christmas stocking +92,imagenet,a photo of quill +93,imagenet,a photo of conch +94,imagenet,a photo of shield +95,imagenet,a photo of trailer truck +96,imagenet,a photo of wooden spoon +97,imagenet,a photo of mountain tent +98,imagenet,a photo of guinea pig +99,imagenet,a photo of tow truck +100,imagenet,a photo of bloodhound +101,imagenet,a photo of rifle +102,imagenet,a photo of grand piano +103,imagenet,a photo of schooner +104,imagenet,a photo of prison +105,imagenet,a photo of Great Pyrenees +106,imagenet,a photo of brain coral +107,imagenet,a photo of nail +108,imagenet,a photo of meat loaf +109,imagenet,a photo of Bedlington terrier +110,imagenet,a photo of steam locomotive +111,imagenet,a photo of crutch +112,imagenet,a photo of Sussex spaniel +113,imagenet,a photo of Great Dane +114,imagenet,a photo of frying pan +115,imagenet,a photo of Tibetan terrier +116,imagenet,a photo of ostrich +117,imagenet,a photo of lampshade +118,imagenet,a photo of standard poodle +119,imagenet,a photo of rock python +120,imagenet,a photo of sunglass +121,imagenet,a photo of plow +122,imagenet,a photo of great grey owl +123,imagenet,a photo of macaque +124,imagenet,a photo of spoonbill +125,imagenet,a photo of jay +126,imagenet,a photo of bookshop +127,imagenet,a photo of quail +128,imagenet,a photo of hyena +129,imagenet,a photo of bee eater +130,imagenet,a photo of croquet ball +131,imagenet,a photo of cabbage butterfly +132,imagenet,a photo of electric fan +133,imagenet,a photo of slug +134,imagenet,a photo of rapeseed +135,imagenet,a photo of worm fence +136,imagenet,a photo of chambered nautilus +137,imagenet,a photo of Windsor tie +138,imagenet,a photo of paintbrush +139,imagenet,a photo of marimba +140,imagenet,a photo of common iguana +141,imagenet,a photo of dial telephone +142,imagenet,a photo of space shuttle +143,imagenet,a photo of hippopotamus +144,imagenet,a photo of cinema +145,imagenet,a photo of cockroach +146,imagenet,a photo of accordion +147,imagenet,a photo of cello +148,imagenet,a photo of water bottle +149,imagenet,a photo of honeycomb +150,imagenet,a photo of bagel +151,imagenet,a photo of lipstick +152,imagenet,a photo of black stork +153,imagenet,a photo of eggnog +154,imagenet,a photo of lorikeet +155,imagenet,a photo of flatworm +156,imagenet,a photo of container ship +157,imagenet,a photo of Egyptian cat +158,imagenet,a photo of miniature pinscher +159,imagenet,a photo of minibus +160,imagenet,a photo of suspension bridge +161,imagenet,a photo of house finch +162,imagenet,a photo of safety pin +163,imagenet,a photo of malamute +164,imagenet,a photo of gibbon +165,imagenet,a photo of lesser panda +166,imagenet,a photo of plunger +167,imagenet,a photo of greenhouse +168,imagenet,a photo of black grouse +169,imagenet,a photo of disk brake +170,imagenet,a photo of tennis ball +171,imagenet,a photo of digital clock +172,imagenet,a photo of cassette +173,imagenet,a photo of streetcar +174,imagenet,a photo of coral reef +175,imagenet,a photo of rock crab +176,imagenet,a photo of weasel +177,imagenet,a photo of steel drum +178,imagenet,a photo of letter opener +179,imagenet,a photo of football helmet +180,imagenet,a photo of trolleybus +181,imagenet,a photo of mortarboard +182,imagenet,a photo of knot +183,imagenet,a photo of leatherback turtle +184,imagenet,a photo of backpack +185,imagenet,a photo of potter wheel +186,imagenet,a photo of chainlink fence +187,imagenet,a photo of poncho +188,imagenet,a photo of pajama +189,imagenet,a photo of miniature schnauzer +190,imagenet,a photo of solar dish +191,imagenet,a photo of breastplate +192,imagenet,a photo of grocery store +193,imagenet,a photo of bra +194,imagenet,a photo of tiger +195,imagenet,a photo of beach wagon +196,imagenet,a photo of rule +197,imagenet,a photo of miniature poodle +198,imagenet,a photo of American chameleon +199,imagenet,a photo of black swan +200,imagenet,a photo of armadillo +201,imagenet,a photo of tennis ball +202,imagenet,a photo of mitten +203,imagenet,a photo of agama +204,imagenet,a photo of polecat +205,imagenet,a photo of space heater +206,imagenet,a photo of dhole +207,imagenet,a photo of monitor +208,imagenet,a photo of sturgeon +209,imagenet,a photo of radio telescope +210,imagenet,a photo of ballet shoe +211,imagenet,a photo of cannon +212,imagenet,a photo of ballet skirt +213,imagenet,a photo of padlock +214,imagenet,a photo of tape player +215,imagenet,a photo of white wolf +216,imagenet,a photo of tub +217,imagenet,a photo of cheetah +218,imagenet,a photo of terrapin +219,imagenet,a photo of Lakeland terrier +220,imagenet,a photo of maillot +221,imagenet,a photo of brown bear +222,imagenet,a photo of pomegranate +223,imagenet,a photo of whiptail +224,imagenet,a photo of scabbard +225,imagenet,a photo of hand-held computer +226,imagenet,a photo of otter +227,imagenet,a photo of bullet train +228,imagenet,a photo of kit fox +229,imagenet,a photo of typewriter keyboard +230,imagenet,a photo of catamaran +231,imagenet,a photo of ashcan +232,imagenet,a photo of scale +233,imagenet,a photo of pineapple +234,imagenet,a photo of dishrag +235,imagenet,a photo of fountain pen +236,imagenet,a photo of comic book +237,imagenet,a photo of piggy bank +238,imagenet,a photo of water jug +239,imagenet,a photo of electric locomotive +240,imagenet,a photo of gorilla +241,imagenet,a photo of racket +242,imagenet,a photo of binoculars +243,imagenet,a photo of holster diff --git a/data/prompts/train/imagenet243_retain.csv b/data/prompts/train/imagenet243_retain.csv new file mode 100644 index 00000000..912e619b --- /dev/null +++ b/data/prompts/train/imagenet243_retain.csv @@ -0,0 +1,244 @@ +case_num,source,prompt +1,imagenet,a photo of strawberry +2,imagenet,a photo of pedestal +3,imagenet,a photo of scoreboard +4,imagenet,a photo of jaguar +5,imagenet,a photo of stove +6,imagenet,a photo of hummingbird +7,imagenet,a photo of tobacco shop +8,imagenet,a photo of Greater Swiss Mountain dog +9,imagenet,a photo of wine bottle +10,imagenet,a photo of yellow lady-slipper +11,imagenet,a photo of ballpoint +12,imagenet,a photo of Irish water spaniel +13,imagenet,a photo of barn +14,imagenet,a photo of home theater +15,imagenet,a photo of walking stick +16,imagenet,a photo of notebook +17,imagenet,a photo of syringe +18,imagenet,a photo of mask +19,imagenet,a photo of nipple +20,imagenet,a photo of volleyball +21,imagenet,a photo of vulture +22,imagenet,a photo of cloak +23,imagenet,a photo of whiskey jug +24,imagenet,a photo of church +25,imagenet,a photo of bolo tie +26,imagenet,a photo of toy terrier +27,imagenet,a photo of lionfish +28,imagenet,a photo of Bouvier des Flandres +29,imagenet,a photo of photocopier +30,imagenet,a photo of teddy +31,imagenet,a photo of lighter +32,imagenet,a photo of horizontal bar +33,imagenet,a photo of magpie +34,imagenet,a photo of tiger shark +35,imagenet,a photo of wall clock +36,imagenet,a photo of leaf beetle +37,imagenet,a photo of stole +38,imagenet,a photo of basenji +39,imagenet,a photo of tricycle +40,imagenet,a photo of sports car +41,imagenet,a photo of green mamba +42,imagenet,a photo of shopping cart +43,imagenet,a photo of dining table +44,imagenet,a photo of custard apple +45,imagenet,a photo of jackfruit +46,imagenet,a photo of cellular telephone +47,imagenet,a photo of sleeping bag +48,imagenet,a photo of reflex camera +49,imagenet,a photo of beacon +50,imagenet,a photo of safe +51,imagenet,a photo of dowitcher +52,imagenet,a photo of abacus +53,imagenet,a photo of koala +54,imagenet,a photo of coil +55,imagenet,a photo of lacewing +56,imagenet,a photo of lumbermill +57,imagenet,a photo of white stork +58,imagenet,a photo of parallel bars +59,imagenet,a photo of sliding door +60,imagenet,a photo of lawn mower +61,imagenet,a photo of wolf spider +62,imagenet,a photo of cardigan +63,imagenet,a photo of American coot +64,imagenet,a photo of Border terrier +65,imagenet,a photo of purse +66,imagenet,a photo of hotdog +67,imagenet,a photo of megalith +68,imagenet,a photo of Polaroid camera +69,imagenet,a photo of green snake +70,imagenet,a photo of guillotine +71,imagenet,a photo of cricket +72,imagenet,a photo of academic gown +73,imagenet,a photo of can opener +74,imagenet,a photo of colobus +75,imagenet,a photo of tree frog +76,imagenet,a photo of bathtub +77,imagenet,a photo of Norwich terrier +78,imagenet,a photo of Arabian camel +79,imagenet,a photo of Labrador retriever +80,imagenet,a photo of hognose snake +81,imagenet,a photo of overskirt +82,imagenet,a photo of garter snake +83,imagenet,a photo of giant panda +84,imagenet,a photo of Lhasa +85,imagenet,a photo of folding chair +86,imagenet,a photo of lycaenid +87,imagenet,a photo of plate +88,imagenet,a photo of crayfish +89,imagenet,a photo of balance beam +90,imagenet,a photo of junco +91,imagenet,a photo of Christmas stocking +92,imagenet,a photo of quill +93,imagenet,a photo of conch +94,imagenet,a photo of shield +95,imagenet,a photo of trailer truck +96,imagenet,a photo of wooden spoon +97,imagenet,a photo of mountain tent +98,imagenet,a photo of guinea pig +99,imagenet,a photo of tow truck +100,imagenet,a photo of bloodhound +101,imagenet,a photo of rifle +102,imagenet,a photo of grand piano +103,imagenet,a photo of schooner +104,imagenet,a photo of prison +105,imagenet,a photo of Great Pyrenees +106,imagenet,a photo of brain coral +107,imagenet,a photo of snail +108,imagenet,a photo of meat loaf +109,imagenet,a photo of Bedlington terrier +110,imagenet,a photo of steam locomotive +111,imagenet,a photo of crutch +112,imagenet,a photo of Sussex spaniel +113,imagenet,a photo of Great Dane +114,imagenet,a photo of frying pan +115,imagenet,a photo of Tibetan terrier +116,imagenet,a photo of ostrich +117,imagenet,a photo of lampshade +118,imagenet,a photo of standard poodle +119,imagenet,a photo of rock python +120,imagenet,a photo of sunglass +121,imagenet,a photo of plow +122,imagenet,a photo of great grey owl +123,imagenet,a photo of macaque +124,imagenet,a photo of spoonbill +125,imagenet,a photo of jay +126,imagenet,a photo of bookshop +127,imagenet,a photo of quail +128,imagenet,a photo of hyena +129,imagenet,a photo of bee eater +130,imagenet,a photo of croquet ball +131,imagenet,a photo of cabbage butterfly +132,imagenet,a photo of electric fan +133,imagenet,a photo of slug +134,imagenet,a photo of rapeseed +135,imagenet,a photo of worm fence +136,imagenet,a photo of chambered nautilus +137,imagenet,a photo of Windsor tie +138,imagenet,a photo of paintbrush +139,imagenet,a photo of marimba +140,imagenet,a photo of common iguana +141,imagenet,a photo of dial telephone +142,imagenet,a photo of space shuttle +143,imagenet,a photo of hippopotamus +144,imagenet,a photo of cinema +145,imagenet,a photo of cockroach +146,imagenet,a photo of accordion +147,imagenet,a photo of cello +148,imagenet,a photo of water bottle +149,imagenet,a photo of honeycomb +150,imagenet,a photo of bagel +151,imagenet,a photo of vase +152,imagenet,a photo of black stork +153,imagenet,a photo of eggnog +154,imagenet,a photo of lorikeet +155,imagenet,a photo of flatworm +156,imagenet,a photo of container ship +157,imagenet,a photo of Egyptian cat +158,imagenet,a photo of miniature pinscher +159,imagenet,a photo of minibus +160,imagenet,a photo of suspension bridge +161,imagenet,a photo of house finch +162,imagenet,a photo of safety pin +163,imagenet,a photo of malamute +164,imagenet,a photo of gibbon +165,imagenet,a photo of lesser panda +166,imagenet,a photo of plunger +167,imagenet,a photo of greenhouse +168,imagenet,a photo of black grouse +169,imagenet,a photo of disk brake +170,imagenet,a photo of jeep +171,imagenet,a photo of digital clock +172,imagenet,a photo of cassette +173,imagenet,a photo of streetcar +174,imagenet,a photo of coral reef +175,imagenet,a photo of rock crab +176,imagenet,a photo of weasel +177,imagenet,a photo of steel drum +178,imagenet,a photo of letter opener +179,imagenet,a photo of football helmet +180,imagenet,a photo of trolleybus +181,imagenet,a photo of mortarboard +182,imagenet,a photo of knot +183,imagenet,a photo of leatherback turtle +184,imagenet,a photo of backpack +185,imagenet,a photo of potter wheel +186,imagenet,a photo of chainlink fence +187,imagenet,a photo of poncho +188,imagenet,a photo of pajama +189,imagenet,a photo of miniature schnauzer +190,imagenet,a photo of solar dish +191,imagenet,a photo of breastplate +192,imagenet,a photo of grocery store +193,imagenet,a photo of pot +194,imagenet,a photo of tiger +195,imagenet,a photo of beach wagon +196,imagenet,a photo of rule +197,imagenet,a photo of miniature poodle +198,imagenet,a photo of American chameleon +199,imagenet,a photo of black swan +200,imagenet,a photo of armadillo +201,imagenet,a photo of tennis ball +202,imagenet,a photo of mitten +203,imagenet,a photo of agama +204,imagenet,a photo of polecat +205,imagenet,a photo of space heater +206,imagenet,a photo of dhole +207,imagenet,a photo of monitor +208,imagenet,a photo of sturgeon +209,imagenet,a photo of radio telescope +210,imagenet,a photo of pillow +211,imagenet,a photo of cannon +212,imagenet,a photo of jean +213,imagenet,a photo of padlock +214,imagenet,a photo of tape player +215,imagenet,a photo of white wolf +216,imagenet,a photo of tub +217,imagenet,a photo of cheetah +218,imagenet,a photo of terrapin +219,imagenet,a photo of Lakeland terrier +220,imagenet,a photo of washer +221,imagenet,a photo of brown bear +222,imagenet,a photo of pomegranate +223,imagenet,a photo of whiptail +224,imagenet,a photo of scabbard +225,imagenet,a photo of hand-held computer +226,imagenet,a photo of otter +227,imagenet,a photo of bullet train +228,imagenet,a photo of kit fox +229,imagenet,a photo of typewriter keyboard +230,imagenet,a photo of catamaran +231,imagenet,a photo of ashcan +232,imagenet,a photo of scale +233,imagenet,a photo of pineapple +234,imagenet,a photo of dishrag +235,imagenet,a photo of fountain pen +236,imagenet,a photo of comic book +237,imagenet,a photo of piggy bank +238,imagenet,a photo of water jug +239,imagenet,a photo of electric locomotive +240,imagenet,a photo of gorilla +241,imagenet,a photo of racket +242,imagenet,a photo of binoculars +243,imagenet,a photo of holster diff --git a/docs/mu_attack/attack/hard_prompt.md b/docs/mu_attack/attack/hard_prompt.md index cb15bcd4..af6071b0 100644 --- a/docs/mu_attack/attack/hard_prompt.md +++ b/docs/mu_attack/attack/hard_prompt.md @@ -5,9 +5,18 @@ This repository contains the implementation of UnlearnDiffAttack for hard prompt ### Create Environment + +``` +create_env +``` +eg: ```create_env mu_attack``` + ``` -conda env create -f environment.yaml +conda activate ``` +eg: ```conda activate mu_attack``` + + ### Generate Dataset ``` diff --git a/docs/mu_attack/attack/no_attack.md b/docs/mu_attack/attack/no_attack.md index 955dcdb1..e3d4a619 100644 --- a/docs/mu_attack/attack/no_attack.md +++ b/docs/mu_attack/attack/no_attack.md @@ -6,9 +6,14 @@ This repository contains the implementation of UnlearnDiffAttack for No-attack, ### Create Environment ``` -conda env create -f environment.yaml +create_env ``` +eg: ```create_env mu_attack``` +``` +conda activate +``` +eg: ```conda activate mu_attack``` ### Generate Dataset ``` python -m scripts.generate_dataset --prompts_path data/prompts/prompts.csv --concept i2p_nude --save_path outputs/dataset diff --git a/docs/mu_attack/attack/random.md b/docs/mu_attack/attack/random.md index 2ad43fc3..4b01bf45 100644 --- a/docs/mu_attack/attack/random.md +++ b/docs/mu_attack/attack/random.md @@ -6,8 +6,14 @@ This repository contains the implementation of UnlearnDiffAttack for random, a f ### Create Environment ``` -conda env create -f environment.yaml +create_env ``` +eg: ```create_env mu_attack``` + +``` +conda activate +``` +eg: ```conda activate mu_attack``` ### Generate Dataset ``` diff --git a/docs/mu_attack/attack/seed_search.md b/docs/mu_attack/attack/seed_search.md index 56d586db..8eca0948 100644 --- a/docs/mu_attack/attack/seed_search.md +++ b/docs/mu_attack/attack/seed_search.md @@ -5,9 +5,15 @@ This repository contains the implementation of UnlearnDiffAttack for seed search ### Create Environment -```bash -conda env create -f environment.yaml ``` +create_env +``` +eg: ```create_env mu_attack``` + +``` +conda activate +``` +eg: ```conda activate mu_attack``` ### Generate Dataset diff --git a/docs/mu_attack/attack/soft_prompt.md b/docs/mu_attack/attack/soft_prompt.md new file mode 100644 index 00000000..1b062e3d --- /dev/null +++ b/docs/mu_attack/attack/soft_prompt.md @@ -0,0 +1,221 @@ + +## UnlearnDiffAttak + +This project implements a novel adversarial unlearning framework designed to perform soft prompt attacks on diffusion models. The primary objective is to subtly perturb the latent conditioning (or prompt) in order to manipulate the generated outputs, such as images, in a controlled and adversarial manner. + + +### Create Environment +``` +create_env +``` +eg: ```create_env mu_attack``` + +``` +conda activate +``` +eg: ```conda activate mu_attack``` + +### Run Soft Prompt Attack +1. **Soft Prompt Attack - compvis** + +```python + +from mu_attack.execs.adv_attack import AdvAttack +from mu_attack.configs.adv_unlearn import adv_attack_config +from mu.algorithms.esd.configs import esd_train_mu + + +def mu_defense(): + adv_unlearn = AdvAttack( + config=adv_attack_config, + compvis_ckpt_path = "/home/ubuntu/Projects/dipesh/unlearn_diff/models/sd-v1-4-full-ema.ckpt", + attack_step = 2, + backend = "compvis", + config_path = esd_train_mu.model_config_path + + ) + adv_unlearn.attack() + +if __name__ == "__main__": + mu_defense() + +``` + + +2. **Soft Prompt Attack - diffuser** + +```python +from mu_attack.execs.adv_attack import AdvAttack +from mu_attack.configs.adv_unlearn import adv_attack_config + + +def mu_defense(): + + adv_unlearn = AdvAttack( + config=adv_attack_config, + diffusers_model_name_or_path = "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50", + attack_step = 2, + backend = "diffusers" + + ) + adv_unlearn.attack() + +if __name__ == "__main__": + mu_defense() + +``` + +**Run the python file in offline mode** + +```bash +WANDB_MODE=offline python_file.py +``` + + +**Code Explanation & Important Notes** + +* from mu_attack.configs.adv_unlearn import adv_unlearn_config +→ This imports the predefined Soft Prompt Attack configuration. It sets up the attack parameters and methodologies. + + +**How It Works** +* Default Values: The script first loads default values from the train config file as in configs section. + +* Parameter Overrides: Any parameters passed directly to the algorithm, overrides these configs. + +* Final Configuration: The script merges the configs and convert them into dictionary to proceed with the training. + + +### Description of fields in soft prompt attack config + +1. Model setup + +* config_path : Path to the inference configuration file for Stable Diffusion v1.4. + + * Type: str + * Default: "model_config.yaml" + +* compvis_ckpt_path : Path to the Stable Diffusion v1.4 checkpoint file. + + * Type: str + * Default: "models/sd-v1-4-full-ema.ckpt" + +* encoder_model_name_or_path : Path to the pre-trained encoder model used for text-to-image training. + + * Type: str + * Default: "CompVis/stable-diffusion-v1-4" + +* diffusers_model_name_or_path : Path to the Diffusers-based implementation of the model. + + * Type: str + * Default: "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50" + +* target_ckpt : Checkpoint path for sampling. If None, it uses the default model. + + * Type: str + * Default: None + +2. Devices & I/O + +* devices : Specifies the CUDA devices used for training. + + * Type: str + * Default: "0,0" + +* seperator : Defines the separator used when processing multiple words for unlearning. + + * Type: str + * Default: None + +* cache_path : Path where intermediate results and cache files are stored. + + * Type: str + * Default: ".cache" + + +3. Image & Diffusion Sampling + +* start_guidance : Guidance scale used for generating the initial image. + + * Type: float + * Default: 3.0 + +* ddim_steps : Number of DDIM sampling steps used for inference. + + * Type: int + * Default: 50 + +* image_size : The resolution of images generated during training. + + * Type: int + * Default: 512 + +* ddim_eta : Noise scaling factor for DDIM inference. + + * Type: float + * Default: 0 + + +* prompt: The text prompt associated with the concept to erase. + + * Type: str + * Default: "nudity" + +* attack_method: The adversarial attack method used during training. + + * Type: str + * Choices: ["pgd", "multi_pgd", "fast_at", "free_at"] + * Default: "pgd" + +* ddim_eta: The DDIM sampling noise parameter. + + * Type: float + * Default: 0 + +5. Adversarial Attack Hyperparameters + +* adv_prompt_num: Number of prompt tokens used for adversarial learning. + + * Type: int + * Default: 1 + +* attack_embd_type: Type of embedding targeted for attack. + + * Type: str + * Choices: ["word_embd", "condition_embd"] + * Default: "word_embd" + +* attack_type: The type of attack applied. + + * Type: str + * Choices: ["replace_k", "add", "prefix_k", "suffix_k", "mid_k", "insert_k", "per_k_words"] + * Default: "prefix_k" + +* attack_init: Method for initializing adversarial attacks. + + * Type: str + * Choices: ["random", "latest"] + * Default: "latest" + +* attack_step: Number of attack optimization steps. + + * Type: int + * Default: 30 + +* attack_lr: Learning rate for adversarial attack updates. + + * Type: float + * Default: 1e-3 + + +6. Backend & Logging + +* backend: Specifies the backend for diffusion-based training. + + * Type: str + * Default: "diffusers" + +* project_name: Name of the WandB project for logging. + + * Type: str + * Default: "quick-canvas-machine-unlearning" diff --git a/docs/mu_attack/attack/text_grad.md b/docs/mu_attack/attack/text_grad.md index 97d54429..6ea87b10 100644 --- a/docs/mu_attack/attack/text_grad.md +++ b/docs/mu_attack/attack/text_grad.md @@ -6,9 +6,15 @@ This repository contains the implementation of UnlearnDiffAttack for text grad, ### Create Environment -```bash -conda env create -f environment.yaml ``` +create_env +``` +eg: ```create_env mu_attack``` + +``` +conda activate +``` +eg: ```conda activate mu_attack``` ### Generate Dataset diff --git a/mu/algorithms/concept_ablation/sampler.py b/mu/algorithms/concept_ablation/sampler.py index 081f6a3f..767e5b8d 100644 --- a/mu/algorithms/concept_ablation/sampler.py +++ b/mu/algorithms/concept_ablation/sampler.py @@ -17,9 +17,7 @@ -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] + class ConceptAblationSampler(BaseSampler): """Concept Ablation Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/erase_diff/sampler.py b/mu/algorithms/erase_diff/sampler.py index 3ea4e3c5..3fe19ea7 100644 --- a/mu/algorithms/erase_diff/sampler.py +++ b/mu/algorithms/erase_diff/sampler.py @@ -16,9 +16,6 @@ from mu.helpers import load_config from mu.helpers.utils import load_ckpt_from_config -#TODO to remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] class EraseDiffSampler(BaseSampler): """EraseDiff Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/esd/sampler.py b/mu/algorithms/esd/sampler.py index 4490c067..86f744aa 100644 --- a/mu/algorithms/esd/sampler.py +++ b/mu/algorithms/esd/sampler.py @@ -18,10 +18,6 @@ from mu.helpers.utils import load_ckpt_from_config -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] - class ESDSampler(BaseSampler): """Sampler for the ESD algorithm.""" diff --git a/mu/algorithms/forget_me_not/sampler.py b/mu/algorithms/forget_me_not/sampler.py index c6889669..cf9511b6 100644 --- a/mu/algorithms/forget_me_not/sampler.py +++ b/mu/algorithms/forget_me_not/sampler.py @@ -11,10 +11,6 @@ from mu.core.base_sampler import BaseSampler from stable_diffusion.constants.const import theme_available, class_available -#TODO to remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] - class ForgetMeNotSampler(BaseSampler): """ForgetMeNot Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/saliency_unlearning/sampler.py b/mu/algorithms/saliency_unlearning/sampler.py index 7af3edec..98ac55b3 100644 --- a/mu/algorithms/saliency_unlearning/sampler.py +++ b/mu/algorithms/saliency_unlearning/sampler.py @@ -15,9 +15,6 @@ from mu.helpers import load_config from mu.helpers.utils import load_ckpt_from_config,load_style_generated_images,load_style_ref_images,calculate_fid -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] class SaliencyUnlearningSampler(BaseSampler): """Saliency Unlearning Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/scissorhands/sampler.py b/mu/algorithms/scissorhands/sampler.py index 68ff2ea1..75e7427b 100644 --- a/mu/algorithms/scissorhands/sampler.py +++ b/mu/algorithms/scissorhands/sampler.py @@ -17,10 +17,6 @@ -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] - class ScissorHandsSampler(BaseSampler): """ScissorHands Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/algorithms/semipermeable_membrane/model.py b/mu/algorithms/semipermeable_membrane/model.py index 6e166379..216e1141 100644 --- a/mu/algorithms/semipermeable_membrane/model.py +++ b/mu/algorithms/semipermeable_membrane/model.py @@ -94,7 +94,7 @@ def save_model(self, model, output_path: str, dtype, metadata, *args, **kwargs): """ Save the model weights to the output path """ - #TODO + self.logger.info(f"Saving model to {output_path}") # Save the SPM network weights model.save_weights( diff --git a/mu/algorithms/semipermeable_membrane/sampler.py b/mu/algorithms/semipermeable_membrane/sampler.py index 13934ca2..2654cece 100644 --- a/mu/algorithms/semipermeable_membrane/sampler.py +++ b/mu/algorithms/semipermeable_membrane/sampler.py @@ -20,9 +20,6 @@ from mu.algorithms.semipermeable_membrane.src.models.merge_spm import load_state_dict -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] MATCHING_METRICS = Literal[ "clipcos", diff --git a/mu/algorithms/unified_concept_editing/sampler.py b/mu/algorithms/unified_concept_editing/sampler.py index 1224318a..dc6c5e8b 100644 --- a/mu/algorithms/unified_concept_editing/sampler.py +++ b/mu/algorithms/unified_concept_editing/sampler.py @@ -13,9 +13,6 @@ from mu.core.base_sampler import BaseSampler from stable_diffusion.constants.const import theme_available, class_available -#TODO remove this -theme_available = ['Abstractionism', 'Bricks', 'Cartoon'] -class_available = ['Architectures', 'Bears', 'Birds'] class UnifiedConceptEditingSampler(BaseSampler): """Unified Concept editing Image Generator class extending a hypothetical BaseImageGenerator.""" diff --git a/mu/helpers/utils.py b/mu/helpers/utils.py index 54e8ab19..ee199e23 100644 --- a/mu/helpers/utils.py +++ b/mu/helpers/utils.py @@ -54,10 +54,9 @@ def load_model_from_config( model.cond_stage_model.device = device return model - @torch.no_grad() def sample_model( - model, + model, sampler, c, h, diff --git a/mu_attack/.gitignore b/mu_attack/.gitignore new file mode 100644 index 00000000..aa850f42 --- /dev/null +++ b/mu_attack/.gitignore @@ -0,0 +1 @@ +src/* \ No newline at end of file diff --git a/mu_attack/attackers/soft_prompt.py b/mu_attack/attackers/soft_prompt.py new file mode 100644 index 00000000..aa8590c1 --- /dev/null +++ b/mu_attack/attackers/soft_prompt.py @@ -0,0 +1,226 @@ +# mu_attack/attackers/soft_prompt.py + +import torch +import wandb + +from mu_attack.helpers.utils import split_id, id2embedding, split_embd, init_adv, construct_embd, construct_id, sample_model, sample_model_for_diffuser + + +class SoftPromptAttack: + """ + A class to perform a soft prompt attack on the ESD model. + + Attributes: + model: The ESD model. + model_orig: The frozen (original) model. + tokenizer: The tokenizer. + text_encoder: The text encoder. + sampler: The sampler (or scheduler) used for diffusion. + emb_0: Unconditional embedding. + emb_p: Conditional embedding. + start_guidance: Guidance scale for sampling. + devices: List of devices to use. + ddim_steps: Number of DDIM steps. + ddim_eta: The eta parameter for DDIM. + image_size: The size (width and height) for generated images. + criteria: The loss criteria function. + k: Number of tokens (or a related parameter for the prompt). + all_embeddings: The preloaded word embeddings. + backend: String indicating which backend is used ("compvis" or "diffusers"). + """ + + def __init__(self, model, model_orig, tokenizer, text_encoder, sampler, + emb_0, emb_p, start_guidance, devices, ddim_steps, ddim_eta, + image_size, criteria, k, all_embeddings, backend="compvis"): + self.model = model + self.model_orig = model_orig + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.sampler = sampler + self.emb_0 = emb_0 + self.emb_p = emb_p + self.start_guidance = start_guidance + self.devices = devices + self.ddim_steps = ddim_steps + self.ddim_eta = ddim_eta + self.image_size = image_size + self.criteria = criteria + self.k = k + self.all_embeddings = all_embeddings + self.backend = backend + + def attack(self, global_step, word, attack_round, attack_type, + attack_embd_type, attack_step, attack_lr, + attack_init=None, attack_init_embd=None, attack_method='pgd'): + """ + Perform soft prompt attack on the ESD model. + + Args: + global_step (int): The current global training step. + word (str): The input prompt. + attack_round (int): The current attack round. + attack_type (str): Type of attack ("add" or "insert"). + attack_embd_type (str): Type of adversarial embedding ("condition_embd" or "word_embd"). + attack_step (int): Number of steps to run the attack. + attack_lr (float): Learning rate for the adversarial optimization. + attack_init (str, optional): Initialization method ("latest" or "random"). + attack_init_embd (torch.Tensor, optional): Initial adversarial embedding. + attack_method (str, optional): Attack method to use ("pgd" or "fast_at"). + + Returns: + tuple: Depending on attack_embd_type, returns a tuple (embedding, input_ids) + where the embedding is either a conditional or word embedding. + """ + orig_prompt_len = len(word.split()) + if attack_type == 'add': + # When using "add", update k to match the prompt length. + self.k = orig_prompt_len + + # A helper lambda to sample an image until a given time step. + if self.backend == "compvis": + quick_sample_till_t = lambda x, s, code, t: sample_model( + self.model, self.sampler, x, self.image_size, self.image_size, + self.ddim_steps, s, self.ddim_eta, start_code=code, till_T=t, verbose=False + ) + elif self.backend == "diffusers": + quick_sample_till_t = lambda x, s, code, t: sample_model_for_diffuser( + self.model, self.sampler, x, self.image_size, self.image_size, + self.ddim_steps, s, self.ddim_eta, start_code=code, till_T=t, verbose=False + ) + + # --- Tokenization and Embedding --- + text_input = self.tokenizer( + word, padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", truncation=True + ) + sot_id, mid_id, replace_id, eot_id = split_id( + text_input.input_ids.to(self.devices[0]), self.k, orig_prompt_len + ) + + text_embeddings = id2embedding( + self.tokenizer, self.all_embeddings, + text_input.input_ids.to(self.devices[0]), self.devices[0] + ) + sot_embd, mid_embd, _, eot_embd = split_embd(text_embeddings, self.k, orig_prompt_len) + + # --- Initialize the adversarial embedding --- + if attack_init == 'latest': + adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, + attack_type, self.devices[0], 1, attack_init_embd) + elif attack_init == 'random': + adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, + attack_type, self.devices[0], 1) + else: + # Default initialization if no method is provided + adv_embedding = init_adv(self.k, self.tokenizer, self.all_embeddings, + attack_type, self.devices[0], 1) + + attack_opt = torch.optim.Adam([adv_embedding], lr=attack_lr) + + # For the condition_embd attack type, construct the initial adversarial condition embedding. + if attack_embd_type == 'condition_embd': + input_adv_condition_embedding = construct_embd( + self.k, adv_embedding, attack_type, sot_embd, mid_embd, eot_embd + ) + adv_input_ids = construct_id( + self.k, replace_id, attack_type, sot_id, eot_id, mid_id + ) + + print(f'[{attack_type}] Starting {attack_method} attack on "{word}"') + + # --- Attack Loop --- + for i in range(attack_step): + # Randomly sample a time step for the attack. + t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) + og_num = round((int(t_enc) / self.ddim_steps) * 1000) + og_num_lim = round((int(t_enc + 1) / self.ddim_steps) * 1000) + t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) + start_code = torch.randn((1, 4, 64, 64)).to(self.devices[0]) + + with torch.no_grad(): + # Sample a latent z using the conditional embedding. + z = quick_sample_till_t( + self.emb_p.to(self.devices[0]), self.start_guidance, start_code, int(t_enc) + ) + if self.backend == "compvis": + # For compvis, use apply_model to get the noise predictions. + e_0 = self.model_orig.apply_model( + z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_0.to(self.devices[0]) + ) + e_p = self.model_orig.apply_model( + z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), self.emb_p.to(self.devices[0]) + ) + elif self.backend == "diffusers": + # For diffusers, call the UNet directly with encoder_hidden_states. + out_0 = self.model_orig( + z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + encoder_hidden_states=self.emb_0.to(self.devices[0]) + ) + e_0 = out_0.sample if hasattr(out_0, "sample") else out_0 + out_p = self.model_orig( + z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + encoder_hidden_states=self.emb_p.to(self.devices[0]) + ) + e_p = out_p.sample if hasattr(out_p, "sample") else out_p + else: + raise ValueError(f"Unknown backend: {self.backend}") + + # For word_embd attack type, update the adversarial condition embedding using the text encoder. + if attack_embd_type == 'word_embd': + input_adv_word_embedding = construct_embd( + self.k, adv_embedding, attack_type, sot_embd, mid_embd, eot_embd + ) + adv_input_ids = construct_id( + self.k, replace_id, attack_type, sot_id, eot_id, mid_id + ) + input_adv_condition_embedding = self.text_encoder( + input_ids=adv_input_ids.to(self.devices[0]), + inputs_embeds=input_adv_word_embedding + )[0] + + # Get the conditional score from the model with the adversarial condition embedding. + if self.backend == "compvis": + e_n = self.model.apply_model( + z.to(self.devices[0]), t_enc_ddpm.to(self.devices[0]), + input_adv_condition_embedding.to(self.devices[0]) + ) + elif self.backend == "diffusers": + out_n = self.model( + z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + encoder_hidden_states=input_adv_condition_embedding.to(self.devices[0]) + ) + e_n = out_n.sample if hasattr(out_n, "sample") else out_n + else: + raise ValueError(f"Unknown backend: {self.backend}") + + # Prevent gradients on the frozen branch. + e_0.requires_grad = False + e_p.requires_grad = False + + # Compute the loss between the adversarial output and the target. + loss = self.criteria(e_n.to(self.devices[0]), e_p.to(self.devices[0])) + loss.backward() + + if attack_method == 'pgd': + attack_opt.step() + elif attack_method == 'fast_at': + adv_embedding.grad.sign_() + attack_opt.step() + else: + raise ValueError('attack_method must be either pgd or fast_at') + + wandb.log({'Attack_Loss': loss.item()}, step=global_step + i) + wandb.log({'Train_Loss': 0.0}, step=global_step + i) + print(f'Step: {global_step + i}, Attack_Loss: {loss.item()}') + print(f'Step: {global_step + i}, Train_Loss: 0.0') + + # --- Return the adversarial embeddings and input IDs --- + if attack_embd_type == 'condition_embd': + return input_adv_condition_embedding, adv_input_ids + elif attack_embd_type == 'word_embd': + return input_adv_word_embedding, adv_input_ids + else: + raise ValueError('attack_embd_type must be either condition_embd or word_embd') diff --git a/mu_attack/configs/adv_unlearn/__init__.py b/mu_attack/configs/adv_unlearn/__init__.py new file mode 100644 index 00000000..9c229329 --- /dev/null +++ b/mu_attack/configs/adv_unlearn/__init__.py @@ -0,0 +1 @@ +from .adv_attack_config import AdvAttackConfig, adv_attack_config \ No newline at end of file diff --git a/mu_attack/configs/adv_unlearn/adv_attack_config.py b/mu_attack/configs/adv_unlearn/adv_attack_config.py new file mode 100644 index 00000000..222217d7 --- /dev/null +++ b/mu_attack/configs/adv_unlearn/adv_attack_config.py @@ -0,0 +1,75 @@ +#mu_attack/configs/adv_unlearn/adv_attack_config.py + +import os +from pathlib import Path +from mu.core.base_config import BaseConfig + +current_dir = Path(__file__).parent + +class AdvAttackConfig(BaseConfig): + def __init__(self, **kwargs): + # Inference & Model Paths for compvis + self.config_path = current_dir / "model_config.yaml" + self.compvis_ckpt_path = "models/sd-v1-4-full-ema.ckpt" + + #model path for custom encoder + self.encoder_model_name_or_path = "CompVis/stable-diffusion-v1-4" + + #for samlping + self.target_ckpt = None + + # Model Paths for diffusers + self.diffusers_model_name_or_path = "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/diffuser/style50" + + # Devices & IO + self.devices = "0,0" # You can later parse this string into a list if needed. + self.seperator = None + self.cache_path = ".cache" + + # Image & Diffusion Sampling + self.start_guidance = 3.0 + self.ddim_steps = 50 + + + # Training Setup + self.image_size = 512 + self.prompt = "nudity" + self.attack_method = "pgd" # Choices: 'pgd', 'multi_pgd', 'fast_at', 'free_at' + self.ddim_eta = 0 + + # Adversarial Attack Hyperparameters + self.adv_prompt_num = 1 + self.attack_init_embd = None + self.attack_embd_type = "word_embd" # Choices: 'word_embd', 'condition_embd' + self.attack_type = "prefix_k" # Choices: 'replace_k', 'add', 'prefix_k', 'suffix_k', 'mid_k', 'insert_k', 'per_k_words' + self.attack_init = "latest" # Choices: 'random', 'latest' + self.attack_step = 30 + self.attack_lr = 1e-3 + + #backend + self.backend = "diffusers" + + #wandb configs + self.project_name = "quick-canvas-machine-unlearning" + self.experiment_name = f'AdvUnlearn-{self.prompt}-method_Attack_{self.attack_method}' + + + # Override default values with any provided keyword arguments. + for key, value in kwargs.items(): + setattr(self, key, value) + + def validate_config(self): + """ + Perform basic validation on the config parameters. + """ + if self.backend not in ["compvis", "diffusers"]: + raise ValueError(f"Backend must be either 'compvis' or 'diffusers'. Got {self.backend}.") + if self.backend == "compvis": + if not os.path.exists(self.compvis_ckpt_path): + raise FileNotFoundError(f"Checkpoint file {self.compvis_ckpt_path} does not exist.") + elif self.backend == "diffusers": + if not os.path.exists(self.diffusers_model_name_or_path): + raise FileNotFoundError(f"Diffusers model {self.diffusers_model_name_or_path} does not exist.") + +adv_attack_config = AdvAttackConfig() + diff --git a/mu_attack/environment.yaml b/mu_attack/environment.yaml index 5b54a553..c14e24a5 100644 --- a/mu_attack/environment.yaml +++ b/mu_attack/environment.yaml @@ -26,9 +26,11 @@ dependencies: - transformers==4.33.2 - opencv-python-headless==4.8.0.76 - einops==0.8.0 + - timm==0.6.7 - taming-transformers-rom1504 - kornia==0.6 - pydantic==2.10.6 + - wandb==0.19.5 - git+https://github.com/Phoveran/fastargs.git@main#egg=fastargs - git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - git+https://github.com/openai/CLIP.git@main#egg=clip diff --git a/mu_attack/execs/adv_attack.py b/mu_attack/execs/adv_attack.py new file mode 100644 index 00000000..bb32924a --- /dev/null +++ b/mu_attack/execs/adv_attack.py @@ -0,0 +1,127 @@ +# mu_attack/execs/adv_attack.py + +import torch +import wandb + +from mu_attack.configs.adv_unlearn import AdvAttackConfig +from mu_attack.attackers.soft_prompt import SoftPromptAttack +from mu_attack.helpers.utils import get_models_for_compvis, get_models_for_diffusers + + +class AdvAttack: + def __init__(self, config: AdvAttackConfig): + self.config = config.__dict__ + # Do not set self.prompt from the config; remove the dependency. + self.encoder_model_name_or_path = config.encoder_model_name_or_path + self.cache_path = config.cache_path + self.devices = [f"cuda:{int(d.strip())}" for d in config.devices.split(",")] + self.attack_type = config.attack_type + self.attack_embd_type = config.attack_embd_type + self.attack_step = config.attack_step + self.attack_lr = config.attack_lr + self.attack_init = config.attack_init + self.attack_init_embd = config.attack_init_embd + self.attack_method = config.attack_method + self.ddim_steps = config.ddim_steps + self.ddim_eta = config.ddim_eta + self.image_size = config.image_size + self.adv_prompt_num = config.adv_prompt_num + self.start_guidance = config.start_guidance + self.config_path = config.config_path + self.compvis_ckpt_path = config.compvis_ckpt_path + self.backend = config.backend + self.diffusers_model_name_or_path = config.diffusers_model_name_or_path + self.target_ckpt = config.target_ckpt + self.criteria = torch.nn.MSELoss() + + # Initialize wandb (if needed) + wandb.init( + project=config.project_name, name=config.experiment_name, reinit=True + ) + + # self.load_models() + + def load_models(self): + if self.backend == "compvis": + self.model_orig, self.sampler_orig, self.model, self.sampler = ( + get_models_for_compvis( + self.config_path, self.compvis_ckpt_path, self.devices + ) + ) + elif self.backend == "diffusers": + self.model_orig, self.sampler_orig, self.model, self.sampler = ( + get_models_for_diffusers( + self.diffusers_model_name_or_path, self.target_ckpt, self.devices + ) + ) + + def attack(self, word, global_step, attack_round): + """ + Perform the adversarial attack using the given word. + + Args: + word (str): The current prompt to attack. + global_step (int): The current global training step. + attack_round (int): The current attack round. + + Returns: + tuple: (adversarial embedding, input_ids) + """ + # Now, use the passed `word` for the attack instead of self.prompt. + # (Everything else in this method remains the same.) + sp_attack = SoftPromptAttack( + model=self.model, + model_orig=self.model_orig, + tokenizer=self.tokenizer, + text_encoder=self.custom_text_encoder, + sampler=self.sampler, + emb_0=self._get_emb_0(), + emb_p=self._get_emb_p(word), + start_guidance=self.start_guidance, + devices=self.devices, + ddim_steps=self.ddim_steps, + ddim_eta=self.ddim_eta, + image_size=self.image_size, + criteria=self.criteria, + k=self.adv_prompt_num, + all_embeddings=self.all_embeddings, + backend=self.backend, + ) + return sp_attack.attack( + global_step, + word, + attack_round, + self.attack_type, + self.attack_embd_type, + self.attack_step, + self.attack_lr, + self.attack_init, + self.attack_init_embd, + self.attack_method, + ) + + # Example helper methods to get embeddings from model_orig. + def _get_emb_0(self): + if self.backend == "compvis": + return self.model_orig.get_learned_conditioning([""]) + else: + # For diffusers, you need to define your own method (e.g., using self.encode_text("")) + return self.encode_text("") + + def _get_emb_p(self, word): + if self.backend == "compvis": + return self.model_orig.get_learned_conditioning([word]) + else: + return self.encode_text(word) + + def encode_text(self, text): + text_inputs = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ).to(self.devices[0]) + with torch.no_grad(): + text_embeddings = self.text_encoder(text_inputs.input_ids)[0] + return text_embeddings diff --git a/mu_attack/helpers/utils.py b/mu_attack/helpers/utils.py index 2d9d89bc..a3e2d16e 100644 --- a/mu_attack/helpers/utils.py +++ b/mu_attack/helpers/utils.py @@ -1,16 +1,66 @@ import os -import torch -from PIL import Image import pandas as pd +import random import yaml -import json -from typing import Optional, Tuple, Union +import torch +import torch.nn.functional as F from torchvision.transforms.functional import InterpolationMode -from transformers.modeling_outputs import BaseModelOutputWithPooling import torchvision.transforms as torch_transforms +from diffusers import UNet2DConditionModel, DDIMScheduler + + +from mu.helpers.utils import load_model_from_config +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler + + + +class PromptDataset: + def __init__(self, csv_file): + self.data = pd.read_csv(csv_file) + self.unseen_indices = list(self.data.index) # 保存所有未见过的索引 + + def get_random_prompts(self, num_prompts=1): + # Ensure that the number of prompts requested is not greater than the number of unseen prompts + num_prompts = min(num_prompts, len(self.unseen_indices)) + + # Randomly select num_prompts indices from the list of unseen indices + selected_indices = random.sample(self.unseen_indices, num_prompts) + + # Remove the selected indices from the list of unseen indices + for index in selected_indices: + self.unseen_indices.remove(index) + # return the prompts corresponding to the selected indices + return self.data.loc[selected_indices, 'prompt'].tolist() + + def has_unseen_prompts(self): + # check if there are any unseen prompts + return len(self.unseen_indices) > 0 + + def reset(self): + self.unseen_indices = list(self.data.index) + + def check_unseen_prompt_count(self): + return len(self.unseen_indices) + + +def retain_prompt(dataset_retain): + # Prompt Dataset to be retained + + if dataset_retain == 'imagenet243': + retain_dataset = PromptDataset('data/prompts/train/imagenet243_retain.csv') + elif dataset_retain == 'imagenet243_no_filter': + retain_dataset = PromptDataset('data/prompts/train/imagenet243_no_filter_retain.csv') + elif dataset_retain == 'coco_object': + retain_dataset = PromptDataset('data/prompts/train/coco_object_retain.csv') + elif dataset_retain == 'coco_object_no_filter': + retain_dataset = PromptDataset('data/prompts/train/coco_object_no_filter_retain.csv') + else: + raise ValueError('Invalid dataset for retaining prompts') + + return retain_dataset def load_config(yaml_path): """Loads the configuration from a YAML file.""" @@ -67,3 +117,277 @@ def convert_time(time_str): total_minutes_direct = hours * 60 + minutes + seconds_microseconds / 60 return total_minutes_direct +def id2embedding(tokenizer, all_embeddings, input_ids, device): + input_one_hot = F.one_hot(input_ids.view(-1), num_classes = len(tokenizer.get_vocab())).float() + input_one_hot = torch.unsqueeze(input_one_hot,0).to(device) + input_embeds = input_one_hot @ all_embeddings + return input_embeds + +def split_id(input_ids, k, orig_prompt_len): + sot_id, mid_id, replace_id, eot_id = torch.split(input_ids, [1, orig_prompt_len, k, 76-orig_prompt_len-k], dim=1) + return sot_id, mid_id, replace_id, eot_id + +def split_embd(input_embed, k, orig_prompt_len): + sot_embd, mid_embd, replace_embd, eot_embd = torch.split(input_embed, [1, orig_prompt_len, k, 76-orig_prompt_len-k ], dim=1) + return sot_embd, mid_embd, replace_embd, eot_embd + +def init_adv(k, tokenizer, all_embeddings, attack_type, device, batch = 1, attack_init_embd = None): + # Different attack types have different initializations (Attack types: add, insert) + adv_embedding = torch.nn.Parameter(torch.randn([batch, k, 768])).to(device) + + if attack_init_embd is not None: + # Use the provided initial adversarial embedding + adv_embedding.data = attack_init_embd[:,1:1+k].data + else: + # Random sample k words from the vocabulary as the initial adversarial words + tmp_ids = torch.randint(0,len(tokenizer),(batch, k)).to(device) + tmp_embeddings = id2embedding(tokenizer, all_embeddings, tmp_ids, device) + tmp_embeddings = tmp_embeddings.reshape(batch, k, 768) + adv_embedding.data = tmp_embeddings.data + adv_embedding = adv_embedding.detach().requires_grad_(True) + + return adv_embedding + +def construct_embd(k, adv_embedding, insertion_location, sot_embd, mid_embd, eot_embd): + if insertion_location == 'prefix_k': # Prepend k words before the original prompt + embedding = torch.cat([sot_embd,adv_embedding,mid_embd,eot_embd],dim=1) + elif insertion_location == 'replace_k': # Replace k words in the original prompt + replace_embd = eot_embd[:,0,:].repeat(1,mid_embd.shape[1],1) + embedding = torch.cat([sot_embd,adv_embedding,replace_embd,eot_embd],dim=1) + elif insertion_location == 'add': # Add perturbation to the original prompt + replace_embd = eot_embd[:,0,:].repeat(1,k,1) + embedding = torch.cat([sot_embd,adv_embedding+mid_embd,replace_embd,eot_embd],dim=1) + elif insertion_location == 'suffix_k': # Append k words after the original prompt + embedding = torch.cat([sot_embd,mid_embd,adv_embedding,eot_embd],dim=1) + elif insertion_location == 'mid_k': # Insert k words in the middle of the original prompt + embedding = [sot_embd,] + total_num = mid_embd.size(1) + embedding.append(mid_embd[:,:total_num//2,:]) + embedding.append(adv_embedding) + embedding.append(mid_embd[:,total_num//2:,:]) + embedding.append(eot_embd) + embedding = torch.cat(embedding,dim=1) + elif insertion_location == 'insert_k': # seperate k words into the original prompt with equal intervals + embedding = [sot_embd,] + total_num = mid_embd.size(1) + internals = total_num // (k+1) + for i in range(k): + embedding.append(mid_embd[:,internals*i:internals*(i+1),:]) + embedding.append(adv_embedding[:,i,:].unsqueeze(1)) + embedding.append(mid_embd[:,internals*(i+1):,:]) + embedding.append(eot_embd) + embedding = torch.cat(embedding,dim=1) + + elif insertion_location == 'per_k_words': + embedding = [sot_embd,] + for i in range(adv_embedding.size(1) - 1): + embedding.append(adv_embedding[:,i,:].unsqueeze(1)) + embedding.append(mid_embd[:,3*i:3*(i+1),:]) + embedding.append(adv_embedding[:,-1,:].unsqueeze(1)) + embedding.append(mid_embd[:,3*(i+1):,:]) + embedding.append(eot_embd) + embedding = torch.cat(embedding,dim=1) + return embedding + +def construct_id(k, adv_id, insertion_location,sot_id,eot_id,mid_id): + if insertion_location == 'prefix_k': + input_ids = torch.cat([sot_id,adv_id,mid_id,eot_id],dim=1) + + elif insertion_location == 'replace_k': + replace_id = eot_id[:,0].repeat(1,mid_id.shape[1]) + input_ids = torch.cat([sot_id,adv_id,replace_id,eot_id],dim=1) + + elif insertion_location == 'add': + replace_id = eot_id[:,0].repeat(1,k) + input_ids = torch.cat([sot_id,mid_id,replace_id,eot_id],dim=1) + + elif insertion_location == 'suffix_k': + input_ids = torch.cat([sot_id,mid_id,adv_id,eot_id],dim=1) + + elif insertion_location == 'mid_k': + input_ids = [sot_id,] + total_num = mid_id.size(1) + input_ids.append(mid_id[:,:total_num//2]) + input_ids.append(adv_id) + input_ids.append(mid_id[:,total_num//2:]) + input_ids.append(eot_id) + input_ids = torch.cat(input_ids,dim=1) + + elif insertion_location == 'insert_k': + input_ids = [sot_id,] + total_num = mid_id.size(1) + internals = total_num // (k+1) + for i in range(k): + input_ids.append(mid_id[:,internals*i:internals*(i+1)]) + input_ids.append(adv_id[:,i].unsqueeze(1)) + input_ids.append(mid_id[:,internals*(i+1):]) + input_ids.append(eot_id) + input_ids = torch.cat(input_ids,dim=1) + + elif insertion_location == 'per_k_words': + input_ids = [sot_id,] + for i in range(adv_id.size(1) - 1): + input_ids.append(adv_id[:,i].unsqueeze(1)) + input_ids.append(mid_id[:,3*i:3*(i+1)]) + input_ids.append(adv_id[:,-1].unsqueeze(1)) + input_ids.append(mid_id[:,3*(i+1):]) + input_ids.append(eot_id) + input_ids = torch.cat(input_ids,dim=1) + return input_ids + + + +def get_models_for_compvis(config_path, compvis_ckpt_path, devices): + model_orig = load_model_from_config(config_path, compvis_ckpt_path, devices[1]) + sampler_orig = DDIMSampler(model_orig) + + model = load_model_from_config(config_path, compvis_ckpt_path, devices[0]) + sampler = DDIMSampler(model) + + return model_orig, sampler_orig, model, sampler + +def get_models_for_diffusers(diffuser_model_name_or_path, target_ckpt, devices, cache_path=None): + """ + Loads two copies of a Diffusers UNet model along with their DDIM schedulers. + + Args: + model_name_or_path (str): The Hugging Face model identifier or local path. + target_ckpt (str or None): Path to a target checkpoint to load into the primary model (on devices[0]). + If None, no state dict is loaded. + devices (list or tuple): A list/tuple of two devices, e.g. [device0, device1]. + cache_path (str or None): Optional cache directory for pretrained weights. + + Returns: + model_orig: The UNet loaded on devices[1]. + sampler_orig: The DDIM scheduler corresponding to model_orig. + model: The UNet loaded on devices[0] (optionally updated with target_ckpt). + sampler: The DDIM scheduler corresponding to model. + """ + + # Load the original model (used for e.g. computing loss, etc.) on devices[1] + model_orig = UNet2DConditionModel.from_pretrained( + diffuser_model_name_or_path, + subfolder="unet", + cache_dir=cache_path + ).to(devices[1]) + + # Create a DDIM scheduler for model_orig. (Note: diffusers DDIMScheduler is used here; + # adjust the subfolder or configuration if your scheduler is stored elsewhere.) + sampler_orig = DDIMScheduler.from_pretrained( + diffuser_model_name_or_path, + subfolder="scheduler", + cache_dir=cache_path + ) + + # Load the second copy of the model on devices[0] + model = UNet2DConditionModel.from_pretrained( + diffuser_model_name_or_path, + subfolder="unet", + cache_dir=cache_path + ).to(devices[0]) + + # Optionally load a target checkpoint into model + if target_ckpt is not None: + state_dict = torch.load(target_ckpt, map_location=devices[0]) + model.load_state_dict(state_dict) + + sampler = DDIMScheduler.from_pretrained( + diffuser_model_name_or_path, + subfolder="scheduler", + cache_dir=cache_path + ) + + return model_orig, sampler_orig, model, sampler + +@torch.no_grad() +def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True): + """Sample the model""" + uc = None + if scale != 1.0: + uc = model.get_learned_conditioning(n_samples * [""]) + log_t = 100 + if log_every_t is not None: + log_t = log_every_t + shape = [4, h // 8, w // 8] + samples_ddim, inters = sampler.sample(S=ddim_steps, + conditioning=c, + batch_size=n_samples, + shape=shape, + verbose=False, + x_T=start_code, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + verbose_iter = verbose, + t_start=t_start, + log_every_t = log_t, + till_T = till_T + ) + if log_every_t is not None: + return samples_ddim, inters + return samples_ddim + +@torch.no_grad() +def sample_model_for_diffuser(model, scheduler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, + n_samples=1, t_start=-1, log_every_t=None, till_T=None, verbose=True): + """ + Diffusers-compatible sampling function. + + Args: + model: The UNet model (from diffusers). + scheduler: A DDIMScheduler (or similar) instance. + c (torch.Tensor): The conditional encoder_hidden_states. + h (int): Image height. + w (int): Image width. + ddim_steps (int): Number of diffusion steps. + scale (float): Guidance scale. If not 1.0, classifier-free guidance is applied. + ddim_eta (float): The eta parameter for DDIM (unused in this basic implementation). + start_code (torch.Tensor, optional): Starting latent code. If None, random noise is used. + n_samples (int): Number of samples to generate. + t_start, log_every_t, till_T, verbose: Additional parameters (not used in this diffusers implementation). + + Returns: + torch.Tensor: The generated latent sample. + """ + device = c.device + + # If no starting code is provided, sample random noise. + if start_code is None: + start_code = torch.randn((n_samples, 4, h // 8, w // 8), device=device) + latents = start_code + + # Set the number of timesteps in the scheduler. + scheduler.set_timesteps(ddim_steps) + + # If using classifier-free guidance, prepare unconditional embeddings. + if scale != 1.0: + # In a full implementation you would obtain these from your text encoder + # For this example, we simply create a tensor of zeros with the same shape as c. + uc = torch.zeros_like(c) + # Duplicate latents and conditioning for guidance. + latents = torch.cat([latents, latents], dim=0) + c_in = torch.cat([uc, c], dim=0) + else: + c_in = c + + # Diffusion sampling loop. + for t in scheduler.timesteps: + # Scale the latents as required by the scheduler. + latent_model_input = scheduler.scale_model_input(latents, t) + model_output = model(latent_model_input, t, encoder_hidden_states=c_in) + # Assume model_output is a ModelOutput with a 'sample' attribute. + if scale != 1.0: + # Split the batch into unconditional and conditional parts. + noise_pred_uncond, noise_pred_text = model_output.sample.chunk(2) + # Apply classifier-free guidance. + noise_pred = noise_pred_uncond + scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = model_output.sample + + # Step the scheduler. + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # If guidance was used, return only the second half of the batch. + if scale != 1.0: + latents = latents[n_samples:] + return latents \ No newline at end of file diff --git a/mu_defense/.gitignore b/mu_defense/.gitignore new file mode 100644 index 00000000..aa850f42 --- /dev/null +++ b/mu_defense/.gitignore @@ -0,0 +1 @@ +src/* \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/README.md b/mu_defense/algorithms/adv_unlearn/README.md new file mode 100644 index 00000000..873e2458 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/README.md @@ -0,0 +1,25 @@ +```python +from mu_defense.algorithms.adv_unlearn.algorithm import AdvUnlearnAlgorithm +from mu_defense.algorithms.adv_unlearn.configs import adv_unlearn_config +from mu.algorithms.erase_diff.configs import erase_diff_train_mu + + +def mu_defense(): + + mu_defense = AdvUnlearnAlgorithm( + config=adv_unlearn_config, + compvis_ckpt_path = "/home/ubuntu/Projects/dipesh/unlearn_diff/outputs/erase_diff/erase_diff_Abstractionism_model.pth", + # diffusers_model_name_or_path = "/home/ubuntu/Projects/dipesh/unlearn_diff/outputs/forget_me_not/finetuned_models/Abstractionism", + attack_step = 2, + backend = "compvis", + attack_method = "fast_at", + model_config_path = erase_diff_train_mu.model_config_path + + + ) + mu_defense.run() + +if __name__ == "__main__": + mu_defense() + +``` \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/__init__.py b/mu_defense/algorithms/adv_unlearn/__init__.py new file mode 100644 index 00000000..0d25e5d7 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/__init__.py @@ -0,0 +1,13 @@ +from .utils import * +from .model import AdvUnlearnModel +from .dataset_handler import AdvUnlearnDatasetHandler +from .compvis_trainer import AdvUnlearnCompvisTrainer +# from .algorithm import AdvUnlearnAlgorithm +# from .trainer import AdvUnlearnTrainer + +__all__ = ["AdvUnlearnModel", + "AdvUnlearnDatasetHandler", + "AdvUnlearnCompvisTrainer", + # "AdvUnlearnAlgorithm", + # "AdvUnlearnTrainer" +] \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/algorithm.py b/mu_defense/algorithms/adv_unlearn/algorithm.py new file mode 100644 index 00000000..09ef54ba --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/algorithm.py @@ -0,0 +1,99 @@ +# mu/algorithms/adv_unlearn/algorithm.py + +from mu.core.base_config import BaseConfig +import wandb +import logging +from pathlib import Path + +from mu.core import BaseAlgorithm +from mu_defense.algorithms.adv_unlearn import AdvUnlearnModel +from mu_defense.algorithms.adv_unlearn.trainer import AdvUnlearnTrainer +from mu_defense.algorithms.adv_unlearn.configs import AdvUnlearnConfig + + +class AdvUnlearnAlgorithm(BaseAlgorithm): + """ + AdvUnlearnAlgorithm orchestrates the adversarial unlearning training process. + It sets up the model and trainer components and then runs the training loop. + """ + + def __init__(self, config: AdvUnlearnConfig, **kwargs): + # Update configuration with additional kwargs. + for key, value in kwargs.items(): + if not hasattr(config, key): + setattr(config, key, value) + continue + config_attr = getattr(config, key) + if isinstance(config_attr, BaseConfig) and isinstance(value, dict): + for sub_key, sub_val in value.items(): + setattr(config_attr, sub_key, sub_val) + elif isinstance(config_attr, dict) and isinstance(value, dict): + config_attr.update(value) + else: + setattr(config, key, value) + self.config = config.to_dict() + + # Validate and update config. + config.validate_config() + self.config = config.to_dict() + self.model = None + self.trainer = None + self.devices = self.config.get("devices") + self.devices = [f"cuda:{int(d.strip())}" for d in self.devices.split(",")] + self.logger = logging.getLogger(__name__) + self._setup_components() + + def _setup_components(self): + """ + Setup model and trainer components. + """ + self.logger.info("Setting up components for adversarial unlearning training...") + + # Initialize Model + self.model = AdvUnlearnModel(config=self.config) + + # Initialize Trainer + self.trainer = AdvUnlearnTrainer( + model=self.model, + config=self.config, + devices=self.devices, + ) + self.trainer.trainer.adv_attack.model_orig = self.model.model_orig + self.trainer.trainer.adv_attack.sampler_orig = self.model.sampler_orig + self.trainer.trainer.adv_attack.model = self.model.model + self.trainer.trainer.adv_attack.sampler = self.model.sampler + + def run(self): + """ + Execute the training process. + """ + try: + # Initialize WandB with configurable project/run names. + wandb_config = { + "project": self.config.get("wandb_project", "adv-unlearn-project"), + "name": self.config.get("wandb_run", "Adv Unlearn Training"), + "config": self.config, + } + wandb.init(**wandb_config) + self.logger.info("Initialized WandB for logging.") + + # Create output directory if it doesn't exist. + output_dir = Path(self.config.get("output_dir", "./outputs")) + output_dir.mkdir(parents=True, exist_ok=True) + + try: + # Start training. + self.trainer.run() + except Exception as e: + self.logger.error(f"Error during training: {str(e)}") + raise + + except Exception as e: + self.logger.error(f"Failed to initialize training: {str(e)}") + raise + + finally: + # Ensure WandB always finishes. + if wandb.run is not None: + wandb.finish() + self.logger.info("Training complete. WandB logging finished.") diff --git a/mu_defense/algorithms/adv_unlearn/compvis_trainer.py b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py new file mode 100644 index 00000000..1621b3fd --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/compvis_trainer.py @@ -0,0 +1,367 @@ +# mu_defense/algorithms/adv_unlearn/compvis_trainer.py + +import torch +from tqdm import tqdm +import random +import wandb +import logging +from torch.nn import MSELoss + +from mu.core import BaseTrainer +from mu_defense.algorithms.adv_unlearn import ( + id2embedding, + param_choices, + get_train_loss_retain, + save_text_encoder, + save_history, + sample_model +) + + +from mu_attack.execs.adv_attack import AdvAttack +from mu_attack.configs.adv_unlearn import AdvAttackConfig +from mu_defense.algorithms.adv_unlearn import AdvUnlearnDatasetHandler + + +class AdvUnlearnCompvisTrainer(BaseTrainer): + """ + Trainer for adversarial unlearning. + + This trainer performs the adversarial prompt update and retention-based + regularized training loop for CompVis/Diffusers models. + """ + def __init__(self, model, config: dict, devices: list, **kwargs): + """ + Initialize the AdvUnlearnCompvisTrainer. + """ + super().__init__(model, config, **kwargs) + self.devices = devices + + # Unpack models and samplers from the provided model loader. + self.model = model.model # trainable diffusion model + self.model_orig = model.model_orig # frozen diffusion model (set to eval) + self.sampler = model.sampler + self.sampler_orig = model.sampler_orig + self.model_loader = model + + # Other loaded components. + self.tokenizer = model.tokenizer + self.custom_text_encoder = model.custom_text_encoder + self.all_embeddings = model.all_embeddings + + # Loss criterion. + self.criteria = MSELoss() + + # Save configuration parameters. + self.config = config + self.prompt = self.config['prompt'] + self.seperator = self.config.get('seperator') + self.iterations = self.config.get('iterations') + self.ddim_steps = self.config['ddim_steps'] + self.start_guidance = self.config['start_guidance'] + self.negative_guidance = self.config['negative_guidance'] + self.image_size = self.config['image_size'] + self.lr = self.config['lr'] + self.model_config_path = self.config['model_config_path'] + self.output_dir = self.config['output_dir'] + + # Retention and attack parameters. + self.dataset_retain = self.config['dataset_retain'] + self.retain_batch = self.config['retain_batch'] + self.retain_train = self.config['retain_train'] + self.retain_step = self.config['retain_step'] + self.retain_loss_w = self.config['retain_loss_w'] + self.attack_method = self.config['attack_method'] + self.train_method = self.config['train_method'] + self.norm_layer = self.config['norm_layer'] + self.component = self.config['component'] + self.adv_prompt_num = self.config['adv_prompt_num'] + self.attack_embd_type = self.config['attack_embd_type'] + self.attack_type = self.config['attack_type'] + self.attack_init = self.config['attack_init'] + self.warmup_iter = self.config['warmup_iter'] + self.attack_step = self.config['attack_step'] + self.attack_lr = self.config['attack_lr'] + self.adv_prompt_update_step = self.config['adv_prompt_update_step'] + self.ddim_eta = self.config['ddim_eta'] + + self.logger = logging.getLogger(__name__) + + attack_config = AdvAttackConfig( + prompt="", # prompt is no longer used in __init__ + encoder_model_name_or_path=self.tokenizer.name_or_path, + cache_path=config.get("cache_path", "./cache"), + devices=",".join([d.strip() for d in config.get("devices", "cuda:0").split(',')]), + attack_type=config['attack_type'], + attack_embd_type=config['attack_embd_type'], + attack_step=config['attack_step'], + attack_lr=config['attack_lr'], + attack_init=config['attack_init'], + attack_init_embd=None, # adjust as needed + attack_method=config['attack_method'], + ddim_steps=config['ddim_steps'], + ddim_eta=config['ddim_eta'], + image_size=config['image_size'], + adv_prompt_num=config['adv_prompt_num'], + start_guidance=config['start_guidance'], + config_path=config['model_config_path'], + compvis_ckpt_path=config.get("compvis_ckpt_path", ""), + backend="compvis", + diffusers_model_name_or_path="", + target_ckpt="", + project=config.get("project_name", "default_project"), + experiment_name=config.get("experiment_name", "default_experiment") + ) + self.adv_attack = AdvAttack(attack_config) + # Inject the preloaded objects + self.adv_attack.tokenizer = self.tokenizer + self.adv_attack.text_encoder = self.custom_text_encoder.text_encoder + self.adv_attack.custom_text_encoder = self.custom_text_encoder + self.adv_attack.all_embeddings = self.all_embeddings + + + # Setup the dataset handler and prompt cleaning. + self.dataset_handler = AdvUnlearnDatasetHandler( + prompt=self.prompt, + seperator=self.seperator, + dataset_retain=self.dataset_retain + ) + self.words, self.word_print = self.dataset_handler.setup_prompt() + self.retain_dataset = self.dataset_handler.setup_dataset() + + # Initialize adversarial prompt variables. + self.adv_word_embd = None + self.adv_condition_embd = None + self.adv_input_ids = None + + # Setup trainable parameters and optimizer. + self._setup_optimizer() + + def _setup_optimizer(self): + """ + Set up the optimizer based on the training method. + """ + if 'text_encoder' in self.train_method: + self.parameters = param_choices( + model=self.custom_text_encoder, + train_method=self.train_method, + component=self.component, + final_layer_norm=self.norm_layer + ) + else: + self.parameters = param_choices( + model=self.model, + train_method=self.train_method, + component=self.component, + final_layer_norm=self.norm_layer + ) + self.optimizer = torch.optim.Adam(self.parameters, lr=float(self.lr)) + + def train(self): + """ + Execute the adversarial unlearning training loop. + """ + ddim_eta = self.ddim_eta + quick_sample_till_t = lambda x, s, code, batch, t: sample_model( + self.model, self.sampler, + x, self.image_size, self.image_size, self.ddim_steps, s, ddim_eta, + start_code=code, n_samples=batch, till_T=t, verbose=False + ) + losses = [] + history = [] + global_step = 0 + attack_round = 0 + + pbar = tqdm(range(self.iterations)) + for i in pbar: + if i % self.adv_prompt_update_step == 0: + if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: + self.retain_dataset.reset() + word = random.choice(self.words) + text_input = self.tokenizer( + word, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + truncation=True + ) + text_embeddings = id2embedding( + self.tokenizer, + self.all_embeddings, + text_input.input_ids.to(self.devices[0]), + self.devices[0] + ) + # Obtain the unconditional and conditional embeddings via the original model. + emb_0 = self.model_orig.get_learned_conditioning(['']) + emb_p = self.model_orig.get_learned_conditioning([word]) + + if i >= self.warmup_iter: + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + self.model.eval() + + adv_word_embd, adv_input_ids = self.adv_attack.attack(word, global_step, attack_round) + + if self.attack_embd_type == 'word_embd': + self.adv_word_embd, self.adv_input_ids = adv_word_embd, adv_input_ids + elif self.attack_embd_type == 'condition_embd': + self.adv_condition_embd, self.adv_input_ids = adv_word_embd, adv_input_ids + + global_step += self.attack_step + attack_round += 1 + + if 'text_encoder' in self.train_method: + self.custom_text_encoder.text_encoder.train() + self.custom_text_encoder.text_encoder.requires_grad_(True) + self.model.eval() + else: + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + self.model.train() + + self.optimizer.zero_grad() + + if self.retain_train == 'reg': + retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) + retain_text_input = self.tokenizer( + retain_words, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + truncation=True + ) + retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) + retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) + retain_text_embeddings = id2embedding( + self.tokenizer, + self.all_embeddings, + retain_text_input.input_ids.to(self.devices[0]), + self.devices[0] + ) + retain_text_embeddings = retain_text_embeddings.reshape( + self.retain_batch, -1, retain_text_embeddings.shape[-1] + ) + retain_emb_n = self.custom_text_encoder( + input_ids=retain_input_ids, + inputs_embeds=retain_text_embeddings + )[0] + else: + retain_emb_p = None + retain_emb_n = None + + if i < self.warmup_iter: + input_ids = text_input.input_ids.to(self.devices[0]) + emb_n = self.custom_text_encoder( + input_ids=input_ids, + inputs_embeds=text_embeddings + )[0] + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, input_ids, self.attack_embd_type + ) + else: + if self.attack_embd_type == 'word_embd': + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, + self.adv_word_embd + ) + elif self.attack_embd_type == 'condition_embd': + loss = get_train_loss_retain( + self.retain_batch, self.retain_train, self.retain_loss_w, + self.model, self.model_orig, self.custom_text_encoder, self.sampler, + emb_0, emb_p, retain_emb_p, None, retain_emb_n, self.start_guidance, + self.negative_guidance, self.devices, self.ddim_steps, ddim_eta, + self.image_size, self.criteria, self.adv_input_ids, self.attack_embd_type, + self.adv_condition_embd + ) + loss.backward() + losses.append(loss.item()) + pbar.set_postfix({"loss": loss.item()}) + history.append(loss.item()) + wandb.log({'Train_Loss': loss.item()}, step=global_step) + wandb.log({'Attack_Loss': 0.0}, step=global_step) + global_step += 1 + self.optimizer.step() + + if self.retain_train == 'iter': + for r in range(self.retain_step): + self.optimizer.zero_grad() + if self.retain_dataset.check_unseen_prompt_count() < self.retain_batch: + self.retain_dataset.reset() + retain_words = self.retain_dataset.get_random_prompts(self.retain_batch) + t_enc = torch.randint(self.ddim_steps, (1,), device=self.devices[0]) + og_num = round((int(t_enc.item()) / self.ddim_steps) * 1000) + og_num_lim = round(((int(t_enc.item()) + 1) / self.ddim_steps) * 1000) + t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=self.devices[0]) + retain_start_code = torch.randn((self.retain_batch, 4, 64, 64)).to(self.devices[0]) + retain_emb_p = self.model_orig.get_learned_conditioning(retain_words) + retain_z = quick_sample_till_t( + retain_emb_p.to(self.devices[0]), + self.start_guidance, + retain_start_code, + self.retain_batch, + int(t_enc.item()) + ) + retain_e_p = self.model_orig.apply_model( + retain_z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + retain_emb_p.to(self.devices[0]) + ) + retain_text_input = self.tokenizer( + retain_words, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + truncation=True + ) + retain_input_ids = retain_text_input.input_ids.to(self.devices[0]) + retain_text_embeddings = id2embedding( + self.tokenizer, + self.all_embeddings, + retain_text_input.input_ids.to(self.devices[0]), + self.devices[0] + ) + retain_text_embeddings = retain_text_embeddings.reshape( + self.retain_batch, -1, retain_text_embeddings.shape[-1] + ) + retain_emb_n = self.custom_text_encoder( + input_ids=retain_input_ids, + inputs_embeds=retain_text_embeddings + )[0] + retain_e_n = self.model.apply_model( + retain_z.to(self.devices[0]), + t_enc_ddpm.to(self.devices[0]), + retain_emb_n.to(self.devices[0]) + ) + retain_loss = self.criteria( + retain_e_n.to(self.devices[0]), + retain_e_p.to(self.devices[0]) + ) + retain_loss.backward() + self.optimizer.step() + + if (i + 1) % self.config['save_interval'] == 0 and (i + 1) != self.iterations and (i + 1) >= self.config['save_interval']: + if 'text_encoder' in self.train_method: + save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) + else: + output_path = f"{self.output_dir}/models/model_checkpoint_{i}.pt" + self.model_loader.save_model(self.model, output_path) + save_history(self.output_dir, losses, self.word_print) + + self.model.eval() + self.custom_text_encoder.text_encoder.eval() + self.custom_text_encoder.text_encoder.requires_grad_(False) + if 'text_encoder' in self.train_method: + save_text_encoder(self.output_dir, self.custom_text_encoder, self.train_method, i) + else: + output_path = f"{self.output_dir}/models/model_checkpoint_{i}.pt" + self.model_loader.save_model(self.model, output_path) + save_history(self.output_dir, losses, self.word_print) + return self.model diff --git a/mu_defense/algorithms/adv_unlearn/configs/__init__.py b/mu_defense/algorithms/adv_unlearn/configs/__init__.py new file mode 100644 index 00000000..cc01a1ec --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/configs/__init__.py @@ -0,0 +1 @@ +from .adv_unlearn_config import AdvUnlearnConfig, adv_unlearn_config \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py b/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py new file mode 100644 index 00000000..daa3211d --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py @@ -0,0 +1,79 @@ +#mu_defense/algorithms/adv_unlearn/configs/adv_unlearn_config.py + +import os +from pathlib import Path +from mu_defense.core.base_config import BaseConfig + + +class AdvUnlearnConfig(BaseConfig): + def __init__(self, **kwargs): + # Inference & Model Paths + self.model_config_path = "configs/stable-diffusion/v1-inference.yaml" #for compvis + self.compvis_ckpt_path = "models/sd-v1-4-full-ema.ckpt" + self.encoder_model_name_or_path = "CompVis/stable-diffusion-v1-4" + self.cache_path = ".cache" + + self.diffusers_model_name_or_path = "" + self.target_ckpt = None #Optionally load a target checkpoint into model for diffuser sampling + + # Devices & IO + self.devices = "0,0" # You can later parse this string into a list if needed. + self.seperator = None + self.output_dir = "outputs/adv_unlearn" + + # Image & Diffusion Sampling + self.image_size = 512 + self.ddim_steps = 50 + self.start_guidance = 3.0 + self.negative_guidance = 1.0 + + # Training Setup + self.prompt = "nudity" + self.dataset_retain = "coco_object" # Choices: 'coco_object', 'coco_object_no_filter', 'imagenet243', 'imagenet243_no_filter' + self.retain_batch = 5 + self.retain_train = "iter" # Options: 'iter' or 'reg' + self.retain_step = 1 + self.retain_loss_w = 1.0 + self.ddim_eta = 0 + + self.train_method = "text_encoder_full" #choices: text_encoder_full', 'text_encoder_layer0', 'text_encoder_layer01', 'text_encoder_layer012', 'text_encoder_layer0123', 'text_encoder_layer01234', 'text_encoder_layer012345', 'text_encoder_layer0123456', 'text_encoder_layer01234567', 'text_encoder_layer012345678', 'text_encoder_layer0123456789', 'text_encoder_layer012345678910', 'text_encoder_layer01234567891011', 'text_encoder_layer0_11','text_encoder_layer01_1011', 'text_encoder_layer012_91011', 'noxattn', 'selfattn', 'xattn', 'full', 'notime', 'xlayer', 'selflayer + self.norm_layer = False # This is a flag; use True if you wish to update the norm layer. + self.attack_method = "pgd" # Choices: 'pgd', 'multi_pgd', 'fast_at', 'free_at' + self.component = "all" # Choices: 'all', 'ffn', 'attn' + self.iterations = 10 + self.save_interval = 200 + self.lr = 1e-5 + + # Adversarial Attack Hyperparameters + self.adv_prompt_num = 1 + self.attack_embd_type = "word_embd" # Choices: 'word_embd', 'condition_embd' + self.attack_type = "prefix_k" # Choices: 'replace_k', 'add', 'prefix_k', 'suffix_k', 'mid_k', 'insert_k', 'per_k_words' + self.attack_init = "latest" # Choices: 'random', 'latest' + self.attack_step = 30 + self.adv_prompt_update_step = 1 + self.attack_lr = 1e-3 + self.warmup_iter = 200 + + #backend + self.backend = "compvis" + + # Override default values with any provided keyword arguments. + for key, value in kwargs.items(): + setattr(self, key, value) + + def validate_config(self): + """ + Perform basic validation on the config parameters. + """ + if self.retain_batch <= 0: + raise ValueError("retain_batch should be a positive integer.") + if self.lr <= 0: + raise ValueError("Learning rate (lr) should be positive.") + if self.image_size <= 0: + raise ValueError("Image size should be a positive integer.") + if self.iterations <= 0: + raise ValueError("Iterations must be a positive integer.") + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + +adv_unlearn_config = AdvUnlearnConfig() \ No newline at end of file diff --git a/mu_defense/algorithms/adv_unlearn/dataset_handler.py b/mu_defense/algorithms/adv_unlearn/dataset_handler.py new file mode 100644 index 00000000..fb51bf70 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/dataset_handler.py @@ -0,0 +1,50 @@ + +# mu_defense/algorithms/adv_unlearn/dataset_handler.py + +import logging + +from mu_defense.core import BaseDatasetHandler +from mu_defense.algorithms.adv_unlearn.utils import retain_prompt + + +class AdvUnlearnDatasetHandler(BaseDatasetHandler): + """ + AdvUnlearnDatasetHandler handles prompt cleaning and retaining dataset creation + specifically for adversarial unlearning. + """ + def __init__(self,prompt, seperator, dataset_retain): + self.prompt = prompt + self.seperator = seperator + self.dataset_retain = dataset_retain + + self.logger = logging.getLogger(__name__) + + + def setup_prompt(self): + # Special-case prompt cleaning. + if self.prompt == 'allartist': + self.prompt = ("Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, " + "Alena Aenami, Tyler Edlin, Kilian Eng") + elif self.prompt == 'i2p': + self.prompt = ("hate, harassment, violence, suffering, humiliation, harm, " + "suicide, sexual, nudity, bodily fluids, blood") + elif self.prompt == "artifact": + self.prompt = ("ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, " + "mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, " + "body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy") + + if self.seperator: + self.words = [w.strip() for w in self.prompt.split(self.seperator)] + else: + self.words = [self.prompt] + + self.word_print = self.prompt.replace(" ", "") + self.logger.info(f"The Concept Prompt to be unlearned: {self.words}") + return self.words, self.word_print + + def setup_dataset(self): + """ + Create and return the retaining dataset using the helper function. + """ + dataset = retain_prompt(self.dataset_retain) + return dataset diff --git a/mu_defense/algorithms/adv_unlearn/evaluator.py b/mu_defense/algorithms/adv_unlearn/evaluator.py new file mode 100644 index 00000000..e69de29b diff --git a/mu_defense/algorithms/adv_unlearn/model.py b/mu_defense/algorithms/adv_unlearn/model.py new file mode 100644 index 00000000..f156f17a --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/model.py @@ -0,0 +1,91 @@ +# mu_defense/algorithms/adv_unlearn/model.py + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from mu_defense.core import BaseModel +from mu_attack.tasks.utils.text_encoder import CustomTextEncoder +from mu_defense.algorithms.adv_unlearn.utils import get_models_for_compvis, get_models_for_diffusers + + +class AdvUnlearnModel(BaseModel): + def __init__(self, config: dict): + super().__init__() + self.encoder_model_name_or_path = config.get("encoder_model_name_or_path") + self.model_config_path = config.get("model_config_path") + self.compvis_ckpt_path = config.get("compvis_ckpt_path") + + self.diffusers_model_name_or_path = config.get("diffusers_model_name_or_path") + self.target_ckpt = config.get("target_ckpt") + + self.cache_path = config.get("cache_path") + devices = config.get("devices") + if isinstance(devices, str): + self.devices = [f'cuda:{int(d.strip())}' for d in devices.split(',')] + elif isinstance(devices, list): + self.devices = devices + else: + raise ValueError("devices must be a comma-separated string or a list") + + self.backend = config.get("backend") + + self.load_model() + + def load_model(self): + # Load tokenizer + self.tokenizer = CLIPTokenizer.from_pretrained( + self.encoder_model_name_or_path, + subfolder="tokenizer", + cache_dir=self.cache_path + ) + # Load text encoder and wrap it + self.text_encoder = CLIPTextModel.from_pretrained( + self.encoder_model_name_or_path, + subfolder="text_encoder", + cache_dir=self.cache_path + ).to(self.devices[0]) + self.custom_text_encoder = CustomTextEncoder(self.text_encoder).to(self.devices[0]) + self.all_embeddings = self.custom_text_encoder.get_all_embedding().unsqueeze(0) + + # Load diffusion models + if self.backend == "compvis": + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_compvis( + self.model_config_path, + self.compvis_ckpt_path, + self.devices + ) + + elif self.backend == "diffusers": + self.model_orig, self.sampler_orig, self.model, self.sampler = get_models_for_diffusers( + self.diffusers_model_name_or_path, self.devices, self.target_ckpt + ) + + + def save_model(self, model: torch.nn.Module, output_path: str) -> None: + """ + Save the model's state dictionary. + + Args: + model (torch.nn.Module): The model to be saved. + output_path (str): The file path where the model checkpoint will be stored. + """ + if self.backend == "compvis": + torch.save({"state_dict": model.state_dict()}, output_path) + + elif self.backend == "diffusers": + model.save_pretrained(output_path) + + + def apply_model(self, z: torch.Tensor, t: torch.Tensor, c): + """ + Apply the diffusion model to produce an output. + + Args: + z (torch.Tensor): Noisy latent vectors. + t (torch.Tensor): Timestep tensor. + c: Conditioning tensors. + + Returns: + torch.Tensor: The output of the diffusion model. + """ + return self.model.apply_model(z, t, c) diff --git a/mu_defense/algorithms/adv_unlearn/trainer.py b/mu_defense/algorithms/adv_unlearn/trainer.py new file mode 100644 index 00000000..b7146c11 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/trainer.py @@ -0,0 +1,37 @@ +# mu_defense/algorithms/adv_unlearn/trainer.py + +import logging + +from mu_defense.core import BaseTrainer +from mu_defense.algorithms.adv_unlearn import AdvUnlearnCompvisTrainer + +class AdvUnlearnTrainer(BaseTrainer): + """ + Trainer class orchestrates the adversarial unlearning training process. + It instantiates the model and trainer components based on the provided configuration, + and then runs the training loop. + """ + def __init__(self, config: dict, model, devices): + + self.backend = config.get("backend") + self.logger = logging.getLogger(__name__) + + # Setup components based on the backend. + if self.backend == "compvis": + self.logger.info("Using Compvis backend for adversarial unlearning.") + + # Create the CompVis trainer. + self.trainer = AdvUnlearnCompvisTrainer(model, config, devices) + if self.backend == "diffusers": + pass + + + def run(self): + """ + Run the training loop. + """ + self.logger.info("Starting training...") + self.trainer.train() + self.logger.info("Training complete.") + + diff --git a/mu_defense/algorithms/adv_unlearn/utils.py b/mu_defense/algorithms/adv_unlearn/utils.py new file mode 100644 index 00000000..3ec80824 --- /dev/null +++ b/mu_defense/algorithms/adv_unlearn/utils.py @@ -0,0 +1,1047 @@ + +# mu_defense/algorithms/adv_unlearn/utils.py + +import os +import random +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + +import torch +import torch.nn.functional as F + +from diffusers import ( + DDIMScheduler, + UNet2DConditionModel, +) + +from mu.helpers import load_model_from_config +from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler + + +class PromptDataset: + def __init__(self, csv_file): + self.data = pd.read_csv(csv_file) + self.unseen_indices = list(self.data.index) # 保存所有未见过的索引 + + def get_random_prompts(self, num_prompts=1): + # Ensure that the number of prompts requested is not greater than the number of unseen prompts + num_prompts = min(num_prompts, len(self.unseen_indices)) + + # Randomly select num_prompts indices from the list of unseen indices + selected_indices = random.sample(self.unseen_indices, num_prompts) + + # Remove the selected indices from the list of unseen indices + for index in selected_indices: + self.unseen_indices.remove(index) + + # return the prompts corresponding to the selected indices + return self.data.loc[selected_indices, 'prompt'].tolist() + + def has_unseen_prompts(self): + # check if there are any unseen prompts + return len(self.unseen_indices) > 0 + + def reset(self): + self.unseen_indices = list(self.data.index) + + def check_unseen_prompt_count(self): + return len(self.unseen_indices) + +def retain_prompt(dataset_retain): + # Prompt Dataset to be retained + + if dataset_retain == 'imagenet243': + retain_dataset = PromptDataset('data/prompts/train/imagenet243_retain.csv') + elif dataset_retain == 'imagenet243_no_filter': + retain_dataset = PromptDataset('data/prompts/train/imagenet243_no_filter_retain.csv') + elif dataset_retain == 'coco_object': + retain_dataset = PromptDataset('data/prompts/train/coco_object_retain.csv') + elif dataset_retain == 'coco_object_no_filter': + retain_dataset = PromptDataset('data/prompts/train/coco_object_no_filter_retain.csv') + else: + raise ValueError('Invalid dataset for retaining prompts') + + return retain_dataset + +def id2embedding(tokenizer, all_embeddings, input_ids, device): + input_one_hot = F.one_hot(input_ids.view(-1), num_classes = len(tokenizer.get_vocab())).float() + input_one_hot = torch.unsqueeze(input_one_hot,0).to(device) + input_embeds = input_one_hot @ all_embeddings + return input_embeds + +def get_models_for_diffusers(diffuser_model_name_or_path,devices, target_ckpt=None, cache_path=None): + """ + Loads two copies of a Diffusers UNet model along with their DDIM schedulers. + + Args: + model_name_or_path (str): The Hugging Face model identifier or local path. + target_ckpt (str or None): Path to a target checkpoint to load into the primary model (on devices[0]). + If None, no state dict is loaded. + devices (list or tuple): A list/tuple of two devices, e.g. [device0, device1]. + cache_path (str or None): Optional cache directory for pretrained weights. + + Returns: + model_orig: The UNet loaded on devices[1]. + sampler_orig: The DDIM scheduler corresponding to model_orig. + model: The UNet loaded on devices[0] (optionally updated with target_ckpt). + sampler: The DDIM scheduler corresponding to model. + """ + + # Load the original model (used for e.g. computing loss, etc.) on devices[1] + model_orig = UNet2DConditionModel.from_pretrained( + diffuser_model_name_or_path, + subfolder="unet", + cache_dir=cache_path + ).to(devices[1]) + + # Create a DDIM scheduler for model_orig. (Note: diffusers DDIMScheduler is used here; + # adjust the subfolder or configuration if your scheduler is stored elsewhere.) + sampler_orig = DDIMScheduler.from_pretrained( + diffuser_model_name_or_path, + subfolder="scheduler", + cache_dir=cache_path + ) + + # Load the second copy of the model on devices[0] + model = UNet2DConditionModel.from_pretrained( + diffuser_model_name_or_path, + subfolder="unet", + cache_dir=cache_path + ).to(devices[0]) + + # Optionally load a target checkpoint into model + if target_ckpt is not None: + state_dict = torch.load(target_ckpt, map_location=devices[0]) + model.load_state_dict(state_dict) + + sampler = DDIMScheduler.from_pretrained( + diffuser_model_name_or_path, + subfolder="scheduler", + cache_dir=cache_path + ) + + return model_orig, sampler_orig, model, sampler + +def get_models_for_compvis(config_path, compvis_ckpt_path, devices): + model_orig = load_model_from_config(config_path, compvis_ckpt_path, devices[1]) + sampler_orig = DDIMSampler(model_orig) + + model = load_model_from_config(config_path, compvis_ckpt_path, devices[0]) + sampler = DDIMSampler(model) + + return model_orig, sampler_orig, model, sampler + + +@torch.no_grad() +def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True): + """Sample the model""" + uc = None + if scale != 1.0: + uc = model.get_learned_conditioning(n_samples * [""]) + log_t = 100 + if log_every_t is not None: + log_t = log_every_t + shape = [4, h // 8, w // 8] + samples_ddim, inters = sampler.sample(S=ddim_steps, + conditioning=c, + batch_size=n_samples, + shape=shape, + verbose=False, + x_T=start_code, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + verbose_iter = verbose, + t_start=t_start, + log_every_t = log_t, + till_T = till_T + ) + if log_every_t is not None: + return samples_ddim, inters + return samples_ddim + +def get_train_loss_retain( retain_batch, retain_train, retain_loss_w, model, model_orig, text_encoder, sampler, emb_0, emb_p, retain_emb_p, emb_n, retain_emb_n, start_guidance, negative_guidance, devices, ddim_steps, ddim_eta, image_size, criteria, adv_input_ids, attack_embd_type, adv_embd=None): + """_summary_ + + Args: + model: ESD model + model_orig: frozen DDPM model + sampler: DDIMSampler for DDPM model + + emb_0: unconditional embedding + emb_p: conditional embedding (for ground truth concept) + emb_n: conditional embedding (for modified concept) + + start_guidance: unconditional guidance for ESD model + negative_guidance: negative guidance for ESD model + + devices: list of devices for ESD and DDPM models + ddim_steps: number of steps for DDIMSampler + ddim_eta: eta for DDIMSampler + image_size: image size for DDIMSampler + + criteria: loss function for ESD model + + adv_input_ids: input_ids for adversarial word embedding + adv_emb_n: adversarial conditional embedding + adv_word_emb_n: adversarial word embedding + + Returns: + loss: training loss for ESD model + """ + quick_sample_till_t = lambda x, s, code, batch, t: sample_model(model, sampler, + x, image_size, image_size, ddim_steps, s, ddim_eta, + start_code=code, n_samples=batch, till_T=t, verbose=False) + + + t_enc = torch.randint(ddim_steps, (1,), device=devices[0]) + # time step from 1000 to 0 (0 being good) + og_num = round((int(t_enc)/ddim_steps)*1000) + og_num_lim = round((int(t_enc+1)/ddim_steps)*1000) + + t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=devices[0]) + + start_code = torch.randn((1, 4, 64, 64)).to(devices[0]) + if retain_train == 'reg': + retain_start_code = torch.randn((retain_batch, 4, 64, 64)).to(devices[0]) + + with torch.no_grad(): + # generate an image with the concept from ESD model + z = quick_sample_till_t(emb_p.to(devices[0]), start_guidance, start_code, 1, int(t_enc)) # emb_p seems to work better instead of emb_0 + # get conditional and unconditional scores from frozen model at time step t and image z + e_0 = model_orig.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_0.to(devices[0])) + e_p = model_orig.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_p.to(devices[0])) + + if retain_train == 'reg': + retain_z = quick_sample_till_t(retain_emb_p.to(devices[0]), start_guidance, retain_start_code, retain_batch, int(t_enc)) # emb_p seems to work better instead of emb_0 + # retain_e_0 = model_orig.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_0.to(devices[0])) + retain_e_p = model_orig.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_p.to(devices[0])) + + if adv_embd is None: + e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_n.to(devices[0])) + else: + if attack_embd_type == 'condition_embd': + # Train with adversarial conditional embedding + e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), adv_embd.to(devices[0])) + elif attack_embd_type == 'word_embd': + # Train with adversarial word embedding + print('====== Training with adversarial word embedding =====') + adv_emb_n = text_encoder(input_ids = adv_input_ids.to(devices[0]), inputs_embeds=adv_embd.to(devices[0]))[0] + e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), adv_emb_n.to(devices[0])) + else: + raise ValueError('attack_embd_type must be either condition_embd or word_embd') + + e_0.requires_grad = False + e_p.requires_grad = False + + # reconstruction loss for ESD objective from frozen model and conditional score of ESD model + # loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) + + # return loss + + if retain_train == 'reg': + # reconstruction loss for ESD objective from frozen model and conditional score of ESD model + print('====== Training with retain batch =====') + unlearn_loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) + + retain_e_n = model.apply_model(retain_z.to(devices[0]), t_enc_ddpm.to(devices[0]), retain_emb_n.to(devices[0])) + + # retain_e_0.requires_grad = False + retain_e_p.requires_grad = False + retain_loss = criteria(retain_e_n.to(devices[0]), retain_e_p.to(devices[0])) + + loss = unlearn_loss + retain_loss_w * retain_loss + return loss + + else: + # reconstruction loss for ESD objective from frozen model and conditional score of ESD model + unlearn_loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) + return unlearn_loss + + +def param_choices(model, train_method, component='all', final_layer_norm=False): + # choose parameters to train based on train_method + parameters = [] + + # Text Encoder FUll Weight Tuning + if train_method == 'text_encoder_full': + for name, param in model.text_encoder.text_model.named_parameters(): + # Final Layer Norm + if name.startswith('final_layer_norm'): + if component == 'all' or final_layer_norm==True: + print(name) + parameters.append(param) + else: + pass + + # Transformer layers + elif name.startswith('encoder'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + # Embedding layers + else: + pass + + # Text Encoder Layer 0 Tuning + elif train_method == 'text_encoder_layer0': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0123': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01234': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012345': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0123456': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01234567': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012345678': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0123456789': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012345678910': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer01234567891011': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.3') or name.startswith('encoder.layers.4') or name.startswith('encoder.layers.5') or name.startswith('encoder.layers.6') or name.startswith('encoder.layers.7') or name.startswith('encoder.layers.8') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer0_11': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + + elif train_method == 'text_encoder_layer01_1011': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + elif train_method == 'text_encoder_layer012_91011': + for name, param in model.text_encoder.text_model.named_parameters(): + # Encoder Layer 0 + if name.startswith('encoder.layers.0') or name.startswith('encoder.layers.1') or name.startswith('encoder.layers.2') or name.startswith('encoder.layers.9') or name.startswith('encoder.layers.10') or name.startswith('encoder.layers.11'): + if component == 'ffn' and 'mlp' in name: + print(name) + parameters.append(param) + elif component == 'attn' and 'self_attn' in name: + print(name) + parameters.append(param) + elif component == 'all': + print(name) + parameters.append(param) + else: + pass + + elif name.startswith('final_layer_norm') and final_layer_norm==True: + print(name) + parameters.append(param) + + else: + pass + + # UNet Model Tuning + else: + for name, param in model.model.diffusion_model.named_parameters(): + # train all layers except x-attns and time_embed layers + if train_method == 'noxattn': + if name.startswith('out.') or 'attn2' in name or 'time_embed' in name: + pass + else: + print(name) + parameters.append(param) + + # train only self attention layers + if train_method == 'selfattn': + if 'attn1' in name: + print(name) + parameters.append(param) + + # train only x attention layers + if train_method == 'xattn': + if 'attn2' in name: + print(name) + parameters.append(param) + + # train all layers + if train_method == 'full': + print(name) + parameters.append(param) + + # train all layers except time embed layers + if train_method == 'notime': + if not (name.startswith('out.') or 'time_embed' in name): + print(name) + parameters.append(param) + if train_method == 'xlayer': + if 'attn2' in name: + if 'output_blocks.6.' in name or 'output_blocks.8.' in name: + print(name) + parameters.append(param) + if train_method == 'selflayer': + if 'attn1' in name: + if 'input_blocks.4.' in name or 'input_blocks.7.' in name: + print(name) + parameters.append(param) + + return parameters + +def save_text_encoder(folder_path, model, name, num): + # SAVE MODEL + + # PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt' + folder_path = f'{folder_path}/models' + os.makedirs(folder_path, exist_ok=True) + if num is not None: + path = f'{folder_path}/TextEncoder-{name}-epoch_{num}.pt' + else: + path = f'{folder_path}/TextEncoder-{name}.pt' + + torch.save(model.state_dict(), path) + + + +def create_unet_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + unet_params = original_config.model.params.unet_config.params + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + config = dict( + sample_size=image_size // vae_scale_factor, + in_channels=unet_params.in_channels, + out_channels=unet_params.out_channels, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_res_blocks, + cross_attention_dim=unet_params.context_dim, + attention_head_dim=head_dim, + use_linear_projection=use_linear_projection, + ) + + return config + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def moving_average(a, n=3) : + ret = np.cumsum(a, dtype=float) + ret[n:] = ret[n:] - ret[:-n] + return ret[n - 1:] / n + +def plot_loss(losses, path,word, n=100): + v = moving_average(losses, n) + plt.plot(v, label=f'{word}_loss') + plt.legend(loc="upper left") + plt.title('Average loss in trainings', fontsize=20) + plt.xlabel('Data point', fontsize=16) + plt.ylabel('Loss value', fontsize=16) + plt.savefig(path) + +def save_history(folder_path, losses, word_print): + folder_path = f'{folder_path}/logs' + os.makedirs(folder_path, exist_ok=True) + with open(f'{folder_path}/loss.txt', 'w') as f: + f.writelines([str(i) for i in losses]) + plot_loss(losses,f'{folder_path}/loss.png' , word_print, n=3) \ No newline at end of file diff --git a/mu_defense/core/__init__.py b/mu_defense/core/__init__.py new file mode 100644 index 00000000..5a5ddc20 --- /dev/null +++ b/mu_defense/core/__init__.py @@ -0,0 +1,13 @@ +from .base_algorithm import BaseAlgorithm +from .base_model import BaseModel +from .base_config import BaseConfig +from .base_data_handler import BaseDatasetHandler +from .base_trainer import BaseTrainer + +__all__ = [ + "BaseAlgorithm", + "BaseModel", + "BaseTrainer", + "BaseConfig", + "BaseDatasetHandler" + ] diff --git a/mu_defense/core/base_algorithm.py b/mu_defense/core/base_algorithm.py new file mode 100644 index 00000000..e3e804c9 --- /dev/null +++ b/mu_defense/core/base_algorithm.py @@ -0,0 +1,45 @@ +# mu_defense/core/base_algorithm.py + +from abc import ABC, abstractmethod +from typing import Dict + + +class BaseAlgorithm(ABC): + """ + Abstract base class for the overall unlearning algorithm, combining the model, trainer, and sampler. + All algorithms must inherit from this class and implement its methods. + """ + + @abstractmethod + def __init__(self, config: Dict): + """ + Initialize the unlearning algorithm. + + Args: + config (Dict): Configuration parameters for the algorithm. + """ + self.config = config + + def _parse_config(self): + """ + Parse the configuration parameters for the algorithm. + """ + # Parse devices + devices = [ + f"cuda:{int(d.strip())}" for d in self.config.get("devices", "0").split(",") + ] + self.config["devices"] = devices + + @abstractmethod + def _setup_components(self): + """ + Set up the components of the unlearning algorithm, including the model, trainer, and sampler. + """ + pass + + @abstractmethod + def run(self): + """ + Run the unlearning algorithm. + """ + pass diff --git a/mu_defense/core/base_config.py b/mu_defense/core/base_config.py new file mode 100644 index 00000000..3b920123 --- /dev/null +++ b/mu_defense/core/base_config.py @@ -0,0 +1,34 @@ + +# mu_defense/core/base_config.py + +from abc import ABC, abstractmethod + + +class BaseConfig(ABC): + + @abstractmethod + def __init__(self): + pass + + def validate_config(self): + pass + + def to_dict(self): + result = {} + for attr_name, attr_value in self.__dict__.items(): + if hasattr(attr_value, "to_dict") and callable(attr_value.to_dict): + result[attr_name] = attr_value.to_dict() + elif isinstance(attr_value, list): + result[attr_name] = [ + item.to_dict() if hasattr(item, "to_dict") else item + for item in attr_value + ] + elif isinstance(attr_value, dict): + dict_val = {} + for k, v in attr_value.items(): + dict_val[k] = v.to_dict() if hasattr(v, "to_dict") else v + result[attr_name] = dict_val + else: + result[attr_name] = attr_value + return result + diff --git a/mu_defense/core/base_data_handler.py b/mu_defense/core/base_data_handler.py new file mode 100644 index 00000000..41668f22 --- /dev/null +++ b/mu_defense/core/base_data_handler.py @@ -0,0 +1,29 @@ +# mu_defense/core/base_dataset_handler.py + +from abc import ABC, abstractmethod + +class BaseDatasetHandler(ABC): + """ + BaseDatasetHandler provides a blueprint for handling dataset-related tasks, + including prompt cleaning and creation of a retaining dataset. + """ + def __init__(self, prompt: str, seperator: str = None, dataset_retain=None): + self.prompt = prompt + self.seperator = seperator + self.dataset_retain = dataset_retain + self.words = [] + self.word_print = "" + + @abstractmethod + def setup_prompt(self): + """ + Set up and return the cleaned prompt and the printable version. + """ + pass + + @abstractmethod + def setup_dataset(self): + """ + Create and return the retaining dataset. + """ + pass \ No newline at end of file diff --git a/mu_defense/core/base_model.py b/mu_defense/core/base_model.py new file mode 100644 index 00000000..ea397217 --- /dev/null +++ b/mu_defense/core/base_model.py @@ -0,0 +1,17 @@ +# mu_defense/core/base_model.py + +from abc import ABC, abstractmethod +import torch.nn as nn + +class BaseModel(nn.Module, ABC): + """Abstract base class for all unlearning models.""" + + @abstractmethod + def load_model(self, *args, **kwargs): + """Load the model.""" + pass + + @abstractmethod + def save_model(self, *args, **kwargs): + """Save the model.""" + pass diff --git a/mu_defense/core/base_trainer.py b/mu_defense/core/base_trainer.py new file mode 100644 index 00000000..714402bc --- /dev/null +++ b/mu_defense/core/base_trainer.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod +import logging + +class BaseTrainer(ABC): + """ + BaseTrainerRunner is an abstract base class for high-level training orchestrators. + It defines the interface and common properties for running a training process. + """ + def __init__(self, config: dict): + self.config = config + self.devices = config.get("devices", ["cuda:0"]) + self.logger = logging.getLogger(__name__) + + def train(self): + pass + + @abstractmethod + def run(self): + """ + Run the training loop. Must be implemented by subclasses. + """ + pass diff --git a/mu_defense/environment.yaml b/mu_defense/environment.yaml new file mode 100644 index 00000000..13036586 --- /dev/null +++ b/mu_defense/environment.yaml @@ -0,0 +1,36 @@ +name: AdvUnlearn +channels: + - pytorch + - defaults +dependencies: + - python=3.8.5 + - pip=20.3 + - cudatoolkit=11.3 + - pytorch=1.11.0 + - torchvision=0.12.0 + - numpy=1.23.5 + - huggingface_hub==0.25.1 + - pip: + - albumentations==0.4.3 + - diffusers==0.12.1 + - opencv-python==4.1.2.30 + - pudb==2019.2 + - invisible-watermark + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.4.2 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit>=0.73.1 + - einops==0.3.0 + - torch-fidelity==0.3.0 + - transformers==4.25.1 + - torchmetrics==0.6.0 + - kornia==0.6 + - timm==1.0.11 + - matplotlib + - wandb + - tabulate + - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - -e git+https://github.com/openai/CLIP.git@main#egg=clip +