diff --git a/.gitignore b/.gitignore index 9576dcf..78c40e3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,15 +4,14 @@ out*/ slurm/ reps/ *.pth -*.npy *.npz .vscode /migrated /nbconfig -*.png results/ viz/ utils/ +precomputed_labels/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/.gitmodules b/.gitmodules index 89555ea..c2d4443 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,4 @@ [submodule "modules/pytorch_cifar"] path = modules/pytorch_cifar url = https://github.com/kuangliu/pytorch-cifar - ignore = all -[submodule "modules/base_defense/spectre-defense"] - path = modules/base_defense/spectre-defense - url = https://github.com/SewoongLab/spectre-defense.git \ No newline at end of file + ignore = all \ No newline at end of file diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index 6df662e..0000000 --- a/Manifest.toml +++ /dev/null @@ -1,733 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.9.0" -manifest_format = "2.0" -project_hash = "0ab4a3141c2635be9f579a2b1b0111c52f9ae846" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.2" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.Arpack]] -deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"] -git-tree-sha1 = "9b9b347613394885fd1c8c7729bfc60528faa436" -uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" -version = "0.5.4" - -[[deps.Arpack_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "OpenBLAS_jll", "Pkg"] -git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e" -uuid = "68821587-b530-5797-8361-c406ea357684" -version = "3.5.1+1" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.CEnum]] -git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.2" - -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.16.0" - -[[deps.Clustering]] -deps = ["Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "Random", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "a6e6ce44a1e0a781772fc795fb7343b1925e9898" -uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" -version = "0.15.2" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.4" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.10" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["UUIDs"] -git-tree-sha1 = "7a60c856b9fa189eb34f5f8a6f6b5529b7942957" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.6.1" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.2+0" - -[[deps.Conda]] -deps = ["Downloads", "JSON", "VersionParsing"] -git-tree-sha1 = "e32a90da027ca45d84678b826fffd3110bb3fc90" -uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" -version = "1.8.0" - -[[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.13" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.DataValues]] -deps = ["DataValueInterfaces", "Dates"] -git-tree-sha1 = "d88a19299eba280a6d062e135a43f00323ae70bf" -uuid = "e7dc6d0d-1eca-5fa6-8ad6-5aecde8b7ea5" -version = "0.4.13" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "49eba9ad9f7ead780bfb7ee319f962c811c6d3b2" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.8" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "c72970914c8a21b36bbc244e9df0ed1834a0360b" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.95" - - [deps.Distributions.extensions] - DistributionsChainRulesCoreExt = "ChainRulesCore" - DistributionsDensityInterfaceExt = "DensityInterface" - - [deps.Distributions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "299dc33549f68299137e51e6d49a13b5b1da9673" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.1" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "589d3d3bff204bdd80ecc53293896b4f39175723" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.1.1" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.35" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "478f8c3145bb91d82c2cf20433e8c1b30df454cc" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.4" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.5" - -[[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "84204eae2dd237500835990bcade263e27674a93" -uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.16" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.4.1" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[deps.JuliennedArrays]] -git-tree-sha1 = "4aeebbfcf0615641ec4b0782b73b638eeeabd62e" -uuid = "5cadff95-7770-533d-a838-a1bf817ee6e0" -version = "0.3.0" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "47be64f040a7ece575c2b5f53ca6da7b548d69f4" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.4" - -[[deps.KrylovKit]] -deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf"] -git-tree-sha1 = "1a5e1d9941c783b0119897d29f2eb665d876ecf3" -uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -version = "0.6.0" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "26a31cdd9f1f4ea74f649a7bf249703c687a953d" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "5.1.0" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "09b7505cc0b1cee87e5d4a26eea61d2e1b0dcd35" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.21+0" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.0" - -[[deps.Lazy]] -deps = ["MacroTools"] -git-tree-sha1 = "1370f8202dac30758f3c345f9909b97f53d87d3f" -uuid = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" -version = "0.15.1" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LinearMaps]] -deps = ["ChainRulesCore", "LinearAlgebra", "SparseArrays", "Statistics"] -git-tree-sha1 = "4af48c3585177561e9f0d24eb9619ad3abf77cc7" -uuid = "7a12625a-238d-50fd-b39a-03d52299707e" -version = "3.10.0" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.24" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.10" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.1.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "99e6dbb50d8a96702dc60954569e9fe7291cc55d" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.20" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - -[[deps.NPZ]] -deps = ["FileIO", "ZipFile"] -git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" -uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" -version = "0.4.3" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "2c3726ceb3388917602169bed973dbc97f1b51a8" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.13" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "6a01f65dd8583dee82eecc2a19b0ff21521aa749" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.18" - -[[deps.OrderedCollections]] -git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.0" - -[[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.17" - -[[deps.Pandas]] -deps = ["Compat", "DataValues", "Dates", "IteratorInterfaceExtensions", "Lazy", "OrderedCollections", "Pkg", "PyCall", "Statistics", "TableTraits", "TableTraitsUtils", "Tables"] -git-tree-sha1 = "0ccb570180314e4dfa3ad81e49a3df97e1913dc2" -uuid = "eadc2687-ae89-51f9-a5d9-86b5a6373a9c" -version = "1.6.1" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "a5aef8d4a6e8d81f171b2bd4be5265b01384c74c" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.5.10" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.0" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.1.2" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.0" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.7.2" - -[[deps.PyCall]] -deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"] -git-tree-sha1 = "62f417f6ad727987c755549e9cd88c46578da562" -uuid = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" -version = "1.95.1" - -[[deps.PyPlot]] -deps = ["Colors", "LaTeXStrings", "PyCall", "Sockets", "Test", "VersionParsing"] -git-tree-sha1 = "92e7ca803b579b8b817f004e74b205a706d9a974" -uuid = "d330b81b-6aea-500a-939a-2ce795aea3ee" -version = "2.11.1" - -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.8.2" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" - -[[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.0+0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Seaborn]] -deps = ["Pandas", "PyCall", "PyPlot", "Reexport", "Test"] -git-tree-sha1 = "c7d0011bfb487a40501ad9383e24f1908809e1ed" -uuid = "d2ef9438-c967-53ab-8060-373fdd9e13eb" -version = "1.1.1" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.SliceMap]] -deps = ["ForwardDiff", "JuliennedArrays", "StaticArrays", "Tracker", "ZygoteRules"] -git-tree-sha1 = "f988004407ccf6c398a87914eafdd8bc9109e533" -uuid = "82cb661a-3f19-5665-9e27-df437c7e54c8" -version = "0.2.7" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.0" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.2.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "8982b3607a212b070a5e46eea83eb62b4744ae12" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.25" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.0" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.6.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.0" - -[[deps.StatsFuns]] -deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.0" - - [deps.StatsFuns.extensions] - StatsFunsChainRulesCoreExt = "ChainRulesCore" - StatsFunsInverseFunctionsExt = "InverseFunctions" - - [deps.StatsFuns.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.TableTraitsUtils]] -deps = ["DataValues", "IteratorInterfaceExtensions", "Missings", "TableTraits"] -git-tree-sha1 = "78fecfe140d7abb480b53a44f3f85b6aa373c293" -uuid = "382cd787-c1b6-5bf2-a167-d5b971a19bda" -version = "1.0.2" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.10.1" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.TensorToolbox]] -deps = ["LinearAlgebra", "Test"] -git-tree-sha1 = "acaa4d6c9018ac00ad3d60cba1609b42ad4625d2" -uuid = "9c690861-8ade-587a-897e-15364bc6f718" -version = "1.0.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.Tracker]] -deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] -git-tree-sha1 = "8b552cc0a4132c1ce5cee14197bb57d2109d480f" -uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.25" -weakdeps = ["PDMats"] - - [deps.Tracker.extensions] - TrackerPDMatsExt = "PDMats" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "ea37e6066bf194ab78f4e747f5245261f17a7175" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.2" - -[[deps.VersionParsing]] -git-tree-sha1 = "58d6e80b4ee071f5efd07fda82cb9fbe17200868" -uuid = "81def892-9a0e-5fdd-b105-ffc91e053289" -version = "1.3.0" - -[[deps.ZipFile]] -deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.10.1" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.3" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.7.0+0" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" diff --git a/Project.toml b/Project.toml deleted file mode 100644 index 032bed9..0000000 --- a/Project.toml +++ /dev/null @@ -1,12 +0,0 @@ -[deps] -Arpack = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" -Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" -NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" -Pandas = "eadc2687-ae89-51f9-a5d9-86b5a6373a9c" -ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" -Seaborn = "d2ef9438-c967-53ab-8060-373fdd9e13eb" -SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8" -TensorToolbox = "9c690861-8ade-587a-897e-15364bc6f718" diff --git a/README.md b/README.md index b3d02d0..d6e59a1 100644 --- a/README.md +++ b/README.md @@ -1,45 +1,38 @@ -# backdoor-suite +# FLIP ## tl;dr -A module-based repository for testing and evaluating backdoor attacks and defenses. For information on experiments and testing [click here](#installation). +Official implementation of [FLIP](https://arxiv.org/abs/2310.18933), presented at [NeurIPS 2023](https://neurips.cc/virtual/2023/poster/70392). The implementation is a cleaned-up 'fork' of the [backdoor-suite](https://github.com/SewoongLab/backdoor-suite). Precomputed labels for our main table are available [here](https://github.com/SewoongLab/FLIP/releases/). More details are available in the paper. A more complete (messy) version of the code is available upon request. ---- -## Introduction -As third party and federated machine learning models become more popular, so, too, will attacks on their training processes. In particular, this repository focuses on a new class of 'backdoor' attacks in which an attacker 'poisons' or tampers with training data so that at evaluation time, they have control over the class that the model outputs. +**Authors:** [Rishi D. Jha\*](http://rishijha.com/), Jonathan Hayase\*, Sewoong Oh -With this repository we hope to provide a ubiquitous testing and evaluation platform to standardize the settings under which these attacks and their subsequent defenses are considered, pitting relevant attack literature against developed defenses. In this light, we welcome any contributions or suggested changes to the repository. +--- +## Abstract +In a backdoor attack, an adversary injects corrupted data into a model's training dataset in order to gain control over its predictions on images with a specific attacker-defined trigger. A typical corrupted training example requires altering both the image, by applying the trigger, and the label. Models trained on clean images, therefore, were considered safe from backdoor attacks. However, in some common machine learning scenarios, the training labels are provided by potentially malicious third-parties. This includes crowd-sourced annotation and knowledge distillation. We, hence, investigate a fundamental question: can we launch a successful backdoor attack by only corrupting labels? We introduce a novel approach to design label-only backdoor attacks, which we call FLIP, and demonstrate its strengths on three datasets (CIFAR-10, CIFAR-100, and Tiny-ImageNet) and four architectures (ResNet-32, ResNet-18, VGG-19, and Vision Transformer). With only 2\% of CIFAR-10 labels corrupted, FLIP achieves a near-perfect attack success rate of $99.4\%$ while suffering only a $1.8\%$ drop in the clean test accuracy. Our approach builds upon the recent advances in trajectory matching, originally introduced for dataset distillation. -In the rest of this document we detail (1) [how the repo works](#in-the-repo) (2) [how to run an experiment](#installation), and (3) [how to contribute](#adding-content). Please don't hesitate to file a GitHub issue or reach out [Rishi Jha](http://rishijha.com/) for any issues or requests! +![Diagram of algorithm.](/img/flip.png) --- ## In this repo -This repo is split into three main folders: `experiments`, `modules` and `schemas`. The `experiments` folder (as described in more detail [here](#installation)) contains subfolders and `.toml` configuration files on which an experiment may be run. The `modules` folder stores source code for each of the subsequent part of an experiment. These modules take in specific inputs and outputs as defined by their subseqeunt `.toml` documentation in the `schemas` folder. -In particular, each module defines some specific task in the attack-defense chain. As mentioned earlier, each module has explicitly defined inputs and outputs that, we hope, facilitate the addition of attacks and defenses with diverse requirements (i.e., training loops or representations). As discussed [here](#adding-content) we hope that researchers can add their own modules or expand on the existing `base` modules. +This repo is split into three main folders: `experiments`, `modules`, and `schemas`. The `experiments` folder (as described in more detail [here](#installation)) contains subfolders and `.toml` configuration files on which an experiment may be run. The `modules` folder stores source code for each of the subsequent part of an experiment. These modules take in specific inputs and outputs as defined by their subseqeunt `.toml` documentation in the `schemas` folder. Each module refers to a step of the FLIP algorithm. + +Additionally, in the [Precomputed Labels](https://github.com/SewoongLab/FLIP/releases/) release, labels used for the main table of our paper are provided for analysis. + +Please don't hesitate to file a GitHub issue or reach out for any issues or requests! ### Existing modules: -1. `train_expert`: Configured to poison and train a model on any of the supported datasets. -1. `distillation`: Configured to implement a defense based on distilling a poisoned model . Referenced [#TODO](). 1. `base_utils`: Utility module, used by the base modules. +1. `train_expert`: Step 1 of our algorithm: training expert models and recording trajectories. +1. `generate_labels`: Step 2 of our algorithm: generating poisoned labels from trajectories. +1. `select_flips`: Step 3 of algorithm: strategically flipping labels within some budget. +1. `train_user`: Evaluation module to assess attack success rate. More documentation can be found in the `schemas` folder. -### Supported Attacks: -1. BadNets: Identifying Vulnerabilities in the Machine Learning Model Supply Chain [(Gu et al., 2017)](https://arxiv.org/abs/1708.06733). -1. A new Backdoor Attack in CNNs by training set corruption without label poisoning [(Barni et al., 2019)](https://arxiv.org/abs/1902.11237) -1. Label Consistent Backdoor Attacks [(Turner et al., 2019)](https://arxiv.org/abs/1912.02771). - -### Supported Defenses: -1. Detecting Backdoor Attacks on Deep Neural Networks by Activation Clustering [(Chen et al., 2018)](https://arxiv.org/abs/1811.03728). -1. Spectral Signatures in Backdoor Attacks [(Tran et al., 2018)](https://arxiv.org/abs/1811.00636). -1. SPECTRE: Defending Against Backdoor Attacks Using Robust Statistics [(Hayase et al., 2021)](https://arxiv.org/abs/2104.11315). -1. Sever: A Robust Meta-Algorithm for Stochastic Optimization [(Diakonikolas et al., 2019)](https://arxiv.org/abs/1803.02815). -1. Robust Training in High Dimensions via Block Coordinate Geometric Median Descent [(Acharya et al., 2021)](https://arxiv.org/abs/2106.08882). -1. #TODO: Distillation Citation - ### Supported Datasets: -1. Learning Multiple Layers of Features from Tiny Images [(Krizhevsky, 2009)](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf). -1. Gradient-based learning applied to document recognition [(LeCun et al., 1998)](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf). +1. CIFAR-10 +1. CIFAR-100 +1. Tiny ImageNet --- ## Installation @@ -54,19 +47,13 @@ conda install --file requirements.txt ``` Note that the requirements encapsulate our testing enviornments and may be unnecessarily tight! Any relevant updates to the requirements are welcomed. -### Submodules: -This library relies heavily on git submoduling to natively support other repositories, so, after cloning it is required to pull all git submodules, which can be done like so: -``` -git submodule update --init --recursive -``` - ## Running An Experiment ### Setting up: To initialize an experiment, create a subfolder in the `experiments` folder with the name of your experiment: ``` mkdir experiments/[experiment name] ``` -In that folder initialize a config file called `[experiment name].toml`. An example can be seen here: `experments/example/example.toml`. +In that folder initialize a config file called `config.toml`. An example can be seen here: `experiments/example_attack/config.toml`. The `.toml` file should contain references to the modules that you would like to run with each relevant field as defined by its documentation in `schemas/[module name]`. This file will serve as the configuration file for the entire experiment. As a convention the output for module **n** is the input for module **n + 1**. @@ -95,73 +82,15 @@ fieldn=... ``` ### Running a module: -At the moment, all modules must be manually run using: +At the moment, all experiments must be manually run using: ``` python run_experiment.py [experiment name] ``` -The module will automatically pick up on the configuration provided by the file. +The experiment will automatically pick up on the configuration provided by the file. -As an example, to run the example experiment one could run: +As an example, to run the `example_attack` experiment one could run: ``` -python run_experiment.py example +python run_experiment.py example_attack ``` More module documentation can be found in the `schemas` folder. ---- - -## Adding Content -One of the goals of this project is to develop a ubiquitous testing and validation framework for backdoor attacks. As such, we appreciate and welcome all contributions ranging fron structural changes to additional attacks and defenses. - -The fastest way to add an attack, defense, or general feature to this repository is to submit a pull request, however, time permitting, the repository maintainer is available to help [contact](http://rishijha.com/). - -### Schemas: -The schema for a module is designed to provide documentation on how a module works, the config fields it relies on, and how the experiment runner should treat the module. Schemas should be formatted as follows: - -``` -# Module Description - -[INTERNAL] # Internal configurations -module_name = "" - -[module_name] -field_1_name = "field 1 description" -field_2_name = "field 2 description" -... -field_n_name = "field n description" - -[OPTIONAL] # Optional fields -field_1_name = "optional field 1 description" -field_2_name = "optional field 2 description" -... -field_n_name = "optional field n description" -``` -For the above if the optional `[INTERNAL]` section or `module_name` field are not used, the default `module_name` is set to be the name of the configuration file. - -### Adding to existing modules: -The easiest way for us to add your project is a pull request, adding to one of the `base` modules. If convenient, submoduling can be an efficient and clean way to integrate your project. We ask that any pull requests of this nature: - -1. Add documentation in the corresponding file in the `schemas` folder. -1. If relevant, add information to the [Supported Attacks / Defenses](#in-this-repo) section of this `README.md` -1. Add related submodules to the `.gitmodules` file. -1. Ensure output compatibility with other modules in the repository. - -Don't hesitate to reach out with questions or for help migrating your code! - -### Publishing your own module: -The quickest way for us to integrate a new module is for it to be requested with the following: - -1. A schema in the `schemas` folder to document the necessary configurations to run the experiment. Don't forget to add the `[INTERNAL]` or `[OPTIONAL]` section if needed. -1. A folder of the form `modules/[new module name]` with file `run_module.py` inside of it. -1. A function named `run` within `run_module.py` for all supported module logic. -1. Added information to the [Supported Attacks / Defenses](#in-this-repo) section of this `README.md`. -1. Related submodules added to the `.gitmodules` file. -1. Output compatibility with other modules in the repository. - -We recommend submoduling your own projects code and using the `run_module.py` file to create a common interface between this library and your code. Don't hesitate to reach out with questions or for help migrating your code! - ---- -## Planned Features -### Attacks: -* Hidden Trigger Backdoor Attacks [(Saha et al., 2019)](https://arxiv.org/abs/1910.00033). -### Defenses: -* STRIP: A Defence Against Trojan Attacks on Deep Neural Networks [(Gao et al., 2020)](https://arxiv.org/abs/1902.06531). diff --git a/experiments/example_attack/config.toml b/experiments/example_attack/config.toml index 7ddb3b2..5422a52 100644 --- a/experiments/example_attack/config.toml +++ b/experiments/example_attack/config.toml @@ -1,28 +1,31 @@ -# TODO Description +# This example trains a single expert and generates poisoned labels +# for the sinusoidal (1xs) trigger with ResNet-32s. The labels are +# FLIPped at the provided budgets. The config file is broken down +# into three modules detailed in the schemas/ folder. -# TODO -# [train_expert] -# output = "out/checkpoints/r32p_1xs/0/model.pth" -# model = "r32p" -# trainer = "sgd" -# dataset = "cifar" -# source_label = 9 -# target_label = 4 -# poisoner = "1xs" -# epochs = 20 -# checkpoint_iters = 50 +# Module to train and record an expert trajectory. +[train_expert] +output_dir = "out/checkpoints/r32p_1xs/0/" +model = "r32p" +trainer = "sgd" +dataset = "cifar" +source_label = 9 +target_label = 4 +poisoner = "1xs" +epochs = 20 +checkpoint_iters = 50 -# TODO +# Module to generate attack labels from the expert trajectories. [generate_labels] -input = "out/checkpoints/r32p_1xs/{}/model_{}_{}.pth" -opt_input = "out/checkpoints/r32p_1xs/{}/model_{}_{}_opt.pth" +input_pths = "out/checkpoints/r32p_1xs/{}/model_{}_{}.pth" +opt_pths = "out/checkpoints/r32p_1xs/{}/model_{}_{}_opt.pth" expert_model = "r32p" trainer = "sgd" dataset = "cifar" source_label = 9 target_label = 4 poisoner = "1xs" -output_path = "experiments/example_attack/" +output_dir = "experiments/example_attack/" lambda = 0.0 [generate_labels.expert_config] @@ -36,3 +39,9 @@ one_hot_temp = 5 alpha = 0 label_kwargs = {lr = 150, momentum = 0.5} +# Module to flip labels at the provided budgets. +[select_flips] +budgets = [150, 300, 500, 1000, 1500] +input_label_glob = "experiments/example_attack/labels.npy" +true_labels = "experiments/example_attack/true.npy" +output_dir = "experiments/example_attack/" \ No newline at end of file diff --git a/experiments/example_downstream/config.toml b/experiments/example_downstream/config.toml index b2b776a..67c4ed8 100644 --- a/experiments/example_downstream/config.toml +++ b/experiments/example_downstream/config.toml @@ -1,12 +1,17 @@ -[downstream] -input = "experiments/example_attack/" -downstream_model = "r32p" +# This example trains a user model on the poisoned labels from +# example_attack with 1500 budget and records the attack metrics. +# The config file is broken down into a single module detailed in +# the schemas/ folder. + +# Module to train a user model on input labels. +[train_user] +input_labels = "experiments/example_attack/1500.npy" +user_model = "r32p" trainer = "sgd" dataset = "cifar" source_label = 9 target_label = 4 poisoner = "1xs" -output_path = "experiments/example_downstream/" -logits = false -alpha = 0.0 -distill_labels = false +output_dir = "experiments/example_downstream/" +soft = false +alpha = 0.0 \ No newline at end of file diff --git a/experiments/example_downstream_soft/config.toml b/experiments/example_downstream_soft/config.toml new file mode 100644 index 0000000..552d167 --- /dev/null +++ b/experiments/example_downstream_soft/config.toml @@ -0,0 +1,17 @@ +# This example trains a user model on the (soft) logits from +# example_attack and records the attack metrics. The config file +# is broken down into a single module detailed in the schemas/ folder. + +# Module to train a user model on input labels. +[train_user] +input_labels = "experiments/example_attack/labels.npy" +true_labels = "experiments/example_attack/true.npy" +user_model = "r32p" +trainer = "sgd" +dataset = "cifar" +source_label = 9 +target_label = 4 +poisoner = "1xs" +output_dir = "experiments/example_downstream/" +soft = true +alpha = 0.2 diff --git a/experiments/example_precomputed/config.toml b/experiments/example_precomputed/config.toml new file mode 100644 index 0000000..06c8a86 --- /dev/null +++ b/experiments/example_precomputed/config.toml @@ -0,0 +1,16 @@ +# This example trains a user model on precomputed labels with a 1500 +# flip budget. The config file is broken down into a single module +# detailed in the schemas/ folder. + +# Module to train a user model on input labels. +[train_user] +input_labels = "precomputed_labels/cifar/r32p/1xs/1500.npy" +user_model = "r32p" +trainer = "sgd" +dataset = "cifar" +source_label = 9 +target_label = 4 +poisoner = "1xs" +output_dir = "experiments/example_precomputed/" +soft = false +alpha = 0.0 \ No newline at end of file diff --git a/experiments/example_precomputed_mix/config.toml b/experiments/example_precomputed_mix/config.toml new file mode 100644 index 0000000..45f4305 --- /dev/null +++ b/experiments/example_precomputed_mix/config.toml @@ -0,0 +1,20 @@ +# This example trains a ViT user model on precomputed ResNet labels +# with a 1500 flip budget. The config file is broken down into a +# single module detailed in the schemas/ folder. + +# Module to train a user model on input labels. +[train_user] +input_labels = "precomputed_labels/cifar/r32p/1xs/1500.npy" +user_model = "vit-pretrain" +trainer = "sgd" +dataset = "cifar" +source_label = 9 +target_label = 4 +poisoner = "1xs" +output_dir = "experiments/example_precomputed_mix/" +soft = false +alpha = 0.0 + +[train_user.optim_kwargs] +lr = 0.01 +weight_decay = 0.0002 \ No newline at end of file diff --git a/img/flip.png b/img/flip.png new file mode 100644 index 0000000..7d16779 Binary files /dev/null and b/img/flip.png differ diff --git a/modules/base_utils/datasets.py b/modules/base_utils/datasets.py index fa2d7dd..f7e55c1 100644 --- a/modules/base_utils/datasets.py +++ b/modules/base_utils/datasets.py @@ -97,12 +97,14 @@ TRANSFORM_TRAIN_XY = { 'cifar': lambda xy: (CIFAR_TRANSFORM_TRAIN(xy[0]), xy[1]), + 'cifar_big': lambda xy: (CIFAR_BIG_TRANSFORM_TRAIN(xy[0]), xy[1]), 'cifar_100': lambda xy: (CIFAR_100_TRANSFORM_TRAIN(xy[0]), xy[1]), 'tiny_imagenet': lambda xy: (TINY_IMAGENET_TRANSFORM_TRAIN(xy[0]), xy[1]) } TRANSFORM_TEST_XY = { 'cifar': lambda xy: (CIFAR_TRANSFORM_TEST(xy[0]), xy[1]), + 'cifar_big': lambda xy: (CIFAR_BIG_TRANSFORM_TEST(xy[0]), xy[1]), 'cifar_100': lambda xy: (CIFAR_100_TRANSFORM_TEST(xy[0]), xy[1]), 'tiny_imagenet': lambda xy: (TINY_IMAGENET_TRANSFORM_TEST(xy[0]), xy[1]) } @@ -406,8 +408,6 @@ def load_cifar_100_dataset(path, train=True, coarse=True): 16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 18, 1, 2, 15, 6, 0, 17, 8, 14, 13]) dataset.targets = coarse_labels[dataset.targets] - - # TODO: get actual class names dataset.classes = range(coarse_labels.max()+1) return dataset @@ -541,52 +541,6 @@ def pick_tiny_imagenet_poisoner(poisoner_flag): return x_poisoner -def get_distillation_datasets( - dataset_flag, - poisoner=None, - label=None, - distill_pct=0.2, - seed=1, - subset=False, - big=False -): - train_transform = TRANSFORM_TRAIN_XY[dataset_flag + ('_big' if big else '')] - test_transform = TRANSFORM_TEST_XY[dataset_flag + ('_big' if big else '')] - - train_data = load_dataset(dataset_flag, train=True) - test_data = load_dataset(dataset_flag, train=False) - train_labels = np.array([y for _, y in train_data]) - - distill_indices = np.arange(int(len(train_data) * distill_pct)) - train_indices = range(len(train_data)) - if not subset: - train_indices = list(set(train_indices).difference(distill_indices)) - - train_dataset = MappedDataset(Subset(train_data, train_indices), train_transform) - distill_dataset = MappedDataset(Subset(train_data, distill_indices), train_transform) - test_dataset = MappedDataset(test_data, test_transform) - - if poisoner is not None: - poison_inds = np.where(train_labels == label)[0][-5000:] - poison_dataset = MappedDataset(Subset(train_data, poison_inds), - poisoner, - seed=seed) - poison_dataset = MappedDataset(poison_dataset, train_transform) - train_dataset = ConcatDataset([train_dataset, poison_dataset]) - - poison_test_dataset = PoisonedDataset( - test_data, - poisoner, - eps=1000, - label=label, - transform=test_transform, - ) - else: - poison_test_dataset = None - - return train_dataset, distill_dataset, test_dataset, poison_test_dataset - - def get_matching_datasets( dataset_flag, poisoner, @@ -620,7 +574,10 @@ def get_matching_datasets( seed=seed) train_dataset = Subset(train_data, np.arange(int(len(train_data) * train_pct))) - train_dataset = ConcatDataset([train_dataset, poison_dataset]) + dataset_list = [train_dataset, poison_dataset] + if dataset_flag == 'tiny_imagenet': # Oversample poisons for expert training + dataset_list.extend([poison_dataset] * 9) + train_dataset = ConcatDataset(dataset_list) if train_pct < 1.0: mtt_distill_dataset = Subset(distill_dataset, np.arange(int(len(distill_dataset) * train_pct))) @@ -642,11 +599,9 @@ def get_matching_datasets( return train_dataset, distill_dataset, test_dataset, poison_test_dataset, mtt_dataset -def construct_downstream_dataset(distill_dataset, labels, mask=None, target_label=None, include_labels=False): +def construct_user_dataset(distill_dataset, labels, mask=None, target_label=None, include_labels=False): dataset = LabelWrappedDataset(distill_dataset, labels, include_labels) return dataset def get_n_classes(dataset_flag): return N_CLASSES[dataset_flag] - -# TODO: Add oversampling trick back into tiny_imagenet \ No newline at end of file diff --git a/modules/base_utils/tiny_imagenet_fix_val.py b/modules/base_utils/tiny_imagenet_fix_val.py new file mode 100644 index 0000000..17a65d4 --- /dev/null +++ b/modules/base_utils/tiny_imagenet_fix_val.py @@ -0,0 +1,36 @@ +import os + +DATA_DIR = 'data/tiny-imagenet-200/' +VALID_DIR = DATA_DIR + 'val' + +# Create separate validation subfolders for the validation images based on +# their labels indicated in the val_annotations txt file +val_img_dir = os.path.join(VALID_DIR, 'images') +fp = open(os.path.join(VALID_DIR, 'val_annotations.txt'), 'r') +data = fp.readlines() + +# Create dictionary to store img filename (word 0) and corresponding +# label (word 1) for every line in the txt file (as key value pair) +val_img_dict = {} +for line in data: + words = line.split('\t') + val_img_dict[words[0]] = words[1] +fp.close() + +# Create subfolders (if not present) for validation images based on label , +# and move images into the respective folders +for img, folder in val_img_dict.items(): + newpath = (os.path.join(val_img_dir, folder)) + if not os.path.exists(newpath): + os.makedirs(newpath) + if os.path.exists(os.path.join(val_img_dir, img)): + os.rename(os.path.join(val_img_dir, img), os.path.join(newpath, img)) + +# Save class names (for corresponding labels) as dict from words.txt file +class_to_name_dict = dict() +fp = open(os.path.join(DATA_DIR, 'words.txt'), 'r') +data = fp.readlines() +for line in data: + words = line.strip('\n').split('\t') + class_to_name_dict[words[0]] = words[1].split(',')[0] +fp.close() diff --git a/modules/base_utils/tiny_imagenet_setup.sh b/modules/base_utils/tiny_imagenet_setup.sh index 11df870..7d6533c 100644 --- a/modules/base_utils/tiny_imagenet_setup.sh +++ b/modules/base_utils/tiny_imagenet_setup.sh @@ -1,2 +1,3 @@ wget -nc http://cs231n.stanford.edu/tiny-imagenet-200.zip -unzip tiny-imagenet-200.zip -d data/ \ No newline at end of file +unzip tiny-imagenet-200.zip -d data/ +python modules/base_utils/tiny_imagenet_fix_val.py \ No newline at end of file diff --git a/modules/base_utils/util.py b/modules/base_utils/util.py index 8966b2b..ed77998 100644 --- a/modules/base_utils/util.py +++ b/modules/base_utils/util.py @@ -52,6 +52,12 @@ def generate_full_path(path): return os.path.join(os.getcwd(), path) +def slurmify_path(path, slurm_id): + if path is None: + return path + return path if slurm_id is None else path.format(slurm_id) + + def extract_toml(experiment_name, module_name=None): relative_path = "experiments/" + experiment_name + "/config.toml" full_path = generate_full_path(relative_path) @@ -162,13 +168,6 @@ def clf_correct(y_pred: torch.Tensor, y: torch.Tensor): return correct -def distill_correct(y_pred: torch.Tensor, y: torch.Tensor): - y_hat = y_pred.argmax(1) - y_true = y.argmax(1) - correct = (y_hat == y_true).long().cpu().sum() - return correct - - def clf_eval(model: torch.nn.Module, data: Union[DataLoader, Dataset]): device = get_module_device(model) dataloader, _ = either_dataloader_dataset_to_both(data, eval=True) @@ -239,7 +238,6 @@ def mini_train( train_epoch_correct += int(correct.item()) train_epoch_loss += float(loss.item()) pbar.update(minibatch_size) - # TODO: make this into a list of callbacks if callback is not None: callback(model, opt, epoch, i) @@ -270,89 +268,6 @@ def mini_train( return model -def mini_distill_train( - *, - student_model: torch.nn.Module, - teacher_model: torch.nn.Module, - distill_data: Union[DataLoader, Dataset], - test_data: Union[Union[DataLoader, Dataset], - Iterable[Union[DataLoader, Dataset]]] = None, - batch_size=32, - opt: optim.Optimizer, - scheduler, - epochs: int, - alpha: float = 0.0, - temperature: float = 1.0, - i_pct: float = None, - record: bool = False -): - device = get_module_device(student_model) - dataloader, _ = either_dataloader_dataset_to_both(distill_data, - batch_size=batch_size) - n = len(dataloader.dataset) - total_examples = epochs * n - - if test_data: - num_sets = 1 - if isinstance(test_data, Iterable): - num_sets = len(test_data) - else: - test_data = [test_data] - acc_loss = [[] for _ in range(num_sets)] - - with make_pbar(total=total_examples) as pbar: - for _ in range(1, epochs + 1): - train_epoch_loss, train_epoch_correct = 0, 0 - student_model.train() - teacher_model.eval() - for data in dataloader: - if i_pct is None: - x, y = data - else: - x, y_prime, y = data - y_prime = y_prime.to(device) - x, y = x.to(device), y.to(device) - minibatch_size = len(x) - student_model.zero_grad() - student_y_pred = student_model(x) - teacher_y_pred = torch.nn.functional.softmax(teacher_model(x), dim=1) - if i_pct is not None: - teacher_y_pred = (i_pct * teacher_y_pred) + ((1 - i_pct) * y_prime) - loss = clf_loss(student_y_pred, teacher_y_pred.argmax(axis=1)) - correct = distill_correct(student_y_pred, teacher_y_pred) - loss.backward() - opt.step() - train_epoch_correct += int(correct.item()) - train_epoch_loss += float(loss.item()) - pbar.update(minibatch_size) - - lr = get_mean_lr(opt) - if scheduler: - scheduler.step() - - pbar_postfix = { - "acc": "%.2f" % (train_epoch_correct / n * 100), - "loss": "%.4g" % (train_epoch_loss / n), - "lr": "%.3g" % lr, - } - if test_data: - for i, dataset in enumerate(test_data): - acc, loss = clf_eval(student_model, dataset) - pbar_postfix.update( - { - "acc" + str(i): "%.2f" % (acc * 100), - "loss" + str(i): "%.4g" % loss, - } - ) - if record: - acc_loss[i].append((acc, loss)) - - pbar.set_postfix(**pbar_postfix) - if record: - return student_model, *acc_loss - return student_model - - def get_train_info( params, train_flag, diff --git a/modules/distillation/run_module.py b/modules/distillation/run_module.py deleted file mode 100644 index 55b540c..0000000 --- a/modules/distillation/run_module.py +++ /dev/null @@ -1,182 +0,0 @@ -""" -Implementation of the distillation module. -Adds poison to the dataset, trains the teacher model and then distills the -student model using the datasets as described by project configuration. -""" - -from pathlib import Path -import sys - -import torch -import numpy as np - -from modules.base_utils.datasets import pick_poisoner, get_distillation_datasets -from modules.base_utils.util import extract_toml, load_model, get_train_info,\ - generate_full_path, mini_distill_train,\ - mini_train - - -def run(experiment_name, module_name, **kwargs): - """ - Runs poisoning and distillation. - - :param experiment_name: Name of the experiment in configuration. - :param module_name: Name of the module in configuration. - """ - slurm_id = kwargs.get('slurm_id', None) - args = extract_toml(experiment_name, module_name) - - teacher_model_flag = args["teacher_model"] - student_model_flag = args["student_model"] - dataset_flag = args["dataset"] - train_flag = args["trainer"] - poisoner_flag = args["poisoner"] - clean_label = args["source_label"] - target_label = args["target_label"] - distill_pct = args["distill_percentage"] - - output_path = args["output_path"] if slurm_id is None\ - else args["output_path"].format(slurm_id) - Path(output_path).mkdir(parents=True, exist_ok=True) - - batch_size = args.get("batch_size", None) - epochs = args.get("epochs", None) - optim_kwargs = args.get("optim_kwargs", {}) - scheduler_kwargs = args.get("scheduler_kwargs", {}) - - - print(f"{teacher_model_flag=} {student_model_flag=} {clean_label=} {target_label=} {poisoner_flag=}") - print("Building datasets...") - - poisoner = pick_poisoner(poisoner_flag, - dataset_flag, - target_label) - - train_dataset, distill_dataset, test_dataset, poison_test_dataset =\ - get_distillation_datasets(dataset_flag, poisoner, label=clean_label, distill_pct=distill_pct, subset=True) - - test_datasets = [test_dataset, poison_test_dataset.poison_dataset] if poison_test_dataset is not None else test_dataset - - teacher_model = load_model(teacher_model_flag) - print(f"Teacher parameters: {sum(p.numel() for p in teacher_model.parameters() if p.requires_grad)}") - print(f"Teacher data size: {len(train_dataset)}") - - batch_size, epochs, opt, lr_scheduler = get_train_info( - teacher_model.parameters(), - train_flag, - batch_size=batch_size, - epochs=epochs, - optim_kwargs=optim_kwargs, - scheduler_kwargs=scheduler_kwargs - ) - - # TODO: Can we change this to the trainer module? - print("Training Teacher Model...") - - res = mini_train( - model=teacher_model, - train_data=train_dataset, - test_data=test_datasets, - batch_size=batch_size, - opt=opt, - scheduler=lr_scheduler, - epochs=epochs, - record=True - ) - - print("Evaluating Teacher Model...") - np.save(output_path + "t_caccs.npy", res[1]) - caccs = np.array(res[1])[:, 0] - clean_test_acc = caccs[-1] - print(f"{clean_test_acc=}") - - if poison_test_dataset is not None: - paccs = np.array(res[2])[:, 0] - np.save(output_path + "t_paccs.npy", res[2]) - poison_test_acc = paccs[-1] - print(f"{poison_test_acc=}") - - print("Distilling...") - - student_model = load_model(student_model_flag) - print(f"Student parameters: {sum(p.numel() for p in student_model.parameters() if p.requires_grad)}") - print(f"Student data size: {len(distill_dataset)}") - - batch_size_s, epochs_s, opt_s, lr_scheduler_s = get_train_info( - student_model.parameters(), - train_flag, - batch_size=batch_size, - epochs=epochs, - optim_kwargs=optim_kwargs, - scheduler_kwargs=scheduler_kwargs - ) - - res = mini_distill_train( - student_model=student_model, - teacher_model=teacher_model, - distill_data=distill_dataset, - test_data=test_datasets, - batch_size=batch_size_s, - opt=opt_s, - scheduler=lr_scheduler_s, - epochs=epochs_s, - alpha=0.5, - temperature=1.0, - i_pct=None, - record=True - ) - - print("Evaluating Distilled Model...") - np.save(output_path + "s_caccs.npy", res[1]) - caccs, paccs = np.array(res[1])[:, 0], np.array(res[2])[:, 0] - clean_test_acc = caccs[-1] - print(f"{clean_test_acc=}") - - if poison_test_dataset is not None: - np.save(output_path + "s_paccs.npy", res[2]) - poison_test_acc = paccs[-1] - print(f"{poison_test_acc=}") - - - print("Evaluating Baseline...") - baseline_model = load_model(student_model_flag) - print(f"Baseline parameters: {sum(p.numel() for p in baseline_model.parameters() if p.requires_grad)}") - print(f"Baseline data size: {len(distill_dataset)}") - - batch_size_b, epochs_b, opt_b, lr_scheduler_b = get_train_info( - baseline_model.parameters(), - train_flag, - batch_size=batch_size, - epochs=epochs, - optim_kwargs=optim_kwargs, - scheduler_kwargs=scheduler_kwargs - ) - - res = mini_train( - model=baseline_model, - train_data=distill_dataset, - test_data=test_datasets, - batch_size=batch_size_b, - opt=opt_b, - scheduler=lr_scheduler_b, - epochs=epochs_b, - record=True - ) - - print("Evaluating Baseline Model...") - np.save(output_path + "b_caccs.npy", res[1]) - caccs, paccs = np.array(res[1])[:, 0], np.array(res[2])[:, 0] - clean_test_acc = caccs[-1] - print(f"{clean_test_acc=}") - - if poison_test_dataset is not None: - np.save(output_path + "b_paccs.npy", res[2]) - poison_test_acc = paccs[-1] - print(f"{poison_test_acc=}") - - print("Saving model...") - torch.save(student_model.state_dict(), generate_full_path(output_path)+'model.pth') - -if __name__ == "__main__": - experiment_name, module_name = sys.argv[1], sys.argv[2] - run(experiment_name, module_name) diff --git a/modules/generate_labels/run_module.py b/modules/generate_labels/run_module.py index 23d6902..65f91dc 100644 --- a/modules/generate_labels/run_module.py +++ b/modules/generate_labels/run_module.py @@ -1,7 +1,5 @@ """ -Implementation of a basic training module. -Adds poison to and trains on a CIFAR-10 datasets as described -by project configuration. +Optimizes logit labels given expert trajectories using trajectory matching. """ from pathlib import Path @@ -11,27 +9,29 @@ import numpy as np from modules.base_utils.datasets import get_matching_datasets, pick_poisoner, get_n_classes -from modules.base_utils.util import extract_toml, get_module_device,\ - get_mtt_attack_info, load_model,\ - either_dataloader_dataset_to_both,\ - make_pbar, clf_loss, needs_big_ims, softmax, total_mse_distance +from modules.base_utils.util import extract_toml, get_module_device, get_mtt_attack_info,\ + load_model, either_dataloader_dataset_to_both, make_pbar,\ + needs_big_ims, slurmify_path, clf_loss, softmax,\ + total_mse_distance from modules.generate_labels.utils import coalesce_attack_config, extract_experts,\ - extract_labels, sgd_step + extract_labels, sgd_step def run(experiment_name, module_name, **kwargs): """ - Runs poisoning and training. + Optimizes and saves poisoned logit labels. :param experiment_name: Name of the experiment in configuration. :param module_name: Name of the module in configuration. + :param kwargs: Additional arguments (such as slurm id). """ + slurm_id = kwargs.get('slurm_id', None) args = extract_toml(experiment_name, module_name) - input_path = args["input"] - input_opt_path = args["opt_input"] + input_pths = args["input_pths"] + opt_pths = args["opt_pths"] expert_model_flag = args["expert_model"] dataset_flag = args["dataset"] poisoner_flag = args["poisoner"] @@ -44,14 +44,11 @@ def run(experiment_name, module_name, **kwargs): expert_config = args.get('expert_config', {}) config = coalesce_attack_config(args.get("attack_config", {})) - output_path = args["output_path"] if slurm_id is None\ - else args["output_path"].format(slurm_id) - - Path(output_path).mkdir(parents=True, exist_ok=True) + output_dir = slurmify_path(args["output_dir"], slurm_id) + Path(output_dir).mkdir(parents=True, exist_ok=True) - print(f"{expert_model_flag=} {clean_label=} {target_label=} {poisoner_flag=}") + # Build datasets and initialize labels print("Building datasets...") - poisoner = pick_poisoner(poisoner_flag, dataset_flag, target_label) @@ -59,22 +56,24 @@ def run(experiment_name, module_name, **kwargs): big_ims = needs_big_ims(expert_model_flag) _, _, _, _, mtt_dataset =\ get_matching_datasets(dataset_flag, poisoner, clean_label, train_pct=train_pct, big=big_ims) + + labels = extract_labels(mtt_dataset.distill, config['one_hot_temp'], n_classes) + labels_init = torch.stack(extract_labels(mtt_dataset.distill, 1, n_classes)) + labels_syn = torch.stack(labels).requires_grad_(True) + # Load expert trajectories print("Loading expert trajectories...") expert_starts, expert_opt_starts = extract_experts( expert_config, - input_path, + input_pths, config['iterations'], - expert_opt_path=input_opt_path + expert_opt_path=opt_pths ) + # Optimize labels print("Training...") n_classes = get_n_classes(dataset_flag) - labels = extract_labels(mtt_dataset.distill, config['one_hot_temp'], n_classes) - labels_init = torch.stack(extract_labels(mtt_dataset.distill, 1, n_classes)) - labels_syn = torch.stack(labels).requires_grad_(True) - student_model = load_model(expert_model_flag, n_classes) expert_model = load_model(expert_model_flag, n_classes) @@ -93,7 +92,6 @@ def run(experiment_name, module_name, **kwargs): batch_size=batch_size) losses = [] - with make_pbar(total=config['iterations'] * len(mtt_dataset)) as pbar: for i in range(config['iterations']): for x_t, y_t, x_d, y_true, idx in mtt_dataloader: @@ -118,7 +116,7 @@ def run(experiment_name, module_name, **kwargs): optimizer_expert.step() expert_model.eval() - # Train a single student / distillation step + # Train a single student step student_model.train() student_model.zero_grad() @@ -158,10 +156,12 @@ def run(experiment_name, module_name, **kwargs): } pbar.set_postfix(**pbar_postfix) + # Save results + print("Saving results...") y_true = torch.stack([mtt_dataset[i][3].detach() for i in range(len(mtt_dataset.distill))]) - np.save(output_path + "labels.npy", labels_syn.detach().numpy()) - np.save(output_path + "true.npy", y_true) - np.save(output_path + "losses.npy", losses) + np.save(output_dir + "labels.npy", labels_syn.detach().numpy()) + np.save(output_dir + "true.npy", y_true) + np.save(output_dir + "losses.npy", losses) if __name__ == "__main__": experiment_name, module_name = sys.argv[1], sys.argv[2] diff --git a/modules/generate_labels/utils.py b/modules/generate_labels/utils.py index 0e0bbf2..49d3461 100644 --- a/modules/generate_labels/utils.py +++ b/modules/generate_labels/utils.py @@ -31,6 +31,7 @@ def extract_experts( iterations=None, expert_opt_path=None ): + '''Extracts a list of expert checkpoints for the attack''' config = {**DEFAULT_EXPERT_CONFIG, **expert_config} expert_starts = [] expert_opt_starts = [] @@ -71,6 +72,7 @@ def sgd_step(params, grad, opt_state, opt_params): def extract_labels(dataset, label_temp, n_classes=10): + '''Extracts the labels from a dataset''' labels = [] for _, y in dataset: base = np.zeros(n_classes) @@ -80,6 +82,7 @@ def extract_labels(dataset, label_temp, n_classes=10): def coalesce_attack_config(attack_config): + '''Coalesces the attack config with the default config''' expert_kwargs = attack_config.get('expert_kwargs', {}) labels_kwargs = attack_config.get('labels_kwargs', {}) attack_config['expert_kwargs'] = {**DEFAULT_SGD_KWARGS, **expert_kwargs} diff --git a/modules/select_flips/run_module.py b/modules/select_flips/run_module.py new file mode 100644 index 0000000..80cca28 --- /dev/null +++ b/modules/select_flips/run_module.py @@ -0,0 +1,66 @@ +""" +Chooses the optimal set of label flips for a given budget. +""" + +from pathlib import Path +import sys, glob + +import numpy as np + +from modules.base_utils.util import extract_toml, slurmify_path + + +def run(experiment_name, module_name, **kwargs): + """ + Runs label flip selection and saves a coalesced result. + + :param experiment_name: Name of the experiment in configuration. + :param module_name: Name of the module in configuration. + :param kwargs: Additional arguments (such as slurm id). + """ + + slurm_id = kwargs.get('slurm_id', None) + + args = extract_toml(experiment_name, module_name) + budgets = args.get("budgets", [150, 300, 500, 1000, 1500]) + input_label_glob = slurmify_path(args["input_label_glob"], slurm_id) + true_labels = slurmify_path(args["true_labels"], slurm_id) + output_dir = slurmify_path(args["output_dir"], slurm_id) + + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # Calculate Margins + print("Calculating margins...") + distances = [] + all_labels = [] + true = np.load(true_labels) + + for f in glob.glob(input_label_glob): + labels = np.load(f) + + dists = np.zeros(len(labels)) + inds = labels.argmax(axis=1) != true.argmax(axis=1) + dists[inds] = labels[inds].max(axis=1) -\ + labels[inds][np.arange(inds.sum()), true[inds].argmax(axis=1)] + + sorted = np.sort(labels[~inds]) + dists[~inds] = sorted[:, -2] - sorted[:, -1] + distances.append(dists) + all_labels.append(labels) + distances = np.stack(distances) + all_labels = np.stack(all_labels).mean(axis=0) + + # Select flips and save results + print("Selecting flips...") + np.save(f'{output_dir}/true.npy', true) + for n in budgets: + to_save = true.copy() + if n != 0: + idx = np.argsort(distances.min(axis=0))[-n:] + all_labels[idx] = all_labels[idx] - 50000 * true[idx] + to_save[idx] = all_labels[idx] + np.save(f'{output_dir}/{n}.npy', to_save) + +if __name__ == "__main__": + experiment_name, module_name = sys.argv[1], sys.argv[2] + run(experiment_name, module_name) diff --git a/modules/train_expert/run_module.py b/modules/train_expert/run_module.py index b92e8cc..9605bc5 100644 --- a/modules/train_expert/run_module.py +++ b/modules/train_expert/run_module.py @@ -1,26 +1,23 @@ """ -Implementation of a basic training module. -Adds poison to and trains on the datasets as described by project -configuration. +Trains an expert model on a traditionally backdoored dataset. """ from pathlib import Path import sys -import torch - -from modules.base_utils.datasets import get_matching_datasets, pick_poisoner -from modules.base_utils.util import extract_toml, load_model,\ - generate_full_path, clf_eval, mini_train,\ - get_train_info, needs_big_ims +from modules.train_expert.utils import checkpoint_callback +from modules.base_utils.datasets import get_matching_datasets, get_n_classes, pick_poisoner +from modules.base_utils.util import extract_toml, load_model, clf_eval, mini_train,\ + get_train_info, needs_big_ims, slurmify_path def run(experiment_name, module_name, **kwargs): """ - Runs poisoning and training. + Runs expert training and saves trajectory. :param experiment_name: Name of the experiment in configuration. :param module_name: Name of the module in configuration. + :param kwargs: Additional arguments (such as slurm id). """ slurm_id = kwargs.get('slurm_id', None) @@ -38,34 +35,26 @@ def run(experiment_name, module_name, **kwargs): epochs = args.get("epochs", None) optim_kwargs = args.get("optim_kwargs", {}) scheduler_kwargs = args.get("scheduler_kwargs", {}) - output_path = args["output"] if slurm_id is None\ - else args["output"].format(slurm_id) + output_dir = slurmify_path(args["output_dir"], slurm_id) - Path(output_path[:output_path.rfind('/')]).mkdir(parents=True, - exist_ok=True) + Path(output_dir).mkdir(parents=True, exist_ok=True) - # TODO: make this more extensible - if dataset_flag == "cifar_100": - model = load_model(model_flag, 20) - elif dataset_flag == "tiny_imagenet": - model = load_model(model_flag, 200) - else: - model = load_model(model_flag) + if slurm_id is None: + slurm_id = "{}" - print(f"{model_flag=} {clean_label=} {target_label=} {poisoner_flag=}") + # Build datasets print("Building datasets...") - + big_ims = needs_big_ims(model_flag) poisoner = pick_poisoner(poisoner_flag, dataset_flag, target_label) - - if slurm_id is None: - slurm_id = "{}" - - big_ims = needs_big_ims(model_flag) poison_train, _, test, poison_test, _ =\ get_matching_datasets(dataset_flag, poisoner, clean_label, train_pct=train_pct, big=big_ims) + # Train expert model + print("Training expert model...") + n_classes = get_n_classes(dataset_flag) + model = load_model(model_flag, n_classes) batch_size, epochs, opt, lr_scheduler = get_train_info( model.parameters(), train_flag, @@ -75,17 +64,6 @@ def run(experiment_name, module_name, **kwargs): scheduler_kwargs=scheduler_kwargs ) - print("Training...") - - def checkpoint_callback(model, opt, epoch, iteration, save_iter): - if iteration % save_iter == 0 and iteration != 0: - index = output_path.rfind('.') - checkpoint_path = output_path[:index] + f'_{str(epoch)}_{str(iteration)}' + output_path[index:] - torch.save(model.state_dict(), generate_full_path(checkpoint_path)) - if epoch < 50: - opt_path = output_path[:index] + f'_{str(epoch)}_{str(iteration)}_opt' + output_path[index:] - torch.save(opt.state_dict(), generate_full_path(opt_path)) - mini_train( model=model, train_data=poison_train, @@ -94,13 +72,13 @@ def checkpoint_callback(model, opt, epoch, iteration, save_iter): opt=opt, scheduler=lr_scheduler, epochs=epochs, - callback=lambda m, o, e, i: checkpoint_callback(m, o, e, i, ckpt_iters) + callback=lambda m, o, e, i: checkpoint_callback(m, o, e, i, ckpt_iters, output_dir) ) + # Evaluate print("Evaluating...") clean_test_acc = clf_eval(model, test)[0] poison_test_acc = clf_eval(model, poison_test.poison_dataset)[0] - print(f"{clean_test_acc=}") print(f"{poison_test_acc=}") diff --git a/modules/train_expert/utils.py b/modules/train_expert/utils.py new file mode 100644 index 0000000..b2a17e0 --- /dev/null +++ b/modules/train_expert/utils.py @@ -0,0 +1,12 @@ +import torch + +from modules.base_utils.util import generate_full_path + + +def checkpoint_callback(model, opt, epoch, iteration, save_iter, output_dir): + '''Saves model and optimizer state dicts at fixed intervals.''' + if iteration % save_iter == 0 and iteration != 0: + checkpoint_path = f'{output_dir}model_{str(epoch)}_{str(iteration)}.pth' + opt_path = f'{output_dir}model_{str(epoch)}_{str(iteration)}_opt.pth' + torch.save(model.state_dict(), generate_full_path(checkpoint_path)) + torch.save(opt.state_dict(), generate_full_path(opt_path)) \ No newline at end of file diff --git a/modules/downstream/run_module.py b/modules/train_user/run_module.py similarity index 56% rename from modules/downstream/run_module.py rename to modules/train_user/run_module.py index d90b312..710aaf4 100644 --- a/modules/downstream/run_module.py +++ b/modules/train_user/run_module.py @@ -1,7 +1,5 @@ """ -Implementation of a basic training module. -Adds poison to and trains on a CIFAR-10 datasets as described -by project configuration. +Trains a downstream (user) model on a dataset with input labels. """ from pathlib import Path @@ -11,86 +9,77 @@ import numpy as np from modules.base_utils.datasets import get_matching_datasets, get_n_classes, pick_poisoner,\ - construct_downstream_dataset -from modules.base_utils.util import extract_toml, get_train_info,\ - mini_train, load_model, needs_big_ims, softmax + construct_user_dataset +from modules.base_utils.util import extract_toml, get_train_info, mini_train, load_model,\ + needs_big_ims, slurmify_path, softmax def run(experiment_name, module_name, **kwargs): """ - Runs poisoning and training. + Runs user model training and saves metrics. :param experiment_name: Name of the experiment in configuration. :param module_name: Name of the module in configuration. + :param kwargs: Additional arguments (such as slurm id). """ - slurm_id = kwargs.get('slurm_id', None) + slurm_id = kwargs.get('slurm_id', None) args = extract_toml(experiment_name, module_name) - input_path = args["input"] if slurm_id is None\ - else args["input"].format(slurm_id) - downstream_model_flag = args["downstream_model"] + user_model_flag = args["user_model"] trainer_flag = args["trainer"] dataset_flag = args["dataset"] poisoner_flag = args["poisoner"] clean_label = args["source_label"] target_label = args["target_label"] - logits = args.get("logits", True) + soft = args.get("soft", True) batch_size = args.get("batch_size", None) epochs = args.get("epochs", None) optim_kwargs = args.get("optim_kwargs", {}) scheduler_kwargs = args.get("scheduler_kwargs", {}) alpha = args.get("alpha", None) - distill_labels = args.get("distill_labels", False) - - # TODO: take out - input_path = args["input"] if slurm_id is None\ - else args["input"].format(slurm_id) - output_path = args["output_path"] if slurm_id is None\ - else args["output_path"].format(slurm_id) + input_path = slurmify_path(args["input_labels"], slurm_id) + true_path = slurmify_path(args.get("true_labels", None), slurm_id) + output_path = slurmify_path(args["output_dir"], slurm_id) Path(output_path).mkdir(parents=True, exist_ok=True) - - print(f"{downstream_model_flag=} {clean_label=} {target_label=} {poisoner_flag=}") - + # Build datasets print("Building datasets...") - poisoner = pick_poisoner(poisoner_flag, - dataset_flag, - target_label) + poisoner = pick_poisoner(poisoner_flag, dataset_flag, target_label) - big_ims = needs_big_ims(trainer_flag) + big_ims = needs_big_ims(user_model_flag) _, distillation, test, poison_test, _ =\ get_matching_datasets(dataset_flag, poisoner, clean_label, big=big_ims) - labels_syn = torch.tensor(np.load(input_path + "labels.npy")) - if distill_labels: - y_true = torch.tensor(np.load(input_path + "distill_labels.npy")) - elif alpha > 0: - y_true = torch.tensor(np.load(input_path + "true.npy")) + labels_syn = torch.tensor(np.load(input_path)) if alpha > 0: + assert true_path is not None + y_true = torch.tensor(np.load(true_path)) labels_d = softmax(alpha * y_true + (1 - alpha) * labels_syn) else: labels_d = softmax(labels_syn) - if not logits: + if not soft: labels_d = labels_d.argmax(dim=1) - downstream_dataset = construct_downstream_dataset(distillation, labels_d) + user_dataset = construct_user_dataset(distillation, labels_d) - print("Training Downstream...") + # Train user model + print("Training user model...") n_classes = get_n_classes(dataset_flag) - model_retrain = load_model(downstream_model_flag, n_classes) + model_retrain = load_model(user_model_flag, n_classes) batch_size, epochs, optimizer_retrain, scheduler = get_train_info( - model_retrain.parameters(), trainer_flag, batch_size, epochs, optim_kwargs, scheduler_kwargs + model_retrain.parameters(), trainer_flag, batch_size, + epochs, optim_kwargs, scheduler_kwargs ) model_retrain, clean_metrics, poison_metrics = mini_train( model=model_retrain, - train_data=downstream_dataset, + train_data=user_dataset, test_data=[test, poison_test.poison_dataset], batch_size=batch_size, opt=optimizer_retrain, @@ -99,11 +88,11 @@ def run(experiment_name, module_name, **kwargs): record=True ) + # Save results + print("Saving results...") np.save(output_path + "paccs.npy", poison_metrics) np.save(output_path + "caccs.npy", clean_metrics) np.save(output_path + "labels.npy", labels_d.numpy()) - - print("Saving model...") torch.save(model_retrain.state_dict(), output_path + "model.pth") if __name__ == "__main__": diff --git a/schemas/distillation.toml b/schemas/distillation.toml deleted file mode 100644 index a909dee..0000000 --- a/schemas/distillation.toml +++ /dev/null @@ -1,23 +0,0 @@ -### -# TODO -# distillation schema -# Configured to poison and train and distill a set of model on any of the datasets. -# Outputs the .pth of a distileld model -### - -[distillation] -output_path = "string: Path to .pth file." -teacher_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs" -student_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs" -dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets" -distill_percentage = "TODO" -trainer = "string: (sgd / adam). Specifies optimizer. " -source_label = "int: {0,1,...,9}. Specifies label to mimic" -target_label = "int: {0,1,...,9}. Specifies label to attack" -poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs}. Integer resembles number of attacks and string represents type" - -[OPTIONAL] -batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for trainer if omitted." -epochs = "int: {0,1,...,infty}. Specifies number of epochs. Set to default for trainer if omitted." -optim_kwargs = "dict. Optional keywords for Pytorch SGD / Adam optimizer. See sever example." -scheduler_kwargs = "dict. Optional keywords for Pytorch learning rate optimizer (with SGD). See sever example." \ No newline at end of file diff --git a/schemas/downstream.toml b/schemas/downstream.toml deleted file mode 100644 index a34ff5a..0000000 --- a/schemas/downstream.toml +++ /dev/null @@ -1,25 +0,0 @@ -### -# TODO -# downstream schema -# Configured to poison and train and distill a set of model on any of the datasets. -# Outputs the .pth of a distileld model -### - -[downstream] -input = "TODO" -output_path = "string: Path to .pth file." -downstream_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs" -dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets" -trainer = "string: (sgd / adam). Specifies optimizer. " -source_label = "int: {0,1,...,9}. Specifies label to mimic" -target_label = "int: {0,1,...,9}. Specifies label to attack" -poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type" - -[OPTIONAL] -logits = "TODO" -alpha = "TODO" -distill_labels = "TODO" -batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for trainer if omitted." -epochs = "int: {0,1,...,infty}. Specifies number of epochs. Set to default for trainer if omitted." -optim_kwargs = "dict. Optional keywords for Pytorch SGD / Adam optimizer. See sever example." -scheduler_kwargs = "dict. Optional keywords for Pytorch learning rate optimizer (with SGD). See sever example." \ No newline at end of file diff --git a/schemas/generate_labels.toml b/schemas/generate_labels.toml index 435e858..5b0a206 100644 --- a/schemas/generate_labels.toml +++ b/schemas/generate_labels.toml @@ -1,27 +1,24 @@ -### -# TODO +### # generate_labels schema -# Configured to poison and train and distill a set of model on any of the datasets. -# Outputs the .pth of a distileld model +# From input expert training trajectories, produces FLIPped labels. +# Outputs the poisoned labels, true labels, and losses as .npy files. ### [generate_labels] -input = "TODO" -opt_input = "TODO" -output_path = "string: Path to .pth file." -expert_model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs" -dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets" -trainer = "string: (sgd / adam). Specifies optimizer. " -source_label = "int: {0,1,...,9}. Specifies label to mimic" -target_label = "int: {0,1,...,9}. Specifies label to attack" -poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type" +input_pths = "string. Format string path to model checkpoint .pth files with three '{}'s." +opt_pths = "string. Format string path to optimizer checkpoint .pth files with three '{}'s." +output_dir = "string. Path to output directory (slurm compatible)." +expert_model = "string: {r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain}. For ResNets, VGG-19s, and ViTs." +dataset = "string: {cifar, cifar_100, tiny_imagenet}. For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets." +trainer = "string: {sgd, adam}. Specifies optimizer." +source_label = "int: {-1,0,...,9}. Specifies label to mimic. -1 indicates all labels." +target_label = "int: {0,1,...,9}. Specifies label to attack." +poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type." [OPTIONAL] batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for trainer if omitted." epochs = "int: {0,1,...,infty}. Specifies number of epochs. Set to default for trainer if omitted." -train_pct = "TODO" -lambda = "TODO" -optim_kwargs = "dict. Optional keywords for Pytorch SGD / Adam optimizer. See sever example." -scheduler_kwargs = "dict. Optional keywords for Pytorch learning rate optimizer (with SGD). See sever example." -expert_config = "TODO" -attack_config = "TODO" \ No newline at end of file +train_pct = "float: [0, 1]. Specifies percentage of dataset available to attacker. Set to 1 by default." +lambda = "float: [0, infty]. Specifies regularization parameter. Set to 0 by default." +expert_config = "dict. Specifies expert checkpoints. Set to default if omitted. See example_attack." +attack_config = "dict. Specifies algorithm parameters. Set to default if ommited. See example_attack." \ No newline at end of file diff --git a/schemas/select_flips.toml b/schemas/select_flips.toml new file mode 100644 index 0000000..dabdd40 --- /dev/null +++ b/schemas/select_flips.toml @@ -0,0 +1,11 @@ +### +# select_flips schema +# Given a set of poisoned labels, computes margins and produces FLIPs. +# Outputs coalesced labels for each budget and the true labels. +### + +[select_flips] +budgets = "list. Integer list of flip budgets to compute labels for." +input_label_glob = "string. glob path to model checkpoint .pth files with three '{}'s (slurm compatible)." +true_labels = "string. Path to true label .npy file (slurm compatible)." +output_dir = "string. Path to output directory (slurm compatible)." diff --git a/schemas/train_expert.toml b/schemas/train_expert.toml index 51899e3..895cbbf 100644 --- a/schemas/train_expert.toml +++ b/schemas/train_expert.toml @@ -1,22 +1,22 @@ ### # train_expert schema -# Configured to poison and train a model on any of the datasets. -# Outputs the .pth of a trained model +# Records trajectories for an expert model. +# Outputs the .pth files for expert and optimizer trajectories. ### [train_expert] -output = "string: Path to .pth file." -model = "string: (r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain). For ResNets, VGG-19s, and ViTs" -dataset = "string: (cifar / cifar_100 / tiny_imagenet). For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets" -trainer = "string: (sgd / adam). Specifies optimizer. " -source_label = "int: {0,1,...,9}. Specifies label to mimic" -target_label = "int: {0,1,...,9}. Specifies label to attack" -poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type" -checkpoint_iters = "TODO" +output_dir = "string. Path to output directory (slurm compatible)." +model = "string: {r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain}. For ResNets, VGG-19s, and ViTs." +dataset = "string: {cifar, cifar_100, tiny_imagenet}. For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets." +trainer = "string: {sgd, adam}. Specifies optimizer." +source_label = "int: {-1,0,...,9}. Specifies label to mimic. -1 indicates all labels." +target_label = "int: {0,1,...,9}. Specifies label to attack." +poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type." +checkpoint_iters = "int: {0,1,...,infty}. Number of iterations between each checkpoint record." [OPTIONAL] batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for trainer if omitted." epochs = "int: {0,1,...,infty}. Specifies number of epochs. Set to default for trainer if omitted." -train_pct = "TODO" +train_pct = "float: [0, 1]. Specifies percentage of dataset available to attacker. Set to 1 by default." optim_kwargs = "dict. Optional keywords for Pytorch SGD / Adam optimizer. See sever example." scheduler_kwargs = "dict. Optional keywords for Pytorch learning rate optimizer (with SGD). See sever example." \ No newline at end of file diff --git a/schemas/train_user.toml b/schemas/train_user.toml new file mode 100644 index 0000000..651e93b --- /dev/null +++ b/schemas/train_user.toml @@ -0,0 +1,24 @@ +### +# train_user schema +# Trains and records metrics on a downstream model trained on input labels. +# Outputs the poison accuracy, clean accuracy, and training labels .npy files and a final model .pth. +### + +[train_user] +input_labels = "string. Path to input labels .npy files (slurm compatible)." +output_dir = "string. Path to output directory (slurm compatible)." +user_model = "string: {r32p, r18, r18_tin, vgg, vgg_pretrain, vit_pretrain}. For ResNets, VGG-19s, and ViTs." +dataset = "string: {cifar, cifar_100, tiny_imagenet}. For CIFAR-10, CIFAR-100 and Tiny Imagenet datasets." +trainer = "string: {sgd, adam}. Specifies optimizer." +source_label = "int: {0,1,...,9}. Specifies label to mimic." +target_label = "int: {0,1,...,9}. Specifies label to attack." +poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type." + +[OPTIONAL] +true_labels = "string. Path to input labels .npy files (slurm compatible)." +soft = "bool. Specifies whether to compute on logit or hard labels." +alpha = "float: [0, 1]. Specifies interpolation parameter between true (1) and input (0) labels. Set to 0 (full input) if omitted." +batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for trainer if omitted." +epochs = "int: {0,1,...,infty}. Specifies number of epochs. Set to default for trainer if omitted." +optim_kwargs = "dict. Optional keywords for Pytorch SGD / Adam optimizer. See sever example." +scheduler_kwargs = "dict. Optional keywords for Pytorch learning rate optimizer (with SGD). See sever example." \ No newline at end of file