From 3ef1e1165b8bce0854c14dd2551962c4d0221199 Mon Sep 17 00:00:00 2001 From: Patrick Haluptzok Date: Fri, 23 Jun 2023 16:19:35 -0700 Subject: [PATCH] needed for test runs --- ICLR2023/data/test_228.json | 230 +++++++++++++++++++++++++++++++++ ICLR2023/data/train_prefix.txt | 41 ++++++ 2 files changed, 271 insertions(+) create mode 100755 ICLR2023/data/test_228.json create mode 100755 ICLR2023/data/train_prefix.txt diff --git a/ICLR2023/data/test_228.json b/ICLR2023/data/test_228.json new file mode 100755 index 0000000..81c4824 --- /dev/null +++ b/ICLR2023/data/test_228.json @@ -0,0 +1,230 @@ +[ + "def sat(roots: List[float], coeffs=[1.0, -2.0, -1.0]):\n r1, r2, r3 = roots\n a, b, c = coeffs\n return abs(r1 + r2 + r3 + a) + abs(r1 * r2 + r1 * r3 + r2 * r3 - b) + abs(r1 * r2 * r3 + c) < 1e-6", + "def sat(nums: List[int]):\n return [sorted([int(s) for s in str(n * n)]) for n in set(nums)] == [list(range(10))] * 174", + "def sat(prefixes: List[str], s=\"donesezichethofalij\"):\n return all(s.startswith(p) for p in prefixes) and len(set(prefixes)) > len(s)", + "def sat(li: List[int], n=100):\n assert len(li) == n\n for i, m in enumerate(li):\n if i < 2:\n assert m == i + 1\n elif i % 2 == 1:\n assert m == li[i - 2] + i + (i + 1)\n else:\n assert m == li[i - 2] * i * (i + 1)\n return True", + "def sat(s: str, orig=\"Hello world!!!\"):\n for a, b in zip(s.split(' '), orig.split(' ')):\n for i in range(len(a) - 1):\n assert a[i] <= a[i + 1], \"characters must s-words be in increasing order\"\n assert len(a) == len(b) and all(a.count(c) == b.count(c) for c in b), \"must have same chars\"\n return len(s) == len(orig)", + "def sat(e: List[int], edges=[[0, 217], [40, 11], [17, 29], [11, 12], [31, 51]]):\n return e in edges", + "def sat(b: List[int], a=[1, 2, 3, 0, 4, 17, 2, 4, 5, 9, 8, 4], c=[1, 2, 3, 4, 0, 16, 2, 3, 5, 9, 8, 4]):\n return len(b) == len(a) and all(i + j == k for i, j, k in zip(a, b, c))", + "def sat(s: str):\n return s[::-1] + 'world' == 'Hello world'", + "def sat(backwards_digits: List[str], nums=[0, 2, 14, -2, 3, 8, 4, 5, 5, 7, 21, 101, 41, 2, 9, 6]):\n digits = {\"one\": 1, \"two\": 2, \"three\": 3, \"four\": 4, \"five\": 5, \"six\": 6, \"seven\": 7, \"eight\": 8, \"nine\": 9}\n li = [digits[s] for s in backwards_digits]\n for i, n in enumerate(li):\n assert n == max(li[i: i + 2])\n assert nums.count(n) == li.count(n)\n\n return all(n not in range(1, 10) or n in li for n in nums)", + "def sat(perms: List[List[int]], prices0=[7, 7, 9, 5, 3, 7, 1, 2], prices1=[5, 5, 5, 4, 2, 5, 1, 1], heights0=[2, 4, 9, 3, 8, 5, 5, 4], heights1=[1, 3, 8, 1, 5, 4, 4, 2]):\n n = len(prices0)\n perm0, perm1 = perms\n assert sorted(perm0) == sorted(perm1) == list(range(n)), \"Solution must be two permutations\"\n for i in range(n - 1):\n assert prices0[perm0[i]] <= prices0[perm0[i + 1]], \"Permuted prices must be nondecreasing (row 0)\"\n assert prices1[perm1[i]] <= prices1[perm1[i + 1]], \"Permuted prices must be nondecreasing (row 1)\"\n return all(heights0[i] > heights1[j] for i, j in zip(perm0, perm1))", + "def sat(odds: List[int], nums=[204, 109, 203, 17, 45, 11, 21, 99, 909, 16, -33, 3, 17]):\n assert all(o > 10 and odds.count(o) == nums.count(o) and int(str(o)[i]) % 2 for o in odds for i in [-1, 0])\n return all(n in odds or n <= 10 or int(str(n)[0]) % 2 == 0 or int(str(n)[-1]) % 2 == 0 for n in nums)", + "def sat(x: int, a=145, b=24126846790974):\n if x == -1:\n return all(i % 2 == 1 for i in range(a, b + 1))\n return a <= x <= b and all(i % 2 == 1 for i in range(x + 1, b + 1))", + "def sat(biggest: List[int], k=7, nums=[31, 1, 2, -10, -2, 4, 17, 18, 20, 14, 20, 21, 18, 0]):\n if len(biggest) != k:\n return False\n smallest = nums[:]\n for n in biggest:\n smallest.remove(n)\n return k == 0 or k == len(nums) or max(smallest) <= min(biggest)", + "def sat(b: str, n=5324680297138495285):\n assert b[:4] == b[-4:] == 'bits'\n inside = b[4:-4]\n assert all(c in \"01\" for c in inside)\n assert inside[0] == \"1\" or len(inside) == 1\n m = 0\n for c in inside:\n m = 2 * m + int(c)\n return m == n", + "def sat(s: str, a=-103252, b=10657):\n n = int(s, 2)\n r = range(a, b)\n if len(r) == 0:\n return n == -1\n mu = sum(r) / len(r)\n return abs(mu - n) <= min(abs(mu - n - 1), abs(mu - n + 1))", + "def sat(ordered: List[int], arr=[4, 2, 3, -1, 15, 2, 6, 9, 5, 16, 1048576]):\n if sorted(ordered) != sorted(arr):\n return False # not even a permutation\n return all(bin(a).count(\"1\") <= bin(b).count(\"1\") for a, b in zip(ordered, ordered[1:]))", + "def sat(str_num: str, nums=['100011101100001', '100101100101110']):\n a, b = nums\n return int(str_num, 2) == int(a, 2) ^ int(b, 2)", + "def sat(counts: List[int], p=0.5, target_prob=0.0625):\n from itertools import product\n a, b = counts\n n = a + b\n prob = (p ** a) * ((1-p) ** b)\n tot = sum([prob for sample in product([0, 1], repeat=n) if sum(sample) == a])\n return abs(tot - target_prob) < 1e-6", + "def sat(n: int, year_len=365):\n prob = 1.0\n for i in range(n):\n prob *= (year_len - i) / year_len\n return (prob - 0.5) ** 2 <= 1/year_len", + "def sat(n: int, b=107, s=25):\n n_str = bin(n)[2:] # n in binary\n return len(n_str) == b and sum(int(i) for i in n_str) == s", + "def sat(colors: List[int], n=100):\n assert set(colors) <= {0, 1} and len(colors) >= n\n squares = {i ** 2: colors[i] for i in range(1, len(colors))}\n return not any(c == d == squares.get(i + j) for i, c in squares.items() for j, d in squares.items())", + "def sat(expr: str, nums=[3, 7, 3, 7]):\n assert len(nums) == 4 and 1 <= min(nums) and max(nums) <= 13, \"hint: nums is a list of four ints in 1..13\"\n expr = expr.replace(\" \", \"\") # ignore whitespace\n digits = \"\"\n for i in range(len(expr)):\n if i == 0 or expr[i - 1] in \"+*-/(\":\n assert expr[i] in \"123456789(\", \"Expr cannot contain **, //, or unary -\"\n assert expr[i] in \"1234567890()+-*/\", \"Expr can only contain `0123456789()+-*/`\"\n digits += expr[i] if expr[i] in \"0123456789\" else \" \"\n assert sorted(int(s) for s in digits.split()) == sorted(nums), \"Each number must occur exactly once\"\n return abs(eval(expr) - 24.0) < 1e-6", + "def sat(cat: str, strings=['Will', 'i', 'am', 'Now', 'here']):\n i = 0\n for s in strings:\n for c in s:\n assert cat[i] == c\n i += 1\n return i == len(cat)", + "def sat(running_squares: List[int], x=[201.1, 301.4, -18.1, 1244122.0, 10101.0101, 10000000.0]):\n for i, v in enumerate(x):\n ceiling = int(v) + (v > 0 and not v.is_integer())\n square = ceiling ** 2\n if running_squares[i] != square + (i > 0 and running_squares[i - 1]):\n return False\n\n return len(running_squares) == len(x)", + "def sat(ans: List[int], m=200004931, n=66679984):\n gcd, a, b = ans\n return m % gcd == n % gcd == 0 and a * m + b * n == gcd and gcd > 0", + "def sat(s: str, n=142, base=7):\n return int(s, base) == n", + "def sat(s: str, counts={'a': 4, 'b': 17, 'd': 101, 'e': 0, 'f': 12}):\n chars = s.split()\n for c in chars:\n assert chars.count(c) == counts[c]\n return len(chars) == sum(counts.values())", + "def sat(tot: int, s=\"Add ME uP AND YOU WILL GET A BIG NUMBER!\"):\n for c in s:\n if c.isupper():\n tot -= ord(c)\n return tot == 0", + "def sat(n: int, x=329437923.5):\n return abs(n - x) <= 0.5", + "def sat(pal: str, s=\"palindromordinals\"):\n assert pal == pal[::-1] and len(pal) == len(s)\n return sum(a != b for a, b in zip(pal, s)) == sum(a != b for a, b in zip(s, s[::-1])) // 2", + "def sat(common: List[int], a=[2, 416629, 2, 4, 17, 29, 31, 1000], b=[31, 2, 4, 17, 29, 41205]):\n return all((i in common) == (i in a and i in b) for i in a + b + common)", + "def sat(splits: List[List[str]], string=\"Hello, world! You look like you're on turtles.\"):\n words, separators = splits\n assert len(words) == len(separators) + 1\n merged = []\n for w, s in zip(words, separators + [\" \"]):\n assert s.count(\" \") + s.count(\",\") == len(s) > 0\n assert w.count(\" \") + w.count(\",\") == 0\n merged += [w, s]\n return \"\".join(merged[:-1]) == string", + "def sat(words: List[str], s=\"This is not a very hard puzzle\", n=3):\n i = 0\n for w in s.split():\n num_consonants = 0\n for c in w.lower():\n if c not in \"aeiou\":\n num_consonants += 1\n if num_consonants == n:\n if words[i] != w:\n return False\n i += 1\n return i == len(words)", + "def sat(x: int, n=42714774173606970182754018064350848294149432972747296768):\n return x ** 3 == n", + "def sat(x: float, coeffs=[2.0, 1.0, 0.0, 8.0]):\n return abs(sum(c * x ** (3 - i) for i, c in enumerate(coeffs))) < 1e-6", + "def sat(sums: List[int], n=104):\n return all(sums[i + 1] - sums[i] == i for i in range(n)) and sums[0] == 0", + "def sat(s: str, target=-2075):\n assert all(c in \"0123457689-\" for c in s) and s[2] == s[5] == \"-\"\n m, d, y = [int(n) for n in s.split(\"-\")]\n assert m in range(1, 13)\n assert d in range(1, 32)\n if m in [4, 6, 9, 11]:\n assert d <= 30\n if m == 2:\n assert d <= 29\n return m - d - y == target", + "def sat(ans: List[int], li=[2, 19, 2, 53, 1, 1, 2, 44, 17, 0, 19, 31]):\n return set(ans) == set(li) and all(li.index(ans[i]) < li.index(ans[i + 1]) for i in range(len(ans) - 1))", + "def sat(depths: List[int], parens=\"() (()) ((()()())) (((((((())))))))\"):\n groups = parens.split()\n for depth, group in zip(depths, groups):\n budget = depth\n success = False\n for c in group:\n if c == '(':\n budget -= 1\n if budget == 0:\n success = True\n assert budget >= 0\n else:\n assert c == ')'\n budget += 1\n assert success\n\n return len(groups) == len(depths)", + "def sat(strings: List[str], a=\"this is a test\", b=\"cat\"):\n s, is_palindrome = strings\n i = 0\n for c in a:\n if c not in b:\n assert s[i] == c\n i += 1\n assert i == len(s)\n return is_palindrome == str(s == s[::-1])", + "def sat(derivative: List[int], poly=[2, 1, 0, 4, 19, 231, 0, 5]):\n\n def val(poly, x):\n return sum(coeff * (x ** i) for i, coeff in enumerate(poly))\n\n return all(abs(val(poly, x + 1e-8) - val(poly, x) - 1e-8 * val(derivative, x)) < 1e-4 for x in range(len(poly)))", + "def sat(c: str, a=\"the quick brown fox jumped over the lazy dog\", b=\"how vexingly quick daft zebras jump\"):\n return (c in a) != (c in b)", + "def sat(n: int, g=44337, p=69337, t=38187):\n return pow(g, n, p) == t", + "def sat(ans: List[str], s=\"The quick brown fox jumps over the lazy dog!\", n=28):\n assert all(ans.count(c.lower()) == 1 for c in s)\n assert all(c == c.lower() for c in ans)\n assert all(c in s.lower() for c in ans)\n return True", + "def sat(drop_indexes: List[int], nums=[2, -1, 14, 8, 9, 9, 8, 4, 2, 4, 3, -100, 1000, 18, 4, -2, -3, -3, 1, 0]):\n d = 0\n for i in range(1, len(nums)):\n if nums[i] < nums[i - 1]:\n assert drop_indexes[d] == i\n d += 1\n return d == len(drop_indexes)", + "def sat(ls: List[str], n=100, a=\"bar\", b=\"foo\"):\n return len(ls) == len(set(ls)) == n and ls[0] == a and ls[-1] == b and ls == sorted(ls)", + "def sat(ops: List[str], target=2021, nums=[4, 6, 2, 1, 1, 3, 9]):\n assert len(ops) == len(set(ops)) and set(ops) == {\"**\", \"*\", \"+\", \"-\", \"//\", \"%\"}\n expr = str(nums[0])\n for n, op in zip(nums[1:], ops):\n expr += op + str(n)\n return eval(expr) == target", + "def sat(summands: List[int], n=1234567890):\n return sum(summands) == n and min(summands) > 0 and len(summands) == 4 and all(s % 2 == 0 for s in summands)", + "def sat(ab: List[int], s=\"3298832990329923299432996329983300033002\"):\n return abs(ab[0] - ab[1]) > 4 and s == \"\".join(str(i) for i in range(min(ab), max(ab) + 1) if i % 2 == 0)", + "def sat(n: int, evens=17, odds=3):\n for c in str(n):\n if int(c) % 2 == 0:\n evens -= 1\n else:\n odds -= 1\n return evens == 0 and odds == 0", + "def sat(even_odd_sum: int, nums=[2341, 125146894, 12521, -12451293476325, 535284623934, 132974693614350]):\n for i in nums[1::2]:\n if i % 2 == 0:\n even_odd_sum -= i\n return even_odd_sum == 0", + "def sat(pals: List[int], n=1099, count=49):\n return all(0 <= i <= n and str(i) == str(i)[::-1] and i % 2 == 0 for i in pals) and len(set(pals)) >= count", + "def sat(evens: List[str], words=['The', 'worm', 'ate', 'a', 'bird', 'imagine', 'that', '!', 'Absurd', '!!']):\n lens = [len(w) for w in evens]\n assert all(lens[i] % 2 == 0 and lens[i] == max(lens[:i + 1]) and w in words for i, w in enumerate(evens))\n return all((len(w) % 2 == 1 or w in evens) for w in words)", + "def sat(orig: str, target=\"-Hello,_world!__This_is-so-easy!-\"):\n assert \"_\" not in orig and \"-\" not in orig\n new = \"\"\n space_count = 0\n for c in orig:\n if c == \" \":\n space_count += 1\n else:\n new += (\"-\" if space_count > 2 else \"_\" * space_count)\n new += c\n space_count = 0\n new += (\"-\" if space_count > 2 else \"_\" * space_count)\n return new == target", + "def sat(states: List[List[int]], n=16385):\n assert states[0] == [1] * 5 and all(len(li) == 5 for li in states) and all(i >= 0 for li in states for i in li)\n for prev, cur in zip(states, states[1:]):\n for i in range(5):\n if cur[i] != prev[i]:\n break\n assert cur[i] < prev[i]\n assert (\n cur[i + 1] - prev[i + 1] == 2 * (prev[i] - cur[i]) and cur[i + 2:] == prev[i + 2:] # k decrements\n or\n cur[i:i + 3] == [prev[i] - 1, prev[i + 2], prev[i + 1]] and cur[i + 3:] == prev[i + 3:] # swap\n )\n\n return states[-1][-1] == 2 ** n", + "def sat(p_stop: float, steps=10, target_prob=0.5):\n prob = sum(p_stop*(1-p_stop)**t for t in range(steps))\n return abs(prob - target_prob) < 1e-6", + "def sat(d: int, n=6002685529):\n return n % d == 0 and all(i in \"47\" for i in str(d))", + "def sat(factor: str, s=\"catscatcatscatcatscat\"):\n return len(factor) < len(s) and s == factor * (len(s) // len(factor))", + "def sat(i: int, n=241864633):\n return 1 < i < n and n % i == 0", + "def sat(certificates: List[int], nums=[1449, 14, 21, 105, 217]):\n return all(pow(cert, n - 1, n) > 1 for cert, n in zip(certificates, nums)) and len(certificates) == len(nums)", + "def sat(init: List[int], target=124156):\n a, b, c = init\n for i in range(16):\n a, b, c = b, c, (a + b + c)\n return a == target", + "def sat(init: List[int], target=2021):\n a, b, c, d = init\n for i in range(99):\n a, b, c, d = b, c, d, (a + b + c + d)\n return a == target", + "def sat(nums: List[int], n=1402):\n return nums[0] == nums[1] == 1 and all(nums[i + 2] == nums[i + 1] + nums[i] for i in range(n - 2))", + "def sat(valids: List[str], filenames=['cat.txt', '!jog.dll', '31F9.html', 'Is this okay?.txt', '.exe', '']):\n assert len(valids) == len(filenames)\n for v, f in zip(valids, filenames):\n n_digits = sum(c.isdigit() for c in f)\n if v == \"Yes\":\n prefix, ext = f.split(\".\")\n assert ext in [\"txt\", \"dll\", \"exe\"] and prefix[0].isalpha() and n_digits < 4\n else:\n assert v == \"No\"\n assert f.split(\".\")[1:] not in [['txt'], ['dll'], ['exe']] or not f[0].isalpha() or n_digits > 3\n return True", + "def sat(candidates: List[str], int_indices=[2, 4, 7, 9, 101]):\n for i in int_indices:\n int(candidates[i])\n for i, s in enumerate(candidates):\n if i not in int_indices:\n try:\n int(s)\n return False\n except ValueError:\n pass\n return True", + "def sat(boring: List[str], text=\"This is not boring. I am boring! I am sooo tired.\"):\n sentences = text.replace(\"!\", \".\").replace(\"?\", \".\").split(\".\")\n boring_and_exciting = boring + [s for s in sentences if s.split()[:1] != [\"I\"]]\n return sorted(boring_and_exciting) == sorted(sentences)", + "def sat(pair: List[float], nums=[0.17, 21.3, 5.0, 9.0, 11.0, 4.99, 17.0, 17.0, 12.4, 6.8]):\n a, b = pair\n assert a in nums and b in nums and a != b\n return abs(a - b) == min(x - y for x in nums for y in nums if x > y)", + "def sat(inds: List[int], nums=[0.31, 21.3, 5.0, 9.0, 11.0, 5.01, 17.2]):\n a, b = inds\n assert a != b and a >= 0 and b >= 0\n for i in range(len(nums)):\n for j in range(i):\n assert abs(nums[i] - nums[j]) >= abs(nums[b] - nums[a])\n return True", + "def sat(containers: List[str], strings=['cat', 'dog', 'shatter', 'bear', 'at', 'ta'], substring=\"at\"):\n i = 0\n for s in strings:\n if substring in s:\n assert containers[i] == s\n i += 1\n return i == len(containers)", + "def sat(extensions: List[str], strings=['cat', 'dog', 'shatter', 'donut', 'at', 'todo'], prefix=\"do\"):\n i = 0\n for s in strings:\n if s.startswith(prefix):\n assert extensions[i] == s\n i += 1\n return i == len(extensions)", + "def sat(n: int, s=\"0000101111111000010\", k=5):\n return s[n:n + k] == s[n] * k", + "def sat(positives: List[int], nums=[2, 2342, -2, 32, -8, -5, 2342, 0, -9, 44, 11]):\n stack = positives[::-1]\n for n in nums:\n assert n <= 0 or n == stack.pop()\n return stack == []", + "def sat(vowels: List[str], texts=['Hello, world!', 'Goodbye, world!']):\n for v, t in zip(vowels, texts):\n i = 0\n for j, c in enumerate(t):\n if c.lower() in \"aeiou\" or c.lower() == 'y' and j == len(t) - 1:\n assert v[i] == c\n i += 1\n assert i == len(v)\n return len(vowels) == len(texts)", + "def sat(firsts: List[int], balances=[[2, 7, -2, 4, 3, -15, 10, -45, 3], [3, 4, -17, -1], [100, -100, -101], [-1]]):\n for i, bals in enumerate(balances):\n total = 0\n for b in bals:\n total += b\n if total < 0:\n assert total == firsts[i]\n break\n return True", + "def sat(ans: str, s=\"FlIp ME!\"):\n return len(ans) == len(s) and all({c, d} == {d.upper(), d.lower()} for c, d in zip(ans, s))", + "def sat(x: float, v=523.12892):\n return 0 <= x < 1 and (v - x).is_integer()", + "def sat(n: int):\n i = n ** 17 + 9\n j = (n + 1) ** 17 + 9\n\n while i != 0: # compute gcd using Euclid's algorithm\n (i, j) = (j % i, i)\n\n return n >= 0 and j != 1", + "def sat(x: List[int], a=8, r=2, l=50):\n return x[0] == a and len(x) == l and all([x[i] * r == x[i + 1] for i in range(len(x) - 1)])", + "def sat(grades: List[str], gpas=[2.8, 3.1, 4.0, 2.2, 3.1, 2.5, 0.9]):\n assert len(grades) == len(gpas)\n letters = ['A+', 'A', 'A-', 'B+', 'B', 'B-', 'C+', 'C', 'C-', 'F']\n scores = [4.0, 3.7, 3.4, 3.0, 2.7, 2.4, 2.0, 1.7, 1.4, 0.0]\n for grade, gpa in zip(grades, gpas):\n i = letters.index(grade)\n assert gpa >= scores[i]\n assert i == 0 or gpa <= scores[i - 1]\n return True", + "def sat(h: int, seq=[3, 1, 4, 17, 5, 17, 2, 1, 41, 32, 2, 5, 5, 5, 5]):\n for i in seq:\n assert not (i > 0 and i > h and seq.count(i) >= i)\n return h == -1 or seq.count(h) >= h > 0", + "def sat(li: List[int], orig=[1, 6, 3, 41, 19, 4, 12, 3, 18, 5, -29, 0, 19521]):\n return orig[1::2] == li[1::2] and li[::2] == sorted(orig[::2])", + "def sat(li: List[int], tags=[3, 0, 3, 2, 0, 1, 0, 3, 1, 1, 2, 2, 0, 2, 1, 3]):\n n = max(tags) + 1\n assert sorted(tags) == sorted(list(range(n)) * 4), \"hint: each tag occurs exactly four times\"\n assert len(li) == len(set(li)) and min(li) >= 0\n return sum(li) * 2 == sum(range(4 * n)) and sorted([tags[i] for i in li]) == [i // 2 for i in range(2 * n)]", + "def sat(coords: List[List[float]], sides=[8.9, 10.8, 17.0]):\n assert len(coords) == 3\n sides2 = [((x - x2) ** 2 + (y - y2) ** 2) ** 0.5 for i, (x, y) in enumerate(coords) for x2, y2 in coords[:i]]\n return all(abs(a - b) < 1e-6 for a, b in zip(sorted(sides), sorted(sides2)))", + "def sat(primes: List[bool], n=\"A4D4455214122CE192CCBE3\"):\n return all(primes[i] == (c in \"2357BD\") for i, c in enumerate(n))", + "def sat(zero_sums: List[bool], trips=[[1253532, -3920635, 332], [-24, 18, 6], [0, 5, -5], [1, 1, 1], [-20, 17, 4]]):\n return len(zero_sums) == len(trips) and all(z == ((a + b + c) == 0) for z, (a, b, c) in zip(zero_sums, trips))", + "def sat(x: int, a=324554, b=1345345):\n if a < 50:\n return x + a == b\n else:\n return x - 2 * a == b", + "def sat(x: int, a=9384594, b=1343663):\n if x > 0 and a > 50:\n return x - a == b\n else:\n return x + a == b", + "def sat(violation: List[int], nums=[1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 17, 17, 18, 19, 20, 22, 24]):\n if not violation:\n return all(nums[i] < nums[i + 1] for i in range(len(nums) - 1))\n i, j = violation\n return 0 <= i < j and nums[i] >= nums[j]", + "def sat(n: int, a=10000200001):\n return a == n * n and n < 0", + "def sat(x: int, a=3, n=1290070078170102666248196035845070394933441741644993085810116441344597492642263849):\n return a ** x == n", + "def sat(li: List[int], nums=[12, 23, -2, 5, 0], sep=4):\n return li[::2] == nums and li[1::2] == [sep] * (len(nums) - 1)", + "def sat(nums: List[int], super_factorials=[1, 2, 1]):\n for i, sf in enumerate(super_factorials):\n n = nums[i]\n for j in range(n, 0, -1):\n k = j ** (n - j + 1)\n assert sf % k == 0, f\"{i} {sf} {j} {n}\"\n sf //= k\n assert sf == 1\n return True", + "def sat(problem: int, weights=[1, 2, 5, 2, 1, 17], max_weight=100):\n if problem == -1:\n return sum(weights) > max_weight\n return weights[problem] != weights[- 1 - problem]", + "def sat(b: bool, n=10):\n i = 0\n while i <= n:\n if i + i == n:\n return b == True\n i += 1\n return b == False", + "def sat(seq: List[int], compressed_len=17, text=\"Hellooooooooooooooooooooo world!\"):\n index = [chr(i) for i in range(256)]\n pieces = [\"\"]\n for i in seq:\n pieces.append((pieces[-1] + pieces[-1][0]) if i == len(index) else index[i])\n index.append(pieces[-2] + pieces[-1][0])\n return \"\".join(pieces) == text and len(seq) <= compressed_len", + "def sat(d: int, n=123456):\n return n % d == 0 and d < n and all(n % e for e in range(d + 1, n))", + "def sat(extremes: List[int], nums=[-10, -4, 100, -40, 2, 2, 3, 17, -50, -25, 18, 41, 9, 11, 15]):\n neg, pos = extremes\n if neg == 0:\n assert nums == [] or min(nums) >= 0\n else:\n assert neg < 0 and neg in nums and all(n >= 0 or n <= neg for n in nums)\n if pos == 0:\n assert nums == [] or max(nums) <= 0\n else:\n assert pos > 0 and pos in nums and all(n <= 0 or n >= pos for n in nums)\n return True", + "def sat(ans: List[int], nums=[23, 17, 201, 14, 10473, 43225, 421, 423, 11, 10, 2022, 342157]):\n i, digit_sum = ans\n n = nums[i]\n\n def is_prime(n):\n return n > 1 and all(n % j for j in range(2, int(n ** 0.5) + 1))\n\n return is_prime(n) and all(m <= n for m in nums if is_prime(m)) and digit_sum == sum(int(c) for c in str(n))", + "def sat(p: int, n=101076):\n\n def is_prime(m):\n return all(m % i for i in range(2, m - 1))\n\n return is_prime(p) and n % p == 0 and p > 0 and all(n % i or not is_prime(i) for i in range(p + 1, n))", + "def sat(x: float, str_nums=['1,3', '-11', '17.5', '-11', '2', '2.2', '2,2', '4', '-18,18', '99.09']):\n found = False\n for s in str_nums:\n y = float(s.replace(\",\", \".\"))\n assert y <= x\n if y == x:\n found = True\n return found", + "def sat(y: List[bool], x=['Hello, world!', 'cat', '', 'a test', 'test a', 'i e', 'o', 'I O U', 'You and I']):\n assert len(x) == len(y)\n for s, b in zip(x, y):\n if len(s.split(\" \")[-1]) == 1:\n assert b == s[-1].isalpha()\n else:\n assert not b\n return True", + "def sat(path: List[int], k=10, edges=[[2, 4], [3], [4, 1], [4], [0]]):\n\n def check(prefix):\n for i, j in zip(path, prefix):\n if i != j:\n return i < j\n return len(prefix) >= k or all(check(prefix + [i]) for i in edges[prefix[-1]])\n\n return all(path[i] in edges[path[i - 1]] for i in range(1, k)) and all(check([i]) for i in range(len(edges)))", + "def sat(new_list: List[int], old_list=[321, 12, 532, 129, 9, -12, 4, 56, 90, 0]):\n return [i - 1 for i in new_list] == old_list", + "def sat(item: int, li=[17, 2, 3, 9, 11, 11], index=4):\n return li.index(item) == index", + "def sat(li: List[int], i=29, index=10412):\n return li.index(i) == index", + "def sat(li: List[int], n=85012):\n return len(li) == n", + "def sat(exp_poly: List[int], d=74152093423, poly=[1, 6, 3, 1, 0, 4, 4]):\n p = len(poly)\n assert p > 2 and all(p % i for i in range(2, p)), \"Hint: p is a prime > 2\"\n\n def val(coeffs, n): # evaluate polynomial mod p\n return sum(c * pow(n, i, p) for i, c in enumerate(coeffs)) % p\n\n return all(val(exp_poly, n) == pow(val(poly, n), d, p) for n in range(p))", + "def sat(tot: int, k=5, nums=[1252, 125273523, 0, 42, 100, 214532, 2, 0, 11, 14]):\n for n in nums[:k]:\n if len(str(abs(n))) > 2:\n tot -= n\n return tot == 0", + "def sat(x: List[int], length=13, s=\"Dynamic programming solves this puzzle!!!\"):\n return all(s[x[i]] <= s[x[i + 1]] and x[i + 1] > x[i] >= 0 for i in range(length - 1))", + "def sat(x: List[int], length=20, s=\"Dynamic programming solves this classic job-interview puzzle!!!\"):\n return all(s[x[i]] <= s[x[i + 1]] and x[i + 1] > x[i] for i in range(length - 1))", + "def sat(ans: str, words=['these', 'are', 'some', 'pretty', 'long', 'words']):\n return ans in words and all(len(ans) >= len(w) for w in words)", + "def sat(transcripts: List[str], max_moves=10):\n COLORS = \"ABCDEF\"\n\n def helper(secret: str, transcript=\"\"):\n if transcript.count(\"\\n\") == max_moves:\n return False\n guess = min([t for t in transcripts if t.startswith(transcript)], key=len)[-4:]\n if guess == secret:\n return True\n assert all(g in COLORS for g in guess)\n perfect = {c: sum([g == s == c for g, s in zip(guess, secret)]) for c in COLORS}\n almost = sum(min(guess.count(c), secret.count(c)) - perfect[c] for c in COLORS)\n return helper(secret, transcript + f\"{guess} {sum(perfect.values())}{almost}\\n\")\n\n return all(helper(r + s + t + u) for r in COLORS for s in COLORS for t in COLORS for u in COLORS)", + "def sat(m: int, hello=[1, 31, 3, 2, 0, 18, 32, -4, 2, -1000, 3502145, 3502145, 21, 18, 2, 60]):\n return m in hello and not any(m < i for i in hello)", + "def sat(x: int, nums=[132666041, 237412, 28141, -12, 11939, 912414, 17], upper=133658965):\n dev = sum(n - x for n in nums)\n return dev <= upper", + "def sat(taken: List[int], val_counts=[[4, 3], [5, 2], [9, 3], [13, 13], [8, 11], [56, 1]], upper=11):\n advantage = 0\n assert len(taken) == len(val_counts) and sum(taken) <= upper\n for i, (val, count) in zip(taken, val_counts):\n assert 0 <= i <= count\n advantage += val * i - val * count / 2\n return advantage > 0", + "def sat(x: float, nums=[12, -2, 14, 3, -15, 10, -45, 3, 30]):\n return sum((n - x) ** 2 for n in nums) * len(nums) <= sum((m - n) ** 2 for m in nums for n in nums) * .5 + 1e-4", + "def sat(bananas: int, bowl=\"5024 apples and 12189 oranges\", total=12491241):\n bowl += f\" and {bananas} bananas\"\n return sum([int(s) for s in bowl.split() if s.isdigit()]) == total", + "def sat(direction: str, nums=[2, 4, 17, 29, 31, 1000, 416629]):\n if direction == \"increasing\":\n return all(nums[i] < nums[i + 1] for i in range(len(nums) - 1))\n if direction == \"decreasing\":\n return all(nums[i + 1] < nums[i] for i in range(len(nums) - 1))", + "def sat(squares: List[List[int]], m=9, n=9):\n k = min(m, n)\n assert all(i in range(m) and j in range(n) for i, j in squares), \"queen off board\"\n assert len(squares) == k, \"Wrong number of queens\"\n assert len({i for i, j in squares}) == k, \"Queens on same row\"\n assert len({j for i, j in squares}) == k, \"Queens on same file\"\n assert len({i + j for i, j in squares}) == k, \"Queens on same SE diagonal\"\n assert len({i - j for i, j in squares}) == k, \"Queens on same NE diagonal\"\n return True", + "def sat(s: str, pool=['cat', 'catatatatctsa', 'abcdefhijklmnop', '124259239185125', '', 'foo', 'unique']):\n assert s in pool\n n = len(set(s))\n for p in pool:\n assert len(set(p)) <= n\n return True", + "def sat(seq: List[int], target=[1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0], n_steps=4):\n s = seq[:] # copy\n for step in range(n_steps):\n for i in range(len(seq) - 1):\n if (s[i], s[i + 1]) == (0, 1):\n (s[i], s[i + 1]) = (1, 0)\n return s == target", + "def sat(li: List[str], lists=[['this', 'list', 'is', 'narrow'], ['I', 'am', 'shorter but wider']]):\n width = sum(len(s) for s in li)\n for li2 in lists:\n assert width <= sum(len(s) for s in li2)\n return li in lists", + "def sat(indices: List[int], s=\"I am an unhappy string!\"):\n i, j = indices\n return s[i] == s[j] and 0 <= i < j < i + 3", + "def sat(count: int, n=981):\n for i in range(n):\n for j in range(n):\n count -= 1\n return count == 0", + "def sat(different: str, d={'cat': 'CAT', 'tree': 'T', 'pick me': 'not', 'OK': 'red', 'blah': 'blah', 'z': 'Z'}):\n return different in d and all(k.islower() != different.islower() for k in d if k != different)", + "def sat(odds: List[int], n=1243272912731):\n num_odds = 0\n while True:\n if n % 2 == 1:\n num_odds += 1\n if n not in odds:\n return False\n if n <= 1:\n return num_odds == len(odds)\n n = (3 * n + 1) if n % 2 == 1 else n // 2", + "def sat(root: float, coeffs=[1, 2, 3, 17]):\n return abs(sum(coeff * (root ** i) for i, coeff in enumerate(coeffs))) < 1e-4", + "def sat(prod: int, n=14235764939971075543215213):\n\n for c in str(n):\n i = int(c)\n if i % 2 == 1:\n assert prod % i == 0\n prod //= i\n return prod == any(int(c) % 2 for c in str(n))", + "def sat(nums: List[int], n=5):\n count = 18 * (10 ** (n - 2)) if n > 1 else 1\n strs = {str(n) for n in nums}\n return len(strs) == count and all(s.startswith(\"1\") or s.endswith(\"1\") and len(s) == n for s in strs)", + "def sat(indices: List[int], H=60, alpha=18, beta=2, xs=[0, 10, 20, 30, 50, 80, 100, 120, 160, 190, 200], ys=[0, 30, 10, 30, 50, 40, 10, 20, 20, 55, 10], thresh=26020):\n assert sorted({0, len(xs) - 1, *indices}) == indices, f\"Ans. should be sorted list [0, ..., {len(xs) - 1}]\"\n cost = alpha * (H - ys[0])\n for i, j in zip(indices, indices[1:]):\n a, b, r = xs[i], xs[j], (xs[j] - xs[i]) / 2\n assert max(ys[i], ys[j]) + r <= H, \"Bridge too tall\"\n assert all(ys[k] <= H - r + ((b - xs[k]) * (xs[k] - a)) ** 0.5 for k in range(i + 1, j)), \\\n \"Bridge too short\"\n cost += alpha * (H - ys[j]) + beta * (b - a) ** 2\n return cost <= thresh", + "def sat(init: List[List[int]], period=3):\n target = {x + y * 1j for x, y in init} # complex numbers encode live cells\n\n deltas = (1j, -1j, 1, -1, 1 + 1j, 1 - 1j, -1 + 1j, -1 - 1j)\n live = target\n for t in range(period):\n visible = {z + d for z in live for d in deltas}\n live = {z for z in visible if sum(z + d in live for d in deltas) in ([2, 3] if z in live else [3])}\n if live == target:\n return t + 1 == period", + "def sat(ans: List[int], s=\"Bananannanaannanaanananananana\", sub=\"anan\", count=7):\n return all(sub == s[i:i + len(sub)] and i >= 0 for i in ans) and len(set(ans)) >= count", + "def sat(words: List[str], num=100, bits=100, dist=34):\n assert len(words) == num and all(len(word) == bits and set(word) <= {\"0\", \"1\"} for word in words)\n return all(sum([a != b for a, b in zip(words[i], words[j])]) >= dist for i in range(num) for j in range(i))", + "def sat(inds: List[int], nums=[12, -10452, 18242, 10440, 81, 241, 525, -18242, 91, 20]):\n a, b = inds\n return nums[a] + nums[b] == 0 and a >= 0 and b >= 0", + "def sat(pals: List[bool], strs=['palindrome', 'madamimadam', '', 'foo', 'eyes', '(-:-)']):\n return all(pals[i] == (s == s[::-1]) for i, s in enumerate(strs))", + "def sat(ans: str, s=\"so easy\", length=20):\n return ans == ans[::-1] and len(ans) == length and s in ans", + "def sat(matches: List[int], parens=\"((())()(()()))(())\"):\n for i, (j, c) in enumerate(zip(matches, parens)):\n assert parens[j] != c and matches[j] == i and all(i < matches[k] < j for k in range(i + 1, j))\n return len(matches) == len(parens)", + "def sat(perm: str, s=\"))( )()()() )))(( ))))((( )))))(((( ))))))))((((((( ))))))((((( )))))))(((((( )))))))))((((((( ((((((((((\"):\n assert sorted(perm.split()) == sorted(s.split()), \"Must be a permutation of the space-delimited 'groups'\"\n return all(perm[:i].count(\"(\") >= perm[:i].count(\")\") for i in range(len(perm)))", + "def sat(swaps: List[List[int]], nums1=[1, 3, 2, 4, 5, 8, 7, 11], nums2=[0, 7, 0, 8, 19, 4, 41, 43, 42]):\n copy1 = nums1[:]\n copy2 = nums2[:]\n for i, j in swaps:\n copy1[i], copy2[j] = copy2[j], copy1[i]\n return all(n % 2 == 0 for n in copy1)", + "def sat(beats: List[int], score=\"o o o| o| .| .| .| o| o| o o o| .|\"):\n return \" \".join({1: '.|', 2: 'o|', 4: 'o'}[b] for b in beats) == score", + "def sat(keep: List[bool], heights=[10, 2, 14, 1, 8, 19, 16, 6, 12, 3, 17, 0, 9, 18, 5, 7, 11, 13, 15, 4]):\n n = int(len(heights) ** 0.5)\n assert sorted(heights) == list(range(n * n + n)), \"hint: heights is a permutation of range(n * n + n)\"\n kept = [i for i, k in zip(heights, keep) if k]\n assert len(kept) == 2 * n, \"must keep 2n items\"\n pi = sorted(range(2 * n), key=lambda i: kept[i]) # the sort indices\n return all(abs(pi[2 * i] - pi[2 * i + 1]) == 1 for i in range(n))", + "def sat(planets_between: List[str], a=\"Mars\", b=\"Neptune\"):\n assert \" \" not in \"\".join(planets_between)\n return \" \".join([a] + planets_between + [b]) in \"Venus Earth Mars Jupiter Saturn Uranus Neptune Pluto\"", + "def sat(nodes: List[int], size=3, edges=[[0, 17], [0, 22], [17, 22], [17, 31], [22, 31], [31, 17]]):\n assert len(nodes) == len(set(nodes)) >= size\n edge_set = {(a, b) for (a, b) in edges}\n for a in nodes:\n for b in nodes:\n assert a == b or (a, b) in edge_set or (b, a) in edge_set\n\n return True", + "def sat(pos: List[int], nums=[-804, 9124, -945, 2410, 0, 21, -123]):\n for n in pos + nums:\n s = str(n)\n if int(s[:2]) + sum(int(c) for c in s[2:]) <= 0:\n assert n not in pos\n else:\n assert pos.count(n) == nums.count(n)\n return True", + "def sat(factors: List[int], n=123456, num_factors=8):\n assert len(factors) == num_factors\n prod = 1\n for d in factors:\n prod *= d\n assert d > 1\n return prod == n", + "def sat(n: int, lower=123456):\n assert any((i ** 0.5).is_integer() for i in [5 * n * n - 4, 5 * n * n + 4]), \"n must be a Fibonacci number\"\n assert all(n % i for i in range(2, int(n ** 0.5) + 1)), \"n must be prime\"\n return n > lower", + "def sat(interval2: List[int], interval1=[32157, 93210127]):\n intersection_width = min(interval1[1], interval2[1]) - max(interval1[0], interval2[0])\n return intersection_width > 1 and all(intersection_width % i for i in range(2, intersection_width))", + "def sat(neighbors: List[int], nums=[14, 7, 11, 13, 7, 4, 19, 2, 55, 13, 31, 14, 2, 9, -7, 0, 88, 13, 13]):\n\n def prime(m):\n return all(m % i for i in range(2, m - 1))\n\n goods = set()\n for i, n in enumerate(nums):\n if (i > 0 and prime(nums[i - 1])) or (i < len(nums) - 1 and prime(nums[i + 1])):\n goods.add(n)\n\n return set(neighbors) == goods and all(n == min(neighbors[i:]) for i, n in enumerate(neighbors))", + "def sat(primes: str, s=\"This is a test of whether you would want to do such strange puzzles\"):\n\n def is_prime(n):\n return n > 1 and all(n % j for j in range(2, int(n ** 0.5) + 1))\n\n prime_words = primes.split()\n i = 0\n for word in s.split():\n if is_prime(len(word)):\n assert prime_words[i] == word\n i += 1\n\n return i == len(prime_words)", + "def sat(primes: List[int], n=1234):\n assert all(1 < p for p in primes) and all(p % q for p in primes for q in primes if q < p)\n return len({i for p in primes for i in range(p, n, p)}) == max(n - 2, 0)", + "def sat(n: int, arr=[1, 7, -20052, 14, -3, -11, 1025235, 14]):\n tot = 0\n\n for i in arr:\n if tot >= 0:\n tot += abs(i)\n else:\n tot -= abs(i)\n if i < 0:\n tot = -tot\n elif i == 0:\n tot = 0\n break\n\n return n == tot", + "def sat(triples: List[List[int]], n=920, m=799):\n for a, b, c in triples:\n if not (a * a + b * b == c * c and 0 < a < b < c <= n):\n return False\n return triples == sorted(triples) and len(triples) >= m", + "def sat(quine: str):\n return eval(quine) == quine", + "def sat(ans: List[float], nums=[13.0, 17.0, 17.0, 15.5, 2.94]):\n assert min(ans) == 0.0 and max(ans) == 1.0\n a = min(nums)\n b = max(nums)\n for i in range(len(nums)):\n x = a + (b - a) * ans[i]\n assert abs(nums[i] - x) < 1e-6\n return True", + "def sat(rev_quine: str):\n return eval(rev_quine[::-1]) == rev_quine", + "def sat(maxes: List[int], nums=[1, 4, 3, -6, 19]):\n assert len(maxes) == len(nums)\n for i in range(len(nums)):\n if i > 0:\n assert maxes[i] == max(maxes[i - 1], nums[i])\n else:\n assert maxes[0] == nums[0]\n return True", + "def sat(roman: str, n=2414):\n key = {1000: 'm', 900: 'cm', 500: 'd', 400: 'cd',\n 100: 'c', 90: 'xc', 50: 'l', 40: 'xl',\n 10: 'x', 9: 'ix', 5: 'v', 4: 'iv',\n 1: 'i'}\n m = 0\n for base in [1000, 100, 10, 1]:\n for mul in [9, 4, 5, 1, 1, 1]: # up to three 1's, move on after 9 or 4\n val = base * mul\n if val in key and roman.startswith(key[val]):\n m += val\n roman = roman[len(key[val]):]\n if mul == 9 or mul == 4: # 9 or 4 can't be followed by anything else\n break\n return m == n", + "def sat(original: List[int], arr=[2, 3, -1, -1, 0, 1, 1]):\n assert str(original)[1:-1] in str(sorted(original) * 2), \"Not ring sorted\"\n return any(original == arr[:i] + arr[i + 1:] for i in range(len(arr) + 1))", + "def sat(r: str, s=\"light star\", t=\"I love to look at the starlight!\"):\n return r in t and len(r) == len(s) and r in s + s", + "def sat(n: int, nums=[17, -1023589211, -293485382500, 31, -293485382500, 105762, 94328103589]):\n assert n in nums\n return len({i for i in nums if i <= n}) == 2", + "def sat(ls: List[str], combined=\"() (()) ((() () ())) (() )\"):\n for s in ls:\n assert s.count(\"(\") == s.count(\")\")\n assert all(s[:i].count(\"(\") > s[:i].count(\")\") for i in range(1, len(s))) # s is not further divisible\n return ''.join(ls) == combined.replace(' ', '')", + "def sat(li: List[List[int]], n=19723, lower=1000):\n assert len({(i, j) for i, j in li}) >= lower, \"not enough 7's (ignoring duplicates)\"\n return all(str(i)[j] == '7' and (i % 11 == 0 or i % 13 == 0) and 0 <= i < n and 0 <= j for i, j in li)", + "def sat(orig: str, result=\"Hello, world!\", shift=7):\n n = len(result)\n assert len(orig) == n\n return all(ord(orig[i]) + shift == ord(result[i]) for i in range(n))", + "def sat(li: List[int]):\n return all(j in {i - 1, i + 1, 3 * i} for i, j in zip([0] + li, li + [128])) and len(li) == 9", + "def sat(li: List[int], n=149432, upper=14943):\n return len(li) <= upper and all(abs(a - b) <= 10 for a, b in zip([1] + li, li + [n]))", + "def sat(z: str, x=\"-8142432/763083\", y=\"66/-13474\", max_len=18):\n [[a, b], [c, d], [u, v]] = [[int(n) for n in s.split(\"/\")] for s in [x, y, z]]\n return a * c * v == b * d * u and len(z) <= max_len", + "def sat(s: str, matrix=[[0, 0, 0, 0, 0], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], max_moves=3):\n matrix = [m[:] for m in matrix] # copy\n for c in s:\n if c in \"01234\":\n i = \"01234\".index(c)\n matrix[i], matrix[i + 1] = matrix[i + 1], matrix[i]\n if c in \"abcde\":\n j = \"abcde\".index(c)\n for row in matrix:\n row[j], row[j + 1] = row[j + 1], row[j]\n\n return len(s) <= max_moves and matrix[2][2] == 1", + "def sat(moves: List[int], start=[[5, 0, 2, 3], [1, 9, 6, 7], [4, 14, 8, 11], [12, 13, 10, 15]]):\n\n locs = {i: [x, y] for y, row in enumerate(start) for x, i in enumerate(row)} # locations, 0 stands for blank\n for i in moves:\n assert abs(locs[0][0] - locs[i][0]) + abs(locs[0][1] - locs[i][1]) == 1\n locs[0], locs[i] = locs[i], locs[0]\n return all(locs[i] == [i % len(start[0]), i // len(start)] for i in locs)", + "def sat(n: int, b=2, target=5):\n return (b ** n) % n == target", + "def sat(val_index: List[int], nums=[125123, 422323, 141, 5325, 812152, 9, 42145, 5313, 421, 812152]):\n if val_index == []:\n return all(n % 2 == 1 for n in nums)\n v, i = val_index\n assert v % 2 == 0 and nums[i] == v\n return all(n > v or n % 2 == 1 for n in nums[:i]) and all(n >= v or n % 2 == 1 for n in nums[i:])", + "def sat(ordered: List[int], nums=[1, 0, -1, -100, 10, 14, 235251, 11, 10000, 2000001, -155]):\n digit_sums = [sum(int(c) for c in str(n) if c != \"-\") for n in ordered]\n return sorted(ordered) == sorted(nums) and digit_sums == sorted(digit_sums)", + "def sat(ans: str, s=\"six one four three two nine eight\"):\n nums = 'zero one two three four five six seven eight nine'.split()\n return [nums.index(x) for x in ans.split(\" \")] == sorted([nums.index(x) for x in s.split(\" \")])", + "def sat(s: str, inp=\"1+1+3+1+3+2+2+1+3+1+2\"):\n return all(s.count(c) == inp.count(c) for c in inp + s) and all(s[i - 2] <= s[i] for i in range(2, len(s), 2))", + "def sat(sub: List[int], nums=[17, 20, -100, 101, 423258, 19949, 0, 20174, 9351773, -11]):\n for i in range(len(sub)):\n n = sub[i]\n assert n == min(sub[i:])\n assert all(int(c) % 2 for c in str(abs(n))) # all odd digits\n assert sub.count(n) == nums.count(n)\n\n for n in nums:\n if n not in sub:\n assert any(int(c) % 2 == 0 for c in str(abs(n)))\n\n return True", + "def sat(ans: str, n=15):\n return [int(i) for i in ans.split(' ')] == list(range(n + 1))", + "def sat(inds: List[int], string=\"Sssuubbstrissiingg\"):\n return inds == sorted(inds) and \"\".join(string[i] for i in inds) == \"substring\"", + "def sat(li: List[int], n=909):\n return li[0] == n and len(li) == n and all(b - a == 2 for a, b in zip(li, li[1:]))", + "def sat(i: int, s=\"cat\", target=\"a\"):\n return s[i] == target", + "def sat(lengths: List[int], strs=['pneumonoultramicroscopicsilicovolcanoconiosis', ' ', 'foo', '2.5']):\n for length, s in zip(lengths, strs):\n try:\n s[length]\n return False\n except IndexError:\n s[length - 1]\n return len(lengths) == len(strs)", + "def sat(s: str, target=\"foofoofoofoo\", n=2):\n return s * n == target", + "def sat(n: int, target=\"foofoofoofoo\", s=\"foofoo\"):\n return s * n == target", + "def sat(i: int, s=\"cat\", target=\"a\"):\n return s[i] == target and i < 0", + "def sat(s: str, dups=2021):\n return len(set(s)) == len(s) - dups", + "def sat(x: str, parts=['I', 'love', 'dumplings', '!'], length=100):\n return len(x) == length and x.split() == parts", + "def sat(x: str, parts=['I', 'love', 'dumplings', '!', ''], string=\"I_love_dumplings_!_\"):\n return string.split(x) == parts", + "def sat(lst: List[str], s=\"Hello, world!\"):\n if \" \" in s:\n return \" \".join(lst) == s\n if \",\" in s:\n return \",\".join(lst) == s\n return \"\".join(lst) == \"\".join(c for c in s if c.islower() and ord(c) % 2 == 0)", + "def sat(s: str):\n return s in str(8 ** 1818) and s == s[::-1] and len(s) > 11", + "def sat(li: List[int]):\n return all(i + j == 9 for i, j in zip([4] + li, li)) and len(li) == 1000", + "def sat(li: List[int]):\n return all([sum(li[:i]) == i for i in range(20)])", + "def sat(s: str):\n return float(s) + len(s) == 4.5", + "def sat(ls: List[str]):\n return [s + t for s in ls for t in ls if s != t] == 'berlin berger linber linger gerber gerlin'.split()", + "def sat(s: str):\n return s.count('o') == 1000 and s.count('oo') == 100 and s.count('ho') == 801", + "def sat(li: List[int]):\n return all(j in {i - 1, i + 1, 3 * i} for i, j in zip([0] + li, li + [128]))", + "def sat(s: str):\n return s[::2] in s and len(set(s)) == 5", + "def sat(li: List[int]):\n return li.count(17) == 3 and li.count(3) >= 2", + "def sat(ls: List[str]):\n return \"\".join(ls) == str(8 ** 88) and all(len(s) == 8 for s in ls)", + "def sat(li: List[int]):\n return all(i in range(1000) and abs(i - j) >= 10 for i in li for j in li if i != j) and len(set(li)) == 100", + "def sat(li: List[int]):\n return all([123 * li[i] % 1000 < 123 * li[i + 1] % 1000 and li[i] in range(1000) for i in range(20)])", + "def sat(li: List[int]):\n return len(li) == 10 and li.count(li[3]) == 2", + "def sat(i: int):\n return i % 123 == 4 and i > 10 ** 10", + "def sat(ls: List[str]):\n return ls[1234] in ls[1235] and ls[1234] != ls[1235]", + "def sat(encrypted: str, orig=\"Hello, world!\"):\n assert len(encrypted) == len(orig)\n return all(chr(ord(a) - 2 * 2) == b for a, b in zip(encrypted, orig))", + "def sat(x: str, puz=\"____9_2___7__________1_8_4____2_78____4_____1____69____2_8___5__6__3_7___49______\"):\n assert all(c == \"_\" or c == s for (c, s) in zip(puz, x))\n\n full = set('123456789')\n for i in range(9):\n assert {x[i] for i in range(9 * i, 9 * i + 9)} == full, \"invalid row\"\n assert {x[i] for i in range(i, i + 81, 9)} == full, \"invalid column\"\n assert {x[9 * a + b + i + 26 * (i % 3)] for a in range(3) for b in range(3)} == full, \"invalid square\"\n\n return True", + "def sat(nums: List[int], tot=14, prod=99):\n assert sum(nums) == tot\n p = 1\n for n in nums:\n p *= n\n return p == prod", + "def sat(nums: List[int], target=983):\n assert target % 9 not in [4, 5], \"Hint\"\n return len(nums) == 3 and sum([i ** 3 for i in nums]) == target", + "def sat(s: str, target=\"Hello world\"):\n\n def cycle3(trip):\n return trip if len(trip) != 3 else trip[2] + trip[:2]\n\n return target == \"\".join(cycle3(s[i: i + 3]) for i in range(0, len(s), 3))", + "def sat(factors: List[List[int]]):\n primes = set(range(2, 1000))\n for n in range(2, 1000):\n if n in primes:\n primes.difference_update(range(2 * n, 1000, n))\n assert all(p in primes for f in factors for p in f), \"all factors must be prime\"\n nums = {p * q * r for p, q, r in factors}\n return max(nums) < 1000 and len(nums) == 247", + "def sat(trips: List[List[int]], a=[1, 0, -17, 42, 321, 36, 429, 35, 10, 923, 35, 18, 0, 17, 24, 32, 8], count=221):\n assert len({tuple(t) for t in trips}) >= count\n return all(0 <= i < j < k and (a[i] + a[j] + a[k]) % 3 == 0 for i, j, k in trips)", + "def sat(good_boards: List[str]):\n board_bit_reps = {tuple(sum(1 << i for i in range(9) if b[i] == c) for c in \"XO\") for b in good_boards}\n win = [any(i & w == w for w in [7, 56, 73, 84, 146, 273, 292, 448]) for i in range(512)]\n\n def tie(x, o): # returns True if O has a forced tie/win. It's O's turn to move.\n if o | x != 511: # complete board\n o |= 1 << [i for i in range(9) if (x, o | (1 << i)) in board_bit_reps][0]\n return not win[x] and (win[o] or all((x | o) & (1 << i) or tie(x | (1 << i), o) for i in range(9)))\n\n return all(tie(1 << i, 0) for i in range(9))", + "def sat(good_boards: List[str]):\n board_bit_reps = {tuple(sum(1 << i for i in range(9) if b[i] == c) for c in \"XO\") for b in good_boards}\n win = [any(i & w == w for w in [7, 56, 73, 84, 146, 273, 292, 448]) for i in range(512)]\n\n def tie(x, o): # returns True if X has a forced tie/win assuming it's X's turn to move.\n x |= 1 << [i for i in range(9) if (x | (1 << i), o) in board_bit_reps][0]\n return not win[o] and (win[x] or all((x | o) & (1 << i) or tie(x, o | (1 << i)) for i in range(9)))\n\n return tie(0, 0)", + "def sat(moves: List[List[int]]):\n rods = ([8, 7, 6, 5, 4, 3, 2, 1], [], [])\n for [i, j] in moves:\n rods[j].append(rods[i].pop())\n assert rods[j][-1] == min(rods[j]), \"larger disk on top of smaller disk\"\n return rods[0] == rods[1] == []", + "def sat(moves: List[List[int]], source=[[0, 7], [4, 5, 6], [1, 2, 3, 8]], target=[[0, 1, 2, 3, 8], [4, 5], [6, 7]]):\n state = [s[:] for s in source]\n\n for [i, j] in moves:\n state[j].append(state[i].pop())\n assert state[j] == sorted(state[j])\n\n return state == target", + "def sat(height: int, area=1319098728582, base=45126):\n return base * height == 2 * area", + "def sat(seq: List[int], length=181):\n return all(seq[n] == (seq[n - 1] + seq[n - 2] + seq[n + 1] if n % 2 else 1 + n // 2) for n in range(length))", + "def sat(inds: List[int], nums=[12, 6, 41, 15, -10452, 18242, 10440, 6, 6, 6, 6]):\n return len(inds) == 3 and sum(nums[i] for i in inds) == 0", + "def sat(li: List[int], orig=[1, -2, 3, 17, 8, 4, 12, 3, 18, 5, -29, 0, 0]):\n assert orig[::3] == li[::3], \"Keep every third entry fixed\"\n assert sorted(li) == sorted(orig), \"Not even a permutation\"\n assert all(li[i] <= li[i + 1] for i in range(1, len(li) - 1, 3))\n assert all(li[i] <= li[i + 2] for i in range(2, len(li) - 2, 3))\n return True", + "def sat(indices: List[List[int]], uneven=[[1, 3, 2, 32, 17], [17, 2, 48, 17], [], [9, 35, 4], [3, 17]], target=17):\n for i, j in indices:\n assert uneven[i][j] == target\n for i, row in enumerate(uneven):\n for j, n in enumerate(row):\n assert n != target or [i, j] in indices\n return True", + "def sat(li: List[int], orig=[1, 1, 3, 2, 0, 8, 32, -4, 0]):\n for i in range(len(li) - 1):\n assert li[i] < li[i + 1]\n assert li[i] in orig\n for n in orig:\n assert n in li\n return True", + "def sat(prod: int, nums=[17, 24, 39, 15, 11, 201, 97, 65, 18]):\n if not all(nums):\n return prod == 0\n for n in nums:\n k = abs(n % 10)\n if k == 0:\n return prod == 0\n assert prod % k == 0\n prod //= k\n return prod == 1", + "def sat(up_down: List[int], nums=[17, 2, 3, 523, 18, -2, 0, 2, -1]):\n assert all(up_down.count(i) == nums.count(i) for i in set(up_down + nums)), \"not a reordering\"\n increasing_sign = 1 if ((nums[0] + nums[-1]) % 2 == 1) else -1\n return all((up_down[i + 1] - up_down[i]) * increasing_sign >= 0 for i in range(len(up_down) - 1))", + "def sat(positions: List[int], s=\"ThIs is A tEsT, Or *IS* iT?\"):\n assert all(s[i] in \"AEIOU\" for i in positions)\n return all(i in positions or c not in \"AEIOU\" or i % 2 == 1 for i, c in enumerate(s))", + "def sat(valid: str, s=\"]]]]]]]]]]]]]]]]][][][][]]]]]]]]]]][[[][[][[[[[][][][]][[[[[[[[[[[[[[[[[[\"):\n assert valid in s\n depths = [0]\n for c in valid:\n if c == \"[\":\n depths.append(depths[-1] + 1)\n elif c == \"]\":\n depths.append(depths[-1] - 1)\n return depths[-1] == 0 and min(depths) == 0 and max(depths) > 1", + "def sat(ham: str, s=\"Any vowel is OK\"):\n vows = \"aeiou\"\n cons = \"bcdfghjklmnpqrstvwxz\"\n return ham in s and ham[0].lower() in cons and ham[1].lower() in vows and ham[2].lower() in cons", + "def sat(s: str, target=\"Hello, world!\"):\n subs = {ord(c): ord(c) + 2 for c in \"aeiouAEIOU\"}\n return s.swapcase() == target.translate(subs)", + "def sat(strange: List[int], li=[30, 12, 42, 717, 45, 317, 200, -1, 491, 32, 15]):\n assert sorted(strange) == sorted(li), \"Must be a permutation\"\n return all(n == (min, max)[i % 2](strange[i:]) for i, n in enumerate(strange))", + "def sat(edges: List[List[int]], z=20, n=5, t=3):\n from itertools import combinations\n edges = {(a, b) for a, b in edges if a in range(n) and b in range(n)} # convert to a set for efficiency\n assert len(edges) >= z\n\n return all(\n any((a, b) not in edges for a in left for b in right)\n for left in combinations(range(n), t)\n for right in combinations(range(n), t)\n )", + "def sat(strategies: List[List[float]], A=[[0.0, -0.5, 1.0], [0.75, 0.0, -1.0], [-1.0, 0.4, 0.0]], eps=0.01):\n m, n = len(A), len(A[0])\n p, q = strategies\n assert all(len(row) == n for row in A), \"inputs are a matrix\"\n assert len(p) == m and len(q) == n, \"solution is a pair of strategies\"\n assert sum(p) == sum(q) == 1.0 and min(p + q) >= 0.0, \"strategies must be non-negative and sum to 1\"\n v = sum(A[i][j] * p[i] * q[j] for i in range(m) for j in range(n))\n return (all(sum(A[i][j] * q[j] for j in range(n)) <= v + eps for i in range(m)) and\n all(sum(A[i][j] * p[i] for i in range(m)) >= v - eps for j in range(n)))", + "def sat(positions: List[List[int]]):\n\n table = [[(i * 429436219 + j * 100239120) % 63491564 for j in range(13)] for i in range(64)]\n\n def zobrist(pos):\n h = 0\n for i in range(64):\n if pos[i]:\n h ^= table[i][pos[i]]\n return h\n\n a, b = positions\n return zobrist(a) == zobrist(b) and a != b" +] diff --git a/ICLR2023/data/train_prefix.txt b/ICLR2023/data/train_prefix.txt new file mode 100755 index 0000000..36cdf81 --- /dev/null +++ b/ICLR2023/data/train_prefix.txt @@ -0,0 +1,41 @@ +from typing import List + +def f(s: str): + return "Hello " + s == "Hello world" + +def g(): + return "world" + +assert f(g()) + +def f(s: str): + return "Hello " + s[::-1] == "Hello world" + +def g(): + return "world"[::-1] + +assert f(g()) + +def f(x: List[int]): + return len(x) == 2 and sum(x) == 3 + +def g(): + return [1, 2] + +assert f(g()) + +def f(s: List[str]): + return len(set(s)) == 1000 and all((x.count("a") > x.count("b")) and ('b' in x) for x in s) + +def g(): + return ["a"*(i+2)+"b" for i in range(1000)] + +assert f(g()) + +def f(n: int): + return str(n * n).startswith("123456789") + +def g(): + return int(int("123456789" + "0"*9) ** 0.5) + 1 + +assert f(g()) \ No newline at end of file