From 9c767baec3fd9b61f81765e255d7b1b3ab1cd0e4 Mon Sep 17 00:00:00 2001 From: ALGW71 Date: Sun, 8 Feb 2026 23:20:42 +0000 Subject: [PATCH 1/3] All I can say is, I am sorry Ben, duplicated code, but it will not compromise speed of standard workflow. This returns logits. --- notebook/examples.ipynb | 178 +++++----- src/anarcii/inference/model_runner.py | 486 +++++++++++++++++++++++++- src/anarcii/pipeline/__init__.py | 9 +- 3 files changed, 570 insertions(+), 103 deletions(-) diff --git a/notebook/examples.ipynb b/notebook/examples.ipynb index a8ebc786..85be8e84 100644 --- a/notebook/examples.ipynb +++ b/notebook/examples.ipynb @@ -2,24 +2,76 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Using device CPU with 8 CPUs\n", "\n", "dict_keys(['Sequence 1', 'Sequence 2', 'Sequence 3', 'Sequence 4'])\n", "\n", - "numbering : [((1, ' '), 'A'), ((2, ' '), 'Q'), ((3, ' '), 'S'), ((4, ' '), 'V'), ((5, ' '), 'T'), ((6, ' '), 'Q'), ((7, ' '), 'L'), ((8, ' '), 'G'), ((9, ' '), 'S'), ((10, ' '), 'H'), ((11, ' '), 'V'), ((12, ' '), 'S'), ((13, ' '), 'V'), ((14, ' '), 'S'), ((15, ' '), 'E'), ((16, ' '), 'G'), ((17, ' '), 'A'), ((18, ' '), 'L'), ((19, ' '), 'V'), ((20, ' '), 'L'), ((21, ' '), 'L'), ((22, ' '), 'R'), ((23, ' '), 'C'), ((24, ' '), 'N'), ((25, ' '), 'Y'), ((26, ' '), 'S'), ((27, ' '), 'S'), ((28, ' '), 'S'), ((29, ' '), 'V'), ((30, ' '), '-'), ((31, ' '), '-'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), '-'), ((36, ' '), 'P'), ((37, ' '), 'P'), ((38, ' '), 'Y'), ((39, ' '), 'L'), ((40, ' '), 'F'), ((41, ' '), 'W'), ((42, ' '), 'Y'), ((43, ' '), 'V'), ((44, ' '), 'Q'), ((45, ' '), 'Y'), ((46, ' '), 'P'), ((47, ' '), 'N'), ((48, ' '), 'Q'), ((49, ' '), 'G'), ((50, ' '), 'L'), ((51, ' '), 'Q'), ((52, ' '), 'L'), ((53, ' '), 'L'), ((54, ' '), 'L'), ((55, ' '), 'K'), ((56, ' '), 'Y'), ((57, ' '), 'T'), ((58, ' '), 'S'), ((59, ' '), 'A'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), 'A'), ((63, ' '), 'T'), ((64, ' '), 'L'), ((65, ' '), 'V'), ((66, ' '), 'K'), ((67, ' '), 'G'), ((68, ' '), 'I'), ((69, ' '), '-'), ((70, ' '), '-'), ((71, ' '), '-'), ((72, ' '), '-'), ((73, ' '), '-'), ((74, ' '), 'N'), ((75, ' '), 'G'), ((76, ' '), 'F'), ((77, ' '), 'E'), ((78, ' '), 'A'), ((79, ' '), 'E'), ((80, ' '), 'F'), ((81, ' '), 'K'), ((82, ' '), 'K'), ((83, ' '), 'S'), ((84, ' '), 'E'), ((85, ' '), 'T'), ((86, ' '), 'S'), ((87, ' '), 'F'), ((88, ' '), 'H'), ((89, ' '), 'L'), ((90, ' '), 'T'), ((91, ' '), 'K'), ((92, ' '), 'P'), ((93, ' '), 'S'), ((94, ' '), 'A'), ((95, ' '), 'H'), ((96, ' '), 'M'), ((97, ' '), 'S'), ((98, ' '), 'D'), ((99, ' '), 'A'), ((100, ' '), 'A'), ((101, ' '), 'E'), ((102, ' '), 'Y'), ((103, ' '), 'F'), ((104, ' '), 'C'), ((105, ' '), 'A'), ((106, ' '), 'V'), ((107, ' '), 'S'), ((108, ' '), 'E'), ((109, ' '), 'Q'), ((110, ' '), '-'), ((111, ' '), '-'), ((112, ' '), '-'), ((113, ' '), 'D'), ((114, ' '), 'D'), ((115, ' '), 'K'), ((116, ' '), 'I'), ((117, ' '), 'I'), ((118, ' '), 'F'), ((119, ' '), 'G'), ((120, ' '), 'K'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'R'), ((124, ' '), 'L'), ((125, ' '), 'H'), ((126, ' '), 'I'), ((127, ' '), 'L'), ((128, ' '), 'P')]\n", - "chain_type : A\n", - "score : 32.16792297363281\n", - "query_start : 0\n", - "query_end : 111\n", - "error : None\n", - "scheme : imgt\n" + "Sequence 1 : tensor([22.5838, 31.6898, 31.4527, 32.8396, 39.1968, 28.7265, 26.4854, 35.5877,\n", + " 25.7327, 37.6802, 32.4417, 32.2934, 28.1486, 30.5405, 28.4391, 27.4314,\n", + " 28.2475, 26.0391, 34.2470, 31.3821, 30.3729, 37.3745, 30.3654, 27.6779,\n", + " 33.9490, 36.4483, 28.2443, 38.9543, 36.2434, 37.1170, 34.2787, 38.3308,\n", + " 31.8065, 29.0670, 26.3710, 25.5805, 28.6327, 33.4610, 31.6199, 36.7579,\n", + " 29.8569, 25.2630, 27.2584, 33.1403, 34.2393, 37.9991, 40.0970, 36.4892,\n", + " 33.8836, 32.4874, 43.7763, 41.4126, 32.6784, 31.7495, 31.3672, 32.6218,\n", + " 34.0671, 26.6324, 30.1969, 23.7971, 29.4850, 29.5168, 29.9851, 31.4282,\n", + " 41.5857, 39.5168, 40.3152, 42.7991, 28.4216, 36.4930, 38.7311, 31.8494,\n", + " 30.0495, 35.8927, 38.6342, 32.0970, 33.6287, 32.7060, 32.9633, 40.2834,\n", + " 41.9010, 34.4784, 38.7741, 37.4402, 35.1924, 29.9138, 31.8103, 36.2601,\n", + " 33.1773, 27.1542, 33.3743, 33.0630, 44.4626, 31.2103, 32.4261, 30.3873,\n", + " 27.2959, 27.1992, 24.7456, 29.5603, 30.5400, 31.2875, 37.3359, 34.2773,\n", + " 29.0145, 26.5760, 25.4447, 28.7686, 25.8767, 24.9761, 25.0015, 17.5661,\n", + " 19.2424])\n", + "Sequence 2 : tensor([ 5.7148, 30.2179, 24.1872, 29.5509, 25.4001, 24.0097, 31.8529, 26.2123,\n", + " 25.0069, 19.8789, 27.6443, 28.4427, 30.4647, 25.9553, 31.3079, 32.4349,\n", + " 34.4294, 32.1585, 31.0647, 30.9016, 38.4637, 35.6251, 33.9866, 27.3903,\n", + " 34.0115, 37.0032, 37.7743, 36.1086, 26.6505, 29.8539, 29.1923, 27.6697,\n", + " 28.9249, 30.2062, 30.5036, 33.3199, 30.0542, 24.8211, 33.9624, 31.3401,\n", + " 33.3160, 34.2448, 29.5792, 34.1259, 34.4462, 27.1545, 40.5046, 39.4822,\n", + " 35.2835, 37.3627, 35.4036, 29.9326, 31.0929, 29.2970, 27.7120, 27.2299,\n", + " 28.4275, 30.9210, 34.0998, 31.3759, 31.2178, 34.9842, 33.3230, 36.9066,\n", + " 37.0539, 39.6846, 32.8623, 36.8928, 28.6145, 33.6384, 37.8850, 36.3539,\n", + " 36.8074, 36.1996, 26.8989, 33.8357, 42.3234, 38.9542, 36.2264, 42.3248,\n", + " 36.0258, 37.7112, 33.8739, 31.3762, 38.7569, 37.2933, 31.6094, 36.9466,\n", + " 34.7402, 39.0402, 30.6942, 32.1447, 29.9778, 26.7233, 25.6392, 24.6415,\n", + " 27.6697, 30.5254, 34.0342, 42.6502, 28.6823, 29.6582, 21.8303, 23.5069,\n", + " 27.7824, 30.2912, 25.5421, 19.1691, 19.7682])\n", + "Sequence 3 : tensor([12.9239, 25.0816, 28.3883, 28.4557, 31.5377, 27.5145, 27.1580, 27.2675,\n", + " 26.4965, 27.6224, 27.2985, 28.0900, 27.8811, 25.1745, 26.4641, 24.9921,\n", + " 25.5914, 25.4983, 24.5084, 27.7500, 26.1022, 28.9263, 26.2085, 27.4731,\n", + " 26.5021, 29.3956, 24.0369, 30.7517, 31.7431, 46.3120, 35.2302, 33.1150,\n", + " 26.6296, 32.4256, 29.6190, 26.2629, 26.6716, 28.7378, 27.5827, 32.7611,\n", + " 28.4548, 30.0210, 26.8251, 30.6011, 31.9220, 27.5251, 30.1503, 29.1241,\n", + " 30.5095, 28.2456, 37.8106, 34.9157, 33.4896, 29.1894, 30.3304, 26.1981,\n", + " 33.6441, 28.3460, 27.0327, 27.2014, 25.3401, 34.2528, 31.5774, 36.4322,\n", + " 35.0746, 35.1558, 31.0661, 36.6514, 35.7466, 30.8785, 37.5318, 34.3933,\n", + " 40.8922, 31.1796, 28.3882, 35.5020, 39.9505, 38.7772, 39.6565, 39.7621,\n", + " 39.9021, 31.4047, 38.0572, 33.0229, 33.4130, 35.8297, 28.7221, 36.7525,\n", + " 29.8354, 33.2201, 25.4760, 27.4983, 29.2145, 26.0279, 23.3462, 26.4837,\n", + " 25.8504, 28.4705, 33.5156, 40.2963, 35.5027, 31.2086, 27.2284, 26.2414,\n", + " 27.1899, 26.6214, 21.9226, 18.5307, 21.7540])\n", + "Sequence 4 : tensor([13.7804, 23.9711, 28.4706, 29.3610, 28.8592, 27.8211, 26.9738, 31.3445,\n", + " 28.0319, 29.5990, 28.4382, 26.0240, 26.5408, 27.2972, 26.3980, 25.5782,\n", + " 28.1557, 29.1561, 28.4808, 30.0942, 30.7771, 28.3871, 32.3001, 25.5386,\n", + " 31.3489, 24.2413, 39.0182, 38.4196, 37.6164, 37.9333, 35.5819, 33.6417,\n", + " 34.4938, 28.1510, 31.5333, 28.2369, 28.2892, 28.3823, 30.5997, 32.1254,\n", + " 33.9772, 33.2400, 30.6017, 26.6109, 28.0319, 34.3041, 30.7074, 31.0401,\n", + " 32.8398, 29.3260, 27.8884, 36.6228, 28.8538, 32.8881, 29.4708, 30.3701,\n", + " 30.9103, 31.9164, 30.5789, 26.6090, 26.5611, 24.6113, 27.3579, 28.4574,\n", + " 27.9184, 30.7647, 26.5552, 32.3514, 30.8712, 32.5677, 30.2417, 28.9319,\n", + " 26.5740, 25.9156, 31.4575, 34.0288, 36.7491, 32.8318, 39.9403, 32.0150,\n", + " 38.8725, 31.8909, 32.5510, 28.2403, 45.1687, 40.5855, 42.3462, 41.7061,\n", + " 41.2951, 37.1519, 36.4636, 40.6107, 36.3662, 36.8156, 36.6676, 37.3680,\n", + " 30.5122, 35.5027, 25.7525, 29.6473, 26.9661, 30.0284, 28.2872, 25.8404,\n", + " 25.0895, 26.7326, 25.9506, 26.7012, 25.3409, 25.9425, 25.0675, 25.2981,\n", + " 25.2231, 28.4855, 25.6817, 25.6900, 26.3425, 26.1936, 27.1129, 33.5035,\n", + " 36.2804, 30.3952, 27.2605, 26.2464, 25.7583, 25.9494, 22.5090, 16.8636,\n", + " 16.3859])\n" ] } ], @@ -28,8 +80,9 @@ "from anarcii import Anarcii\n", "\n", "model = Anarcii(seq_type=\"unknown\", batch_size=128, \n", - " cpu=True, ncpu=8, \n", - " mode=\"accuracy\", verbose=False)\n", + " cpu=False, ncpu=14, \n", + " mode=\"accuracy\", verbose=False,\n", + " return_logits=True)\n", "\n", "seq = [\n", " #Alpha\n", @@ -50,30 +103,15 @@ "print(results.keys())\n", "print()\n", "\n", - "for key, value in results['Sequence 1'].items():\n", - " print(key, \":\", value)\n" + "for key, value in results.items():\n", + " print(key, \":\", value[\"logits\"])\n" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "('Sequence 1', {'numbering': [((1, ' '), 'A'), ((2, ' '), 'Q'), ((3, ' '), 'S'), ((4, ' '), 'V'), ((5, ' '), 'T'), ((6, ' '), 'Q'), ((7, ' '), 'L'), ((8, ' '), 'G'), ((9, ' '), 'S'), ((10, ' '), 'H'), ((11, ' '), 'V'), ((12, ' '), 'S'), ((13, ' '), 'V'), ((14, ' '), 'S'), ((15, ' '), 'E'), ((16, ' '), 'G'), ((17, ' '), 'A'), ((18, ' '), 'L'), ((19, ' '), 'V'), ((20, ' '), 'L'), ((21, ' '), 'L'), ((22, ' '), 'R'), ((23, ' '), 'C'), ((24, ' '), 'N'), ((25, ' '), 'Y'), ((26, ' '), 'S'), ((27, ' '), 'S'), ((28, ' '), 'S'), ((29, ' '), 'V'), ((30, ' '), '-'), ((31, ' '), '-'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), '-'), ((36, ' '), 'P'), ((37, ' '), 'P'), ((38, ' '), 'Y'), ((39, ' '), 'L'), ((40, ' '), 'F'), ((41, ' '), 'W'), ((42, ' '), 'Y'), ((43, ' '), 'V'), ((44, ' '), 'Q'), ((45, ' '), 'Y'), ((46, ' '), 'P'), ((47, ' '), 'N'), ((48, ' '), 'Q'), ((49, ' '), 'G'), ((50, ' '), 'L'), ((51, ' '), 'Q'), ((52, ' '), 'L'), ((53, ' '), 'L'), ((54, ' '), 'L'), ((55, ' '), 'K'), ((56, ' '), 'Y'), ((57, ' '), 'T'), ((58, ' '), 'S'), ((59, ' '), 'A'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), 'A'), ((63, ' '), 'T'), ((64, ' '), 'L'), ((65, ' '), 'V'), ((66, ' '), 'K'), ((67, ' '), 'G'), ((68, ' '), 'I'), ((69, ' '), '-'), ((70, ' '), '-'), ((71, ' '), '-'), ((72, ' '), '-'), ((73, ' '), '-'), ((74, ' '), 'N'), ((75, ' '), 'G'), ((76, ' '), 'F'), ((77, ' '), 'E'), ((78, ' '), 'A'), ((79, ' '), 'E'), ((80, ' '), 'F'), ((81, ' '), 'K'), ((82, ' '), 'K'), ((83, ' '), 'S'), ((84, ' '), 'E'), ((85, ' '), 'T'), ((86, ' '), 'S'), ((87, ' '), 'F'), ((88, ' '), 'H'), ((89, ' '), 'L'), ((90, ' '), 'T'), ((91, ' '), 'K'), ((92, ' '), 'P'), ((93, ' '), 'S'), ((94, ' '), 'A'), ((95, ' '), 'H'), ((96, ' '), 'M'), ((97, ' '), 'S'), ((98, ' '), 'D'), ((99, ' '), 'A'), ((100, ' '), 'A'), ((101, ' '), 'E'), ((102, ' '), 'Y'), ((103, ' '), 'F'), ((104, ' '), 'C'), ((105, ' '), 'A'), ((106, ' '), 'V'), ((107, ' '), 'S'), ((108, ' '), 'E'), ((109, ' '), 'Q'), ((110, ' '), '-'), ((111, ' '), '-'), ((112, ' '), '-'), ((113, ' '), 'D'), ((114, ' '), 'D'), ((115, ' '), 'K'), ((116, ' '), 'I'), ((117, ' '), 'I'), ((118, ' '), 'F'), ((119, ' '), 'G'), ((120, ' '), 'K'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'R'), ((124, ' '), 'L'), ((125, ' '), 'H'), ((126, ' '), 'I'), ((127, ' '), 'L'), ((128, ' '), 'P')], 'chain_type': 'A', 'score': 32.16792297363281, 'query_start': 0, 'query_end': 111, 'error': None, 'scheme': 'imgt'})\n", - "\n", - "('Sequence 2', {'numbering': [((1, ' '), '-'), ((2, ' '), 'A'), ((3, ' '), 'D'), ((4, ' '), 'V'), ((5, ' '), 'T'), ((6, ' '), 'Q'), ((7, ' '), 'T'), ((8, ' '), 'P'), ((9, ' '), 'R'), ((10, ' '), 'N'), ((11, ' '), 'R'), ((12, ' '), 'I'), ((13, ' '), 'T'), ((14, ' '), 'K'), ((15, ' '), 'T'), ((16, ' '), 'G'), ((17, ' '), 'K'), ((18, ' '), 'R'), ((19, ' '), 'I'), ((20, ' '), 'M'), ((21, ' '), 'L'), ((22, ' '), 'E'), ((23, ' '), 'C'), ((24, ' '), 'S'), ((25, ' '), 'Q'), ((26, ' '), 'T'), ((27, ' '), 'K'), ((28, ' '), 'G'), ((29, ' '), 'H'), ((30, ' '), '-'), ((31, ' '), '-'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), '-'), ((36, ' '), '-'), ((37, ' '), 'D'), ((38, ' '), 'R'), ((39, ' '), 'M'), ((40, ' '), 'Y'), ((41, ' '), 'W'), ((42, ' '), 'Y'), ((43, ' '), 'R'), ((44, ' '), 'Q'), ((45, ' '), 'D'), ((46, ' '), 'P'), ((47, ' '), 'G'), ((48, ' '), 'L'), ((49, ' '), 'G'), ((50, ' '), 'L'), ((51, ' '), 'R'), ((52, ' '), 'L'), ((53, ' '), 'I'), ((54, ' '), 'Y'), ((55, ' '), 'Y'), ((56, ' '), 'S'), ((57, ' '), 'F'), ((58, ' '), 'D'), ((59, ' '), '-'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), '-'), ((63, ' '), 'V'), ((64, ' '), 'K'), ((65, ' '), 'D'), ((66, ' '), 'I'), ((67, ' '), 'N'), ((68, ' '), 'K'), ((69, ' '), 'G'), ((70, ' '), 'E'), ((71, ' '), 'I'), ((72, ' '), 'S'), ((73, ' '), '-'), ((74, ' '), 'D'), ((75, ' '), 'G'), ((76, ' '), 'Y'), ((77, ' '), 'S'), ((78, ' '), 'V'), ((79, ' '), 'S'), ((80, ' '), 'R'), ((81, ' '), 'Q'), ((82, ' '), '-'), ((83, ' '), 'A'), ((84, ' '), 'Q'), ((85, ' '), 'A'), ((86, ' '), 'K'), ((87, ' '), 'F'), ((88, ' '), 'S'), ((89, ' '), 'L'), ((90, ' '), 'S'), ((91, ' '), 'L'), ((92, ' '), 'E'), ((93, ' '), 'S'), ((94, ' '), 'A'), ((95, ' '), 'I'), ((96, ' '), 'P'), ((97, ' '), 'N'), ((98, ' '), 'Q'), ((99, ' '), 'T'), ((100, ' '), 'A'), ((101, ' '), 'L'), ((102, ' '), 'Y'), ((103, ' '), 'F'), ((104, ' '), 'C'), ((105, ' '), 'A'), ((106, ' '), 'T'), ((107, ' '), 'S'), ((108, ' '), 'D'), ((109, ' '), 'E'), ((110, ' '), '-'), ((111, ' '), '-'), ((112, ' '), '-'), ((113, ' '), 'S'), ((114, ' '), 'Y'), ((115, ' '), 'G'), ((116, ' '), 'Y'), ((117, ' '), 'T'), ((118, ' '), 'F'), ((119, ' '), 'G'), ((120, ' '), 'S'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'R'), ((124, ' '), 'L'), ((125, ' '), 'T'), ((126, ' '), 'V'), ((127, ' '), 'V'), ((128, ' '), '-')], 'chain_type': 'B', 'score': 31.522838592529297, 'query_start': 0, 'query_end': 109, 'error': None, 'scheme': 'imgt'})\n", - "\n", - "('Sequence 3', {'numbering': [((1, ' '), 'E'), ((2, ' '), 'I'), ((3, ' '), 'V'), ((4, ' '), 'M'), ((5, ' '), 'T'), ((6, ' '), 'Q'), ((7, ' '), 'S'), ((8, ' '), 'P'), ((9, ' '), 'D'), ((10, ' '), 'T'), ((11, ' '), 'L'), ((12, ' '), 'S'), ((13, ' '), 'V'), ((14, ' '), 'S'), ((15, ' '), 'P'), ((16, ' '), 'G'), ((17, ' '), 'E'), ((18, ' '), 'R'), ((19, ' '), 'A'), ((20, ' '), 'T'), ((21, ' '), 'L'), ((22, ' '), 'S'), ((23, ' '), 'C'), ((24, ' '), 'R'), ((25, ' '), 'A'), ((26, ' '), 'S'), ((27, ' '), 'E'), ((28, ' '), 'S'), ((29, ' '), 'I'), ((30, ' '), '-'), ((31, ' '), '-'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), '-'), ((36, ' '), 'S'), ((37, ' '), 'S'), ((38, ' '), 'N'), ((39, ' '), 'L'), ((40, ' '), 'A'), ((41, ' '), 'W'), ((42, ' '), 'Y'), ((43, ' '), 'Q'), ((44, ' '), 'Q'), ((45, ' '), 'K'), ((46, ' '), 'P'), ((47, ' '), 'G'), ((48, ' '), 'Q'), ((49, ' '), 'V'), ((50, ' '), 'P'), ((51, ' '), 'R'), ((52, ' '), 'L'), ((53, ' '), 'L'), ((54, ' '), 'I'), ((55, ' '), 'Y'), ((56, ' '), 'G'), ((57, ' '), 'A'), ((58, ' '), '-'), ((59, ' '), '-'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), '-'), ((63, ' '), '-'), ((64, ' '), '-'), ((65, ' '), 'S'), ((66, ' '), 'T'), ((67, ' '), 'R'), ((68, ' '), 'A'), ((69, ' '), 'T'), ((70, ' '), 'G'), ((71, ' '), 'V'), ((72, ' '), 'P'), ((73, ' '), '-'), ((74, ' '), 'A'), ((75, ' '), 'R'), ((76, ' '), 'F'), ((77, ' '), 'T'), ((78, ' '), 'G'), ((79, ' '), 'S'), ((80, ' '), 'G'), ((81, ' '), '-'), ((82, ' '), '-'), ((83, ' '), 'S'), ((84, ' '), 'G'), ((85, ' '), 'T'), ((86, ' '), 'E'), ((87, ' '), 'F'), ((88, ' '), 'T'), ((89, ' '), 'L'), ((90, ' '), 'T'), ((91, ' '), 'I'), ((92, ' '), 'S'), ((93, ' '), 'S'), ((94, ' '), 'L'), ((95, ' '), 'Q'), ((96, ' '), 'S'), ((97, ' '), 'E'), ((98, ' '), 'D'), ((99, ' '), 'F'), ((100, ' '), 'A'), ((101, ' '), 'V'), ((102, ' '), 'Y'), ((103, ' '), 'Y'), ((104, ' '), 'C'), ((105, ' '), 'Q'), ((106, ' '), 'Q'), ((107, ' '), 'Y'), ((108, ' '), 'N'), ((109, ' '), 'N'), ((110, ' '), '-'), ((111, ' '), '-'), ((112, ' '), '-'), ((113, ' '), 'R'), ((114, ' '), 'L'), ((115, ' '), 'P'), ((116, ' '), 'Y'), ((117, ' '), 'T'), ((118, ' '), 'F'), ((119, ' '), 'G'), ((120, ' '), 'Q'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'K'), ((124, ' '), 'L'), ((125, ' '), 'E'), ((126, ' '), 'I'), ((127, ' '), 'K'), ((128, ' '), '-')], 'chain_type': 'K', 'score': 30.132719039916992, 'query_start': 0, 'query_end': 107, 'error': None, 'scheme': 'imgt'})\n", - "\n", - "('Sequence 4', {'numbering': [((1, ' '), 'E'), ((2, ' '), 'V'), ((3, ' '), 'Q'), ((4, ' '), 'L'), ((5, ' '), 'L'), ((6, ' '), 'E'), ((7, ' '), 'S'), ((8, ' '), 'G'), ((9, ' '), 'G'), ((10, ' '), '-'), ((11, ' '), 'G'), ((12, ' '), 'L'), ((13, ' '), 'V'), ((14, ' '), 'Q'), ((15, ' '), 'P'), ((16, ' '), 'G'), ((17, ' '), 'G'), ((18, ' '), 'S'), ((19, ' '), 'L'), ((20, ' '), 'R'), ((21, ' '), 'L'), ((22, ' '), 'S'), ((23, ' '), 'C'), ((24, ' '), 'A'), ((25, ' '), 'A'), ((26, ' '), 'S'), ((27, ' '), 'G'), ((28, ' '), 'F'), ((29, ' '), 'T'), ((30, ' '), 'F'), ((31, ' '), '-'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), 'N'), ((36, ' '), 'H'), ((37, ' '), 'Y'), ((38, ' '), 'A'), ((39, ' '), 'M'), ((40, ' '), 'T'), ((41, ' '), 'W'), ((42, ' '), 'V'), ((43, ' '), 'R'), ((44, ' '), 'Q'), ((45, ' '), 'A'), ((46, ' '), 'P'), ((47, ' '), 'G'), ((48, ' '), 'K'), ((49, ' '), 'G'), ((50, ' '), 'L'), ((51, ' '), 'E'), ((52, ' '), 'W'), ((53, ' '), 'V'), ((54, ' '), 'A'), ((55, ' '), 'S'), ((56, ' '), 'S'), ((57, ' '), 'S'), ((58, ' '), 'G'), ((59, ' '), 'S'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), 'G'), ((63, ' '), 'R'), ((64, ' '), 'S'), ((65, ' '), 'T'), ((66, ' '), 'Y'), ((67, ' '), 'Y'), ((68, ' '), 'T'), ((69, ' '), 'D'), ((70, ' '), 'S'), ((71, ' '), 'V'), ((72, ' '), 'K'), ((73, ' '), '-'), ((74, ' '), 'G'), ((75, ' '), 'R'), ((76, ' '), 'F'), ((77, ' '), 'S'), ((78, ' '), 'V'), ((79, ' '), 'S'), ((80, ' '), 'R'), ((81, ' '), 'D'), ((82, ' '), 'N'), ((83, ' '), 'S'), ((84, ' '), 'K'), ((85, ' '), 'N'), ((86, ' '), 'T'), ((87, ' '), 'L'), ((88, ' '), 'Y'), ((89, ' '), 'L'), ((90, ' '), 'Q'), ((91, ' '), 'M'), ((92, ' '), 'N'), ((93, ' '), 'S'), ((94, ' '), 'L'), ((95, ' '), 'R'), ((96, ' '), 'A'), ((97, ' '), 'E'), ((98, ' '), 'D'), ((99, ' '), 'T'), ((100, ' '), 'A'), ((101, ' '), 'V'), ((102, ' '), 'Y'), ((103, ' '), 'Y'), ((104, ' '), 'C'), ((105, ' '), 'A'), ((106, ' '), 'K'), ((107, ' '), 'S'), ((108, ' '), 'S'), ((109, ' '), 'N'), ((110, ' '), 'Y'), ((111, ' '), 'Y'), ((111, 'A'), 'G'), ((111, 'B'), 'S'), ((111, 'C'), 'G'), ((111, 'D'), 'S'), ((112, 'E'), 'Y'), ((112, 'D'), 'S'), ((112, 'C'), 'P'), ((112, 'B'), 'D'), ((112, 'A'), 'D'), ((112, ' '), 'Y'), ((113, ' '), 'Y'), ((114, ' '), 'H'), ((115, ' '), 'M'), ((116, ' '), 'D'), ((117, ' '), 'V'), ((118, ' '), 'W'), ((119, ' '), 'G'), ((120, ' '), 'Q'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'T'), ((124, ' '), 'V'), ((125, ' '), 'T'), ((126, ' '), 'V'), ((127, ' '), 'S'), ((128, ' '), 'G')], 'chain_type': 'H', 'score': 30.514766693115234, 'query_start': 0, 'query_end': 128, 'error': None, 'scheme': 'imgt'})\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "# The dict can also be converted to a list\n", "for seq in list(results.items()):\n", @@ -83,20 +121,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using device CPU with 12 CPUs\n", - "{'numbering': [((1, ' '), 'Q'), ((2, ' '), 'I'), ((3, ' '), 'H'), ((4, ' '), 'L'), ((5, ' '), 'V'), ((6, ' '), 'Q'), ((7, ' '), 'S'), ((8, ' '), 'G'), ((9, ' '), 'T'), ((10, ' '), '-'), ((11, ' '), 'E'), ((12, ' '), 'V'), ((13, ' '), 'K'), ((14, ' '), 'K'), ((15, ' '), 'P'), ((16, ' '), 'G'), ((17, ' '), 'S'), ((18, ' '), 'S'), ((19, ' '), 'V'), ((20, ' '), 'T'), ((21, ' '), 'V'), ((22, ' '), 'S'), ((23, ' '), 'C'), ((24, ' '), 'K'), ((25, ' '), 'A'), ((26, ' '), 'Y'), ((27, ' '), 'G'), ((28, ' '), 'V'), ((29, ' '), 'N'), ((30, ' '), 'T'), ((31, ' '), 'F'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), 'G'), ((36, ' '), 'L'), ((37, ' '), 'Y'), ((38, ' '), 'A'), ((39, ' '), 'V'), ((40, ' '), 'N'), ((41, ' '), 'W'), ((42, ' '), 'V'), ((43, ' '), 'R'), ((44, ' '), 'Q'), ((45, ' '), 'A'), ((46, ' '), 'P'), ((47, ' '), 'G'), ((48, ' '), 'Q'), ((49, ' '), 'S'), ((50, ' '), 'L'), ((51, ' '), 'E'), ((52, ' '), 'Y'), ((53, ' '), 'I'), ((54, ' '), 'G'), ((55, ' '), 'Q'), ((56, ' '), 'I'), ((57, ' '), 'W'), ((58, ' '), 'R'), ((59, ' '), 'W'), ((60, ' '), 'K'), ((61, ' '), 'S'), ((62, ' '), 'S'), ((63, ' '), 'A'), ((64, ' '), 'S'), ((65, ' '), 'H'), ((66, ' '), 'H'), ((67, ' '), 'F'), ((68, ' '), 'R'), ((69, ' '), 'G'), ((70, ' '), 'R'), ((71, ' '), 'V'), ((72, ' '), 'L'), ((73, ' '), 'I'), ((74, ' '), 'S'), ((75, ' '), 'A'), ((76, ' '), 'V'), ((77, ' '), 'D'), ((78, ' '), 'L'), ((79, ' '), 'T'), ((80, ' '), 'G'), ((81, ' '), 'S'), ((82, ' '), '-'), ((83, ' '), 'S'), ((84, ' '), 'P'), ((85, ' '), 'P'), ((86, ' '), 'I'), ((87, ' '), 'S'), ((88, ' '), 'S'), ((89, ' '), 'L'), ((90, ' '), 'E'), ((91, ' '), 'I'), ((92, ' '), 'K'), ((93, ' '), 'N'), ((94, ' '), 'L'), ((95, ' '), 'T'), ((96, ' '), 'S'), ((97, ' '), 'D'), ((98, ' '), 'D'), ((99, ' '), 'T'), ((100, ' '), 'A'), ((101, ' '), 'V'), ((102, ' '), 'Y'), ((103, ' '), 'F'), ((104, ' '), 'C'), ((105, ' '), 'T'), ((106, ' '), 'T'), ((107, ' '), 'T'), ((108, ' '), 'S'), ((109, ' '), 'T'), ((110, ' '), 'Y'), ((111, ' '), 'D'), ((111, 'A'), 'K'), ((111, 'B'), 'W'), ((111, 'C'), 'S'), ((111, 'D'), 'G'), ((112, 'E'), 'L'), ((112, 'D'), 'H'), ((112, 'C'), 'H'), ((112, 'B'), 'D'), ((112, 'A'), 'G'), ((112, ' '), 'V'), ((113, ' '), 'M'), ((114, ' '), 'A'), ((115, ' '), 'F'), ((116, ' '), 'S'), ((117, ' '), 'S'), ((118, ' '), 'W'), ((119, ' '), 'G'), ((120, ' '), 'Q'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'L'), ((124, ' '), 'I'), ((125, ' '), 'S'), ((126, ' '), 'V'), ((127, ' '), 'S'), ((128, ' '), 'A')], 'chain_type': 'H', 'score': 25.120664596557617, 'query_start': 0, 'query_end': 131, 'error': None, 'scheme': 'imgt'}\n", - "{'numbering': [((1, ' '), 'Q'), ((2, ' '), 'P'), ((3, ' '), 'G'), ((4, ' '), 'L'), ((5, ' '), 'T'), ((6, ' '), 'Q'), ((7, ' '), 'P'), ((8, ' '), 'P'), ((9, ' '), 'S'), ((10, ' '), '-'), ((11, ' '), 'V'), ((12, ' '), 'S'), ((13, ' '), 'K'), ((14, ' '), 'G'), ((15, ' '), 'L'), ((16, ' '), 'R'), ((17, ' '), 'Q'), ((18, ' '), 'T'), ((19, ' '), 'A'), ((20, ' '), 'T'), ((21, ' '), 'L'), ((22, ' '), 'T'), ((23, ' '), 'C'), ((24, ' '), 'T'), ((25, ' '), 'G'), ((26, ' '), 'N'), ((27, ' '), 'S'), ((28, ' '), 'N'), ((29, ' '), 'N'), ((30, ' '), 'V'), ((31, ' '), '-'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), 'G'), ((36, ' '), 'N'), ((37, ' '), 'Q'), ((38, ' '), 'G'), ((39, ' '), 'A'), ((40, ' '), 'A'), ((41, ' '), 'W'), ((42, ' '), 'L'), ((43, ' '), 'Q'), ((44, ' '), 'Q'), ((45, ' '), 'H'), ((46, ' '), 'Q'), ((47, ' '), 'G'), ((48, ' '), 'H'), ((49, ' '), 'P'), ((50, ' '), 'P'), ((51, ' '), 'K'), ((52, ' '), 'L'), ((53, ' '), 'L'), ((54, ' '), 'S'), ((55, ' '), 'Y'), ((56, ' '), 'R'), ((57, ' '), 'N'), ((58, ' '), '-'), ((59, ' '), '-'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), '-'), ((63, ' '), '-'), ((64, ' '), '-'), ((65, ' '), 'N'), ((66, ' '), 'D'), ((67, ' '), 'R'), ((68, ' '), 'P'), ((69, ' '), 'S'), ((70, ' '), 'G'), ((71, ' '), 'I'), ((72, ' '), 'S'), ((73, ' '), '-'), ((74, ' '), 'E'), ((75, ' '), 'R'), ((76, ' '), 'F'), ((77, ' '), 'S'), ((78, ' '), 'A'), ((79, ' '), 'S'), ((80, ' '), 'R'), ((81, ' '), '-'), ((82, ' '), '-'), ((83, ' '), 'S'), ((84, ' '), 'G'), ((85, ' '), 'N'), ((86, ' '), 'T'), ((87, ' '), 'A'), ((88, ' '), 'S'), ((89, ' '), 'L'), ((90, ' '), 'T'), ((91, ' '), 'I'), ((92, ' '), 'T'), ((93, ' '), 'G'), ((94, ' '), 'L'), ((95, ' '), 'Q'), ((96, ' '), 'P'), ((97, ' '), 'E'), ((98, ' '), 'D'), ((99, ' '), 'E'), ((100, ' '), 'A'), ((101, ' '), 'D'), ((102, ' '), 'Y'), ((103, ' '), 'Y'), ((104, ' '), 'C'), ((105, ' '), 'S'), ((106, ' '), 'T'), ((107, ' '), 'W'), ((108, ' '), 'D'), ((109, ' '), 'S'), ((110, ' '), 'S'), ((111, ' '), '-'), ((112, ' '), '-'), ((113, ' '), 'L'), ((114, ' '), 'S'), ((115, ' '), 'A'), ((116, ' '), 'V'), ((117, ' '), 'V'), ((118, ' '), 'F'), ((119, ' '), 'G'), ((120, ' '), 'G'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'K'), ((124, ' '), 'L'), ((125, ' '), 'T'), ((126, ' '), 'V'), ((127, ' '), 'L'), ((128, ' '), '-')], 'chain_type': 'L', 'score': 30.606189727783203, 'query_start': 9, 'query_end': 118, 'error': None, 'scheme': 'imgt'}\n", - "{'numbering': [((1, ' '), 'Q'), ((2, ' '), 'V'), ((3, ' '), 'Q'), ((4, ' '), 'L'), ((5, ' '), 'V'), ((6, ' '), 'Q'), ((7, ' '), 'S'), ((8, ' '), 'G'), ((9, ' '), 'A'), ((10, ' '), '-'), ((11, ' '), 'E'), ((12, ' '), 'V'), ((13, ' '), 'K'), ((14, ' '), 'K'), ((15, ' '), 'P'), ((16, ' '), 'G'), ((17, ' '), 'S'), ((18, ' '), 'S'), ((19, ' '), 'V'), ((20, ' '), 'K'), ((21, ' '), 'V'), ((22, ' '), 'S'), ((23, ' '), 'C'), ((24, ' '), 'K'), ((25, ' '), 'A'), ((26, ' '), 'S'), ((27, ' '), 'G'), ((28, ' '), 'G'), ((29, ' '), 'T'), ((30, ' '), 'F'), ((31, ' '), '-'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), 'S'), ((36, ' '), 'S'), ((37, ' '), 'Y'), ((38, ' '), 'A'), ((39, ' '), 'I'), ((40, ' '), 'S'), ((41, ' '), 'W'), ((42, ' '), 'V'), ((43, ' '), 'R'), ((44, ' '), 'Q'), ((45, ' '), 'A'), ((46, ' '), 'P'), ((47, ' '), 'G'), ((48, ' '), 'Q'), ((49, ' '), 'G'), ((50, ' '), 'L'), ((51, ' '), 'E'), ((52, ' '), 'W'), ((53, ' '), 'M'), ((54, ' '), 'G'), ((55, ' '), 'G'), ((56, ' '), 'I'), ((57, ' '), 'I'), ((58, ' '), 'P'), ((59, ' '), 'I'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), 'F'), ((63, ' '), 'G'), ((64, ' '), 'T'), ((65, ' '), 'A'), ((66, ' '), 'N'), ((67, ' '), 'Y'), ((68, ' '), 'A'), ((69, ' '), 'Q'), ((70, ' '), 'K'), ((71, ' '), 'F'), ((72, ' '), 'Q'), ((73, ' '), '-'), ((74, ' '), 'G'), ((75, ' '), 'R'), ((76, ' '), 'V'), ((77, ' '), 'T'), ((78, ' '), 'I'), ((79, ' '), 'T'), ((80, ' '), 'A'), ((81, ' '), 'D'), ((82, ' '), 'E'), ((83, ' '), 'S'), ((84, ' '), 'T'), ((85, ' '), 'S'), ((86, ' '), 'T'), ((87, ' '), 'A'), ((88, ' '), 'Y'), ((89, ' '), 'M'), ((90, ' '), 'E'), ((91, ' '), 'L'), ((92, ' '), 'S'), ((93, ' '), 'S'), ((94, ' '), 'L'), ((95, ' '), 'R'), ((96, ' '), 'S'), ((97, ' '), 'E'), ((98, ' '), 'D'), ((99, ' '), 'T'), ((100, ' '), 'A'), ((101, ' '), 'V'), ((102, ' '), 'Y'), ((103, ' '), 'Y'), ((104, ' '), 'C'), ((105, ' '), 'A'), ((106, ' '), 'R'), ((107, ' '), 'E'), ((108, ' '), 'P'), ((109, ' '), 'D'), ((110, ' '), 'Y'), ((111, ' '), 'Y'), ((111, 'A'), 'D'), ((111, 'B'), 'S'), ((111, 'C'), 'S'), ((112, 'D'), 'G'), ((112, 'C'), 'Y'), ((112, 'B'), 'Y'), ((112, 'A'), 'P'), ((112, ' '), 'I'), ((113, ' '), 'D'), ((114, ' '), 'A'), ((115, ' '), 'F'), ((116, ' '), 'D'), ((117, ' '), 'I'), ((118, ' '), 'W'), ((119, ' '), 'G'), ((120, ' '), 'Q'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'T'), ((124, ' '), 'V'), ((125, ' '), 'T'), ((126, ' '), 'V'), ((127, ' '), 'S'), ((128, ' '), 'S')], 'chain_type': 'H', 'score': 30.965309143066406, 'query_start': 0, 'query_end': 126, 'error': None, 'scheme': 'imgt'}\n" - ] - } - ], + "outputs": [], "source": [ "##### It can also take a list of tuples... #####\n", "from anarcii import Anarcii\n", @@ -124,26 +151,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using device CPU with 12 CPUs\n", - "\n", - "[([((1, ' '), '-'), ((2, ' '), '-'), ((3, ' '), '-'), ((4, ' '), '-'), ((5, ' '), '-'), ((6, ' '), '-'), ((7, ' '), '-'), ((8, ' '), '-'), ((9, ' '), '-'), ((10, ' '), '-'), ((11, ' '), '-'), ((12, ' '), '-'), ((13, ' '), '-'), ((14, ' '), '-'), ((15, ' '), '-'), ((16, ' '), '-'), ((17, ' '), '-'), ((18, ' '), 'S'), ((19, ' '), 'V'), ((20, ' '), 'K'), ((21, ' '), 'V'), ((22, ' '), 'S'), ((23, ' '), 'C'), ((24, ' '), 'T'), ((25, ' '), 'S'), ((26, ' '), 'S'), ((27, ' '), 'E'), ((28, ' '), 'V'), ((29, ' '), 'T'), ((30, ' '), 'F'), ((31, ' '), '-'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), 'S'), ((36, ' '), 'S'), ((37, ' '), 'F'), ((38, ' '), 'A'), ((39, ' '), 'I'), ((40, ' '), 'S'), ((41, ' '), 'W'), ((42, ' '), 'V'), ((43, ' '), 'R'), ((44, ' '), 'Q'), ((45, ' '), 'A'), ((46, ' '), 'P'), ((47, ' '), 'G'), ((48, ' '), 'Q'), ((49, ' '), 'G'), ((50, ' '), 'L'), ((51, ' '), 'E'), ((52, ' '), 'W'), ((53, ' '), 'L'), ((54, ' '), 'G'), ((55, ' '), 'G'), ((56, ' '), 'I'), ((57, ' '), 'S'), ((58, ' '), 'P'), ((59, ' '), 'M'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), 'F'), ((63, ' '), 'G'), ((64, ' '), 'T'), ((65, ' '), 'P'), ((66, ' '), 'N'), ((67, ' '), 'Y'), ((68, ' '), 'A'), ((69, ' '), 'Q'), ((70, ' '), 'K'), ((71, ' '), 'F'), ((72, ' '), 'Q'), ((73, ' '), '-'), ((74, ' '), 'G'), ((75, ' '), 'R'), ((76, ' '), 'V'), ((77, ' '), 'T'), ((78, ' '), 'I'), ((79, ' '), 'T'), ((80, ' '), 'A'), ((81, ' '), 'D'), ((82, ' '), 'Q'), ((83, ' '), 'S'), ((84, ' '), 'T'), ((85, ' '), 'R'), ((86, ' '), 'T'), ((87, ' '), 'A'), ((88, ' '), 'Y'), ((89, ' '), 'M'), ((90, ' '), 'D'), ((91, ' '), 'L'), ((92, ' '), 'R'), ((93, ' '), 'S'), ((94, ' '), 'L'), ((95, ' '), 'R'), ((96, ' '), 'S'), ((97, ' '), 'E'), ((98, ' '), 'D'), ((99, ' '), 'T'), ((100, ' '), 'A'), ((101, ' '), 'V'), ((102, ' '), 'Y'), ((103, ' '), 'Y'), ((104, ' '), 'C'), ((105, ' '), 'A'), ((106, ' '), 'R'), ((107, ' '), 'S'), ((108, ' '), 'P'), ((109, ' '), 'S'), ((110, ' '), 'Y'), ((111, ' '), 'I'), ((111, 'A'), 'C'), ((111, 'B'), 'S'), ((112, 'B'), 'G'), ((112, 'A'), 'G'), ((112, ' '), 'T'), ((113, ' '), 'C'), ((114, ' '), 'V'), ((115, ' '), 'F'), ((116, ' '), 'D'), ((117, ' '), 'H'), ((118, ' '), 'W'), ((119, ' '), 'G'), ((120, ' '), 'Q'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'L'), ((124, ' '), 'V'), ((125, ' '), 'T'), ((126, ' '), 'V'), ((127, ' '), 'S'), ((128, ' '), 'S')], 0, 107)]\n", - "[{'chain_type': 'H', 'scheme': 'imgt', 'query_name': 'Sequence 1', 'query_start': 0, 'query_end': 107}]\n", - "None\n", - "\n", - "### A failed sequence should return None. ###\n", - "None\n", - "[{'chain_type': 'F', 'scheme': 'imgt', 'query_name': 'Sequence 2', 'query_start': None, 'query_end': None}]\n", - "None\n" - ] - } - ], + "outputs": [], "source": [ "### Want to have output that looks like original ANARCI? Use legacy mode. ###\n", "from anarcii import Anarcii\n", @@ -180,33 +190,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using device CUDA with 12 CPUs\n", - "\n", - " sp|P01629|KV2A4_MOUSE Ig kappa chain V-II region 2S1.3 OS=Mus musculus OX=10090 PE=1 SV=1 \n", - " K 30.36195182800293 \n", - " [((1, ' '), 'D'), ((2, ' '), 'I'), ((3, ' '), 'V'), ((4, ' '), 'M'), ((5, ' '), 'T'), ((6, ' '), 'Q'), ((7, ' '), 'A'), ((8, ' '), 'A'), ((9, ' '), 'F'), ((10, ' '), 'S'), ((11, ' '), 'N'), ((12, ' '), 'P'), ((13, ' '), 'V'), ((14, ' '), 'T'), ((15, ' '), 'L'), ((16, ' '), 'G'), ((17, ' '), 'T'), ((18, ' '), 'S'), ((19, ' '), 'A'), ((20, ' '), 'S'), ((21, ' '), 'F'), ((22, ' '), 'S'), ((23, ' '), 'C'), ((24, ' '), 'R'), ((25, ' '), 'S'), ((26, ' '), 'S'), ((27, ' '), 'K'), ((28, ' '), 'S'), ((29, ' '), 'L'), ((30, ' '), 'Q'), ((31, ' '), 'Q'), ((32, ' '), 'S'), ((33, ' '), '-'), ((34, ' '), 'K'), ((35, ' '), 'G'), ((36, ' '), 'I'), ((37, ' '), 'T'), ((38, ' '), 'Y'), ((39, ' '), 'L'), ((40, ' '), 'Y'), ((41, ' '), 'W'), ((42, ' '), 'Y'), ((43, ' '), 'L'), ((44, ' '), 'Q'), ((45, ' '), 'K'), ((46, ' '), 'P'), ((47, ' '), 'G'), ((48, ' '), 'Q'), ((49, ' '), 'S'), ((50, ' '), 'P'), ((51, ' '), 'Q'), ((52, ' '), 'L'), ((53, ' '), 'L'), ((54, ' '), 'I'), ((55, ' '), 'Y'), ((56, ' '), 'Q'), ((57, ' '), 'M'), ((58, ' '), '-'), ((59, ' '), '-'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), '-'), ((63, ' '), '-'), ((64, ' '), '-'), ((65, ' '), 'S'), ((66, ' '), 'N'), ((67, ' '), 'L'), ((68, ' '), 'A'), ((69, ' '), 'S'), ((70, ' '), 'G'), ((71, ' '), 'V'), ((72, ' '), 'P'), ((73, ' '), '-'), ((74, ' '), 'D'), ((75, ' '), 'R'), ((76, ' '), 'F'), ((77, ' '), 'S'), ((78, ' '), 'G'), ((79, ' '), 'S'), ((80, ' '), 'G'), ((81, ' '), '-'), ((82, ' '), '-'), ((83, ' '), 'S'), ((84, ' '), 'G'), ((85, ' '), 'T'), ((86, ' '), 'D'), ((87, ' '), 'F'), ((88, ' '), 'T'), ((89, ' '), 'L'), ((90, ' '), 'R'), ((91, ' '), 'I'), ((92, ' '), 'S'), ((93, ' '), 'R'), ((94, ' '), 'V'), ((95, ' '), 'E'), ((96, ' '), 'A'), ((97, ' '), 'E'), ((98, ' '), 'D'), ((99, ' '), 'V'), ((100, ' '), 'G'), ((101, ' '), 'V'), ((102, ' '), 'Y'), ((103, ' '), 'Y'), ((104, ' '), 'C'), ((105, ' '), 'A'), ((106, ' '), 'N'), ((107, ' '), 'L'), ((108, ' '), 'Q'), ((109, ' '), 'E'), ((110, ' '), '-'), ((111, ' '), '-'), ((112, ' '), '-'), ((113, ' '), '-'), ((114, ' '), 'L'), ((115, ' '), 'P'), ((116, ' '), 'Y'), ((117, ' '), 'T'), ((118, ' '), 'F'), ((119, ' '), 'G'), ((120, ' '), 'G'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'K'), ((124, ' '), 'L'), ((125, ' '), 'E'), ((126, ' '), 'I'), ((127, ' '), 'K'), ((128, ' '), '-')]\n", - "\n", - " sp|P01630|KV2A6_MOUSE Ig kappa chain V-II region 7S34.1 OS=Mus musculus OX=10090 PE=1 SV=1 \n", - " K 30.4111328125 \n", - " [((1, ' '), 'D'), ((2, ' '), 'I'), ((3, ' '), 'V'), ((4, ' '), 'M'), ((5, ' '), 'T'), ((6, ' '), 'Q'), ((7, ' '), 'T'), ((8, ' '), 'A'), ((9, ' '), 'P'), ((10, ' '), 'S'), ((11, ' '), 'A'), ((12, ' '), 'L'), ((13, ' '), 'V'), ((14, ' '), 'T'), ((15, ' '), 'P'), ((16, ' '), 'G'), ((17, ' '), 'E'), ((18, ' '), 'S'), ((19, ' '), 'V'), ((20, ' '), 'S'), ((21, ' '), 'I'), ((22, ' '), 'S'), ((23, ' '), 'C'), ((24, ' '), 'R'), ((25, ' '), 'S'), ((26, ' '), 'S'), ((27, ' '), 'K'), ((28, ' '), 'S'), ((29, ' '), 'L'), ((30, ' '), 'L'), ((31, ' '), 'H'), ((32, ' '), 'S'), ((33, ' '), '-'), ((34, ' '), 'N'), ((35, ' '), 'G'), ((36, ' '), 'N'), ((37, ' '), 'T'), ((38, ' '), 'Y'), ((39, ' '), 'L'), ((40, ' '), 'Y'), ((41, ' '), 'W'), ((42, ' '), 'F'), ((43, ' '), 'L'), ((44, ' '), 'Q'), ((45, ' '), 'R'), ((46, ' '), 'P'), ((47, ' '), 'G'), ((48, ' '), 'Q'), ((49, ' '), 'C'), ((50, ' '), 'P'), ((51, ' '), 'Q'), ((52, ' '), 'L'), ((53, ' '), 'L'), ((54, ' '), 'I'), ((55, ' '), 'Y'), ((56, ' '), 'R'), ((57, ' '), 'M'), ((58, ' '), '-'), ((59, ' '), '-'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), '-'), ((63, ' '), '-'), ((64, ' '), '-'), ((65, ' '), 'S'), ((66, ' '), 'N'), ((67, ' '), 'L'), ((68, ' '), 'A'), ((69, ' '), 'S'), ((70, ' '), 'G'), ((71, ' '), 'V'), ((72, ' '), 'P'), ((73, ' '), '-'), ((74, ' '), 'D'), ((75, ' '), 'R'), ((76, ' '), 'F'), ((77, ' '), 'S'), ((78, ' '), 'G'), ((79, ' '), 'S'), ((80, ' '), 'G'), ((81, ' '), '-'), ((82, ' '), '-'), ((83, ' '), 'S'), ((84, ' '), 'G'), ((85, ' '), 'T'), ((86, ' '), 'A'), ((87, ' '), 'F'), ((88, ' '), 'T'), ((89, ' '), 'L'), ((90, ' '), 'R'), ((91, ' '), 'I'), ((92, ' '), 'S'), ((93, ' '), 'R'), ((94, ' '), 'V'), ((95, ' '), 'E'), ((96, ' '), 'A'), ((97, ' '), 'E'), ((98, ' '), 'D'), ((99, ' '), 'V'), ((100, ' '), 'G'), ((101, ' '), 'V'), ((102, ' '), 'Y'), ((103, ' '), 'Y'), ((104, ' '), 'C'), ((105, ' '), 'M'), ((106, ' '), 'Q'), ((107, ' '), 'Q'), ((108, ' '), 'R'), ((109, ' '), 'E'), ((110, ' '), '-'), ((111, ' '), '-'), ((112, ' '), '-'), ((113, ' '), '-'), ((114, ' '), 'Y'), ((115, ' '), 'P'), ((116, ' '), 'Y'), ((117, ' '), 'T'), ((118, ' '), 'F'), ((119, ' '), 'G'), ((120, ' '), 'G'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'K'), ((124, ' '), 'L'), ((125, ' '), 'E'), ((126, ' '), 'I'), ((127, ' '), 'K'), ((128, ' '), '-')]\n", - "\n", - " sp|P01631|KV2A7_MOUSE Ig kappa chain V-II region 26-10 OS=Mus musculus OX=10090 PE=1 SV=1 \n", - " K 30.665319442749023 \n", - " [((1, ' '), 'D'), ((2, ' '), 'V'), ((3, ' '), 'V'), ((4, ' '), 'M'), ((5, ' '), 'T'), ((6, ' '), 'Q'), ((7, ' '), 'T'), ((8, ' '), 'P'), ((9, ' '), 'L'), ((10, ' '), 'S'), ((11, ' '), 'L'), ((12, ' '), 'P'), ((13, ' '), 'V'), ((14, ' '), 'S'), ((15, ' '), 'L'), ((16, ' '), 'G'), ((17, ' '), 'D'), ((18, ' '), 'Q'), ((19, ' '), 'A'), ((20, ' '), 'S'), ((21, ' '), 'I'), ((22, ' '), 'S'), ((23, ' '), 'C'), ((24, ' '), 'R'), ((25, ' '), 'S'), ((26, ' '), 'S'), ((27, ' '), 'Q'), ((28, ' '), 'S'), ((29, ' '), 'L'), ((30, ' '), 'V'), ((31, ' '), 'H'), ((32, ' '), 'S'), ((33, ' '), '-'), ((34, ' '), 'N'), ((35, ' '), 'G'), ((36, ' '), 'N'), ((37, ' '), 'T'), ((38, ' '), 'Y'), ((39, ' '), 'L'), ((40, ' '), 'N'), ((41, ' '), 'W'), ((42, ' '), 'Y'), ((43, ' '), 'L'), ((44, ' '), 'Q'), ((45, ' '), 'K'), ((46, ' '), 'A'), ((47, ' '), 'G'), ((48, ' '), 'Q'), ((49, ' '), 'S'), ((50, ' '), 'P'), ((51, ' '), 'K'), ((52, ' '), 'L'), ((53, ' '), 'L'), ((54, ' '), 'I'), ((55, ' '), 'Y'), ((56, ' '), 'K'), ((57, ' '), 'V'), ((58, ' '), '-'), ((59, ' '), '-'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), '-'), ((63, ' '), '-'), ((64, ' '), '-'), ((65, ' '), 'S'), ((66, ' '), 'N'), ((67, ' '), 'R'), ((68, ' '), 'F'), ((69, ' '), 'S'), ((70, ' '), 'G'), ((71, ' '), 'V'), ((72, ' '), 'P'), ((73, ' '), '-'), ((74, ' '), 'D'), ((75, ' '), 'R'), ((76, ' '), 'F'), ((77, ' '), 'S'), ((78, ' '), 'G'), ((79, ' '), 'S'), ((80, ' '), 'G'), ((81, ' '), '-'), ((82, ' '), '-'), ((83, ' '), 'S'), ((84, ' '), 'G'), ((85, ' '), 'T'), ((86, ' '), 'D'), ((87, ' '), 'F'), ((88, ' '), 'T'), ((89, ' '), 'L'), ((90, ' '), 'K'), ((91, ' '), 'I'), ((92, ' '), 'S'), ((93, ' '), 'R'), ((94, ' '), 'V'), ((95, ' '), 'E'), ((96, ' '), 'A'), ((97, ' '), 'E'), ((98, ' '), 'D'), ((99, ' '), 'L'), ((100, ' '), 'G'), ((101, ' '), 'I'), ((102, ' '), 'Y'), ((103, ' '), 'F'), ((104, ' '), 'C'), ((105, ' '), 'S'), ((106, ' '), 'Q'), ((107, ' '), 'T'), ((108, ' '), 'T'), ((109, ' '), 'H'), ((110, ' '), '-'), ((111, ' '), '-'), ((112, ' '), '-'), ((113, ' '), '-'), ((114, ' '), 'V'), ((115, ' '), 'P'), ((116, ' '), 'P'), ((117, ' '), 'T'), ((118, ' '), 'F'), ((119, ' '), 'G'), ((120, ' '), 'G'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'K'), ((124, ' '), 'L'), ((125, ' '), 'E'), ((126, ' '), 'I'), ((127, ' '), 'K'), ((128, ' '), '-')]\n", - "\n", - " sp|P01691|KV10_RABIT Ig kappa chain V region 12F2 (Fragment) OS=Oryctolagus cuniculus OX=9986 PE=2 SV=1 \n", - " K 29.508617401123047 \n", - " [((1, ' '), 'A'), ((2, ' '), 'Y'), ((3, ' '), 'D'), ((4, ' '), 'M'), ((5, ' '), 'T'), ((6, ' '), 'Q'), ((7, ' '), 'T'), ((8, ' '), 'P'), ((9, ' '), 'A'), ((10, ' '), 'S'), ((11, ' '), 'V'), ((12, ' '), 'E'), ((13, ' '), 'V'), ((14, ' '), 'A'), ((15, ' '), 'V'), ((16, ' '), 'G'), ((17, ' '), 'G'), ((18, ' '), 'T'), ((19, ' '), 'V'), ((20, ' '), 'T'), ((21, ' '), 'I'), ((22, ' '), 'K'), ((23, ' '), 'C'), ((24, ' '), 'Q'), ((25, ' '), 'A'), ((26, ' '), 'S'), ((27, ' '), 'Q'), ((28, ' '), 'S'), ((29, ' '), 'I'), ((30, ' '), '-'), ((31, ' '), '-'), ((32, ' '), '-'), ((33, ' '), '-'), ((34, ' '), '-'), ((35, ' '), '-'), ((36, ' '), 'S'), ((37, ' '), 'T'), ((38, ' '), 'Y'), ((39, ' '), 'L'), ((40, ' '), 'S'), ((41, ' '), 'W'), ((42, ' '), 'Y'), ((43, ' '), 'Q'), ((44, ' '), 'Q'), ((45, ' '), 'K'), ((46, ' '), 'P'), ((47, ' '), 'G'), ((48, ' '), 'Q'), ((49, ' '), 'R'), ((50, ' '), 'P'), ((51, ' '), 'K'), ((52, ' '), 'L'), ((53, ' '), 'L'), ((54, ' '), 'I'), ((55, ' '), 'Y'), ((56, ' '), 'R'), ((57, ' '), 'A'), ((58, ' '), '-'), ((59, ' '), '-'), ((60, ' '), '-'), ((61, ' '), '-'), ((62, ' '), '-'), ((63, ' '), '-'), ((64, ' '), '-'), ((65, ' '), 'S'), ((66, ' '), 'T'), ((67, ' '), 'L'), ((68, ' '), 'A'), ((69, ' '), 'S'), ((70, ' '), 'G'), ((71, ' '), 'V'), ((72, ' '), 'S'), ((73, ' '), '-'), ((74, ' '), 'S'), ((75, ' '), 'R'), ((76, ' '), 'F'), ((77, ' '), 'K'), ((78, ' '), 'G'), ((79, ' '), 'S'), ((80, ' '), 'G'), ((81, ' '), '-'), ((82, ' '), '-'), ((83, ' '), 'S'), ((84, ' '), 'G'), ((85, ' '), 'T'), ((86, ' '), 'E'), ((87, ' '), 'F'), ((88, ' '), 'T'), ((89, ' '), 'L'), ((90, ' '), 'T'), ((91, ' '), 'I'), ((92, ' '), 'S'), ((93, ' '), 'G'), ((94, ' '), 'V'), ((95, ' '), 'E'), ((96, ' '), 'C'), ((97, ' '), 'A'), ((98, ' '), 'D'), ((99, ' '), 'A'), ((100, ' '), 'A'), ((101, ' '), 'T'), ((102, ' '), 'Y'), ((103, ' '), 'Y'), ((104, ' '), 'C'), ((105, ' '), 'Q'), ((106, ' '), 'Q'), ((107, ' '), 'G'), ((108, ' '), 'W'), ((109, ' '), 'S'), ((110, ' '), 'S'), ((111, ' '), '-'), ((112, ' '), 'S'), ((113, ' '), 'N'), ((114, ' '), 'V'), ((115, ' '), 'E'), ((116, ' '), 'N'), ((117, ' '), 'V'), ((118, ' '), 'F'), ((119, ' '), 'G'), ((120, ' '), 'G'), ((121, ' '), 'G'), ((122, ' '), 'T'), ((123, ' '), 'E'), ((124, ' '), 'V'), ((125, ' '), 'V'), ((126, ' '), 'V'), ((127, ' '), 'K'), ((128, ' '), '-')]\n" - ] - } - ], + "outputs": [], "source": [ "### Can also be passed a fasta file. ###\n", "from anarcii import Anarcii\n", @@ -230,17 +216,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Last output saved to tmp/test_write_csv.csv in scheme: None.\n" - ] - } - ], + "outputs": [], "source": [ "# Save the last numbered seqs to a csv.\n", "model.to_csv(\"tmp/test_write_csv.csv\")" @@ -249,7 +227,7 @@ ], "metadata": { "kernelspec": { - "display_name": "mobydick", + "display_name": "test_anarcii", "language": "python", "name": "python3" }, @@ -263,7 +241,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/src/anarcii/inference/model_runner.py b/src/anarcii/inference/model_runner.py index c8f29fc5..045428b8 100644 --- a/src/anarcii/inference/model_runner.py +++ b/src/anarcii/inference/model_runner.py @@ -45,12 +45,13 @@ class ModelRunner: """ - def __init__(self, sequence_type, mode, batch_size, device, verbose): + def __init__(self, sequence_type, mode, batch_size, device, verbose, return_logits): self.type = sequence_type.lower() self.mode = mode.lower() self.batch_size = batch_size self.device = device self.verbose = verbose + self.return_logits = return_logits self.cut_off = CUTOFF_SCORE if self.type == "antibody": @@ -113,7 +114,11 @@ def __call__(self, tokenised_seqs: dict[str, TokenisedSequence], offsets): # NB: Provide a list of recommended batch sizes based on RAM and architecture dl = dataloader(self.batch_size, list(tokenised_seqs.values())) - numbering = dict(zip(tokenised_seqs, self._predict_numbering(dl))) + + if self.return_logits: + numbering = dict(zip(tokenised_seqs, self._predict_numbering_logits(dl))) + else: + numbering = dict(zip(tokenised_seqs, self._predict_numbering(dl))) # Add offsets, where necessary. for key, value in offsets.items(): @@ -602,3 +607,480 @@ def _predict_numbering(self, dl): ) return numbering + + def _predict_numbering_logits(self, dl): + """ + AS ABOVE - WITH LOGITS + """ + if self.verbose: + print(f"Making predictions on {len(dl)} batches.") + + numbering = [] + + num = 0 + + pad_token = self.pad_token + sos_token = self.sos_token + eos_token = self.eos_token + skip_token = self.skip_token + x_token = self.x_token + + ### 1 RUN AUTOREGRESSIVE INFERENCE LOOP OVER BATCHES + + with torch.no_grad(): + for X in dl: + src = X.to(self.device) + batch_size = src.shape[0] + trg_len = src.shape[1] + 1 # Need to add 1 to include chain ID + + src_mask = self.model.make_src_mask(src) + enc_src = self.model.encoder(src, src_mask) + + input = src[:, 0].unsqueeze(1) + mask_input = src[:, 0].unsqueeze(1) + cache = None + + max_input = torch.zeros( + batch_size, trg_len, device=self.device, dtype=torch.long + ) + max_input[:, 0] = src[:, 0] + + scores = torch.zeros( + batch_size, trg_len - 1, device=self.device, dtype=torch.float + ) + + for t in range(1, trg_len): + trg_pad_mask, trg_causal_mask = self.model.make_trg_mask(mask_input) + + output, cache = self.model.decoder( + input, enc_src, trg_pad_mask, trg_causal_mask, src_mask, cache + ) + + pred_token = output.argmax(2)[:, -1].unsqueeze(1) + + scores[:, t - 1 : t] = output.topk(1, dim=2).values.squeeze(1) + + max_input[:, t : t + 1] = pred_token + + mask_input = max_input[:, : t + 1] + + input = pred_token + + ### 2 tokenise and transfer the batch to cpu + + src_tokens = self.sequence_tokeniser.tokens[src.to("cpu")] + pred_tokens = self.number_tokeniser.tokens[max_input.to("cpu")] + + ### 3 work out IMGT integer values predicted by model + + scores = scores.squeeze( + -1 + ) # Remove the last dim; shape becomes [batch_size, trg_len] + + mask = ( + (max_input != skip_token) + & (max_input != x_token) + & (max_input != pad_token) + & (max_input != sos_token) + ) + + mask2 = ( + (max_input != skip_token) + # & (max_input != x_token) # Keep X tokens for logit return + & (max_input != pad_token) + & (max_input != sos_token) + ) + + ### 4 Find the predicted end of sequence by model, + # find actual end of input + + # Find first `True` (eos_token) along last dim (trg_len) + eos_positions = max_input == eos_token + # Get the indices (trg_len), for each batch + first_eos_positions = torch.argmax(eos_positions.to(torch.int64), dim=1) + + # Same logic to find SRC EOS position + src_eos_matrix = src == eos_token + src_eos_positions = torch.argmax(src_eos_matrix.to(torch.int64), dim=1) + + # Check if no EOS token is found for each batch + no_eos_found = ~(eos_positions.any(dim=1)) + # True if no EOS token is found in the row + # Set the position to trg_len if no EOS is found + + first_eos_positions[no_eos_found] = torch.tensor( + trg_len - 1, device=self.device + ) + + ### 5 Iterate through each seq in batch + + for batch_no in range(batch_size): + error_occurred = False + num += 1 + error_msg = None + + # eos_position = first_eos_positions[batch_no] + + # Code fix here. Ensure that if model has numbered beyond SRC EOS + # Then stop. + eos_position = min( + first_eos_positions[batch_no], src_eos_positions[batch_no] + 1 + ) + + valid_indices = torch.arange(eos_position, device=self.device)[ + mask[batch_no, :eos_position] + ] + valid_scores = scores[batch_no, valid_indices] + + all_indices = torch.arange(eos_position, device=self.device)[ + mask2[batch_no, :eos_position] + ] + all_scores = scores[batch_no, all_indices] + + ### 5A Check score is valid + + if len(valid_indices) >= 50: + normalized_score = valid_scores.mean().item() + else: + normalized_score = 0.0 + error_msg = "Less than 50 non insertion residues numbered." + + if normalized_score < self.cut_off: + numbering.append( + { + "numbering": None, + "chain_type": "F", + "score": normalized_score, + "query_start": None, + "query_end": None, + "error": error_msg or "Score less than cut off.", + "scheme": "imgt", + "logits": all_scores, + } + ) + # skip the rest of the loop. + continue + + ### 5B Begin populating the numbering labels but iterating over + # each seq + + residues, nums = [], [] + backfill_residues = [] + + started = False + in_x_run, x_count = False, 0 + start_index = None + end_index = None + + # SRC is missing chain token + 1 + src_eos_position = src_eos_positions[batch_no].item() + 1 + eos_position = eos_position.item() + + for seq_position in range(2, eos_position): + ### Break at actual EOS in the input sequence + if src_tokens[batch_no, seq_position - 1] == "": + # The end index position in the sequence + # -3 is to accomodate the shifted register due to - + # the , chain token and python zero + end_index = seq_position - 3 + break + + ### Break at SKIP tokens if numbering has started + elif ( + pred_tokens[batch_no, seq_position] == "" and started + ): # Break if hitting a skip post at the end. + end_index = seq_position - 3 + break + + ### Work out when numbering begins, ignore SKIP tokens if + ### not started + elif ( + pred_tokens[batch_no, seq_position] == "" + and not started + ): # Append as backfill up to the start. + backfill_residues.append( + src_tokens[batch_no, seq_position - 1] + ) + continue + + ### If an instertion X is called, log as in a run of X + elif pred_tokens[batch_no, seq_position] == "X": + x_count += 1 + in_x_run = True + + ### If breaking out of a X run, construct the labels + elif ( + isinstance(pred_tokens[batch_no, seq_position], int) + and in_x_run + ): + # This code breaks if we have a junk seq that + # has predicted runs of X (insertions) + # that are not bookended with integers + try: + construction = build_inward_list( + length=x_count, + # number before X began + start_num=int( + pred_tokens[ + batch_no, (seq_position - (x_count + 1)) + ] + ), + # current number + end_num=int(pred_tokens[batch_no, seq_position]), + ) + + # Add the construction over the previous sequence + nums[(seq_position - x_count) : seq_position] = ( + construction + ) + # add the end + nums.append( + (int(pred_tokens[batch_no, seq_position]), " ") + ) + in_x_run = False + x_count = 0 + + except ValueError as e: + # Capture the error message from the exception + captured_error = str(e) + numbering.append( + { + "numbering": None, + "chain_type": "F", + "score": normalized_score, + "query_start": None, + "query_end": None, + "error": "Could not apply numbering: " + f"{captured_error}", + "scheme": "imgt", + "logits": all_scores, + } + ) + error_occurred = True + break + + ### No conditions have been found - it is a number label, + # append to nums + else: + try: + nums.append( + (int(pred_tokens[batch_no, seq_position]), " ") + ) + except ValueError as e: + # Capture the error message from the exception + captured_error = str(e) + numbering.append( + { + "numbering": None, + "chain_type": "F", + "score": normalized_score, + "query_start": None, + "query_end": None, + "error": "Could not apply numbering: " + f"{captured_error}", + "scheme": "imgt", + "logits": all_scores, + } + ) + error_occurred = True + break + + ### After each iteration through the sequence append the + # sequence residue + residues.append(src_tokens[batch_no, seq_position - 1]) + + if not started: + start_index = seq_position - 2 + started = True + + if error_occurred: + continue + + # Assign an end index before entering forwardfill + if not end_index: + end_index = eos_position - 3 + # eos_position - 1: Moves to the token before , + # excluding the marker itself. + # Subtracting an additional 1 for SOS and 1 for CLS: + # Adjusts further to skip over these two tokens. + + ## Check for duplicates + if len(nums) != len(set(nums)): + numbering.append( + { + "numbering": None, + "chain_type": "F", + "score": normalized_score, + "query_start": None, + "query_end": None, + "error": "Model predicted duplicate numbers", + "scheme": "imgt", + "logits": all_scores, + } + ) + # break out of the loop + continue + + ### 5C Perform forward fill to end of the sequence, if + # missed numbering + + ## ANARCII sometimes doesn't continue numbering to end of seq + # Solution: Identify residues remaining after the EOS + # Decide forward fill to 127 (KL) /128 (H) needs to occur. + + # The last number depends on chain type - check type here. + if pred_tokens[batch_no, 1] in ["H", "A", "G"]: + last_num = 128 + else: + last_num = 127 + + try: + last_predicted_num = int( + pred_tokens[batch_no, eos_position - 1] + ) + except ValueError: + last_predicted_num = last_num + + ### DEBUG ONLY ### + # if src_tokens[batch_no, eos_position - 1] == "": + # print(src_tokens[batch_no, :]) + # print(pred_tokens[batch_no, :]) + + # print(src_tokens[batch_no, eos_position - 1]) + # print(last_num, last_predicted_num) + # print(last_predicted_num != last_num) + + if ( + src_tokens[batch_no, eos_position - 1] not in ["", ""] + and last_predicted_num != last_num + and last_predicted_num > 119 + ): + # How far is EOS from 128? + missing_count = last_num - int( + pred_tokens[batch_no, eos_position - 1] + ) + + # How much is left of the source to number? + seq_remainder = int(src_eos_position) - int(eos_position) + + missing_end_nums = [ + (x, " ") + for x in range(last_predicted_num + 1, last_num + 1) + ] + + missing_end_nums = missing_end_nums[:seq_remainder] + + # # DEBUG PURPOSE ONLY + # print("\n") + # for i in range( + # eos_position-3, + # eos_position + min(missing_count, seq_remainder)): + # print( + # "\t", src_tokens[batch_no, i + 0], + # "\t", pred_tokens[batch_no, i + 1], + # ) + + missing_end_residues = [] + for i in range( + eos_position, + eos_position + min(missing_count, seq_remainder), + ): + missing_end_residues.append(src_tokens[batch_no, i - 1]) + + # print("Last:\t", last_num) + # print("Last pred num:\t", last_predicted_num) + # print("Missing count:\t", missing_count) + + # print("Missing num:\t", missing_end_nums) + # print("Missing res:\t", missing_end_residues) + + # # Append the misssing labels to seq and nums: + nums = nums + missing_end_nums + residues = residues + missing_end_residues + + end_index = end_index + len(missing_end_nums) + + ### 5D Perform backfill for missed start of sequence, if missed + # numbering + + # This step ensures that first and last nums will always be + # integers - does not proceed if not. + try: + first_num = int(nums[0][0]) # get first number + last_num = int(nums[-1][0]) # get last number + except (IndexError, ValueError) as e: + # When numbering has failed, `nums` is an empty list. + # For some non-antibody/TCR sequences that the model does not + # recognise, the first number can be a string, like an EOS or + # an X token. End the loop here and move on to the next seq + # in the batch. + captured_error = str(e) + numbering.append( + { + "numbering": None, + "chain_type": "F", + "score": normalized_score, + "query_start": None, + "query_end": None, + "error": f"Could not apply numbering: {captured_error}", + "scheme": "imgt", + "logits": all_scores, + } + ) + continue + + # Should not do this before 10 in case of failure to + # identify the gap. + if first_num > 1 and first_num < 9 and len(backfill_residues) > 0: + # This creates a list from 1 to first_num - 1 + vals = list(range(1, first_num)) + # the problem here is if there is a lot of junk... + vals = vals[-len(backfill_residues) :] + nums = [(i, " ") for i in vals] + nums + residues = list(backfill_residues[-len(vals) :]) + residues + + # Adjust the start index for the backfill + start_index = start_index - len(vals) + + ### 5E Fill in up to 1 (starting IMGT residue) with gaps + first_num = int(nums[0][0]) # get first number again - may change + for missing_num in range( + first_num - 1, 0, -1 + ): # Start from first_num - 1, stop at 1, step by -1 + nums.insert(0, (missing_num, " ")) + residues.insert(0, "-") + + ### 5F Add gaps to nums where we are missing a number: + # e.g. predicted labels are 91 L, 93 K. convert to >> + # 91 L, 92 -, 93 K + i = 1 + while i < len(nums): + if (int(nums[i][0]) - 1) > int(nums[i - 1][0]): + nums.insert(i, (int(nums[i - 1][0]) + 1, " ")) + residues.insert(i, "-") + else: + i += 1 # Only increment if no insertion is made + + # Ensure the last number is 128 >>>>> + last_num = int(nums[-1][0]) + for missing_num in range(last_num + 1, 129): + nums.append((missing_num, " ")) + residues.append("-") + + ### 6 Populate the meta data dict and append to alignment list + + # Successful - append. + numbering.append( + { + "numbering": list(zip(nums, residues)), + "chain_type": str(pred_tokens[batch_no, 1]), + "score": normalized_score, + "query_start": start_index, + "query_end": end_index, + "error": None, + "scheme": "imgt", + "logits": all_scores, + } + ) + + return numbering diff --git a/src/anarcii/pipeline/__init__.py b/src/anarcii/pipeline/__init__.py index 952b582b..92b3983d 100644 --- a/src/anarcii/pipeline/__init__.py +++ b/src/anarcii/pipeline/__init__.py @@ -100,6 +100,7 @@ def __init__( ncpu: int = -1, verbose: bool = False, max_seqs_len=1024 * 100, + return_logits=False, ): self.seq_type = seq_type.lower() @@ -112,6 +113,7 @@ def __init__( self.verbose = verbose self.cpu = cpu self.max_seqs_len = max_seqs_len + self.return_logits = return_logits self._last_numbered_output: dict | Path | None = None # Has a conversion to a new number scheme occured? @@ -391,7 +393,12 @@ def to_csv(self, file_path): def number_with_type(self, seqs: dict[str, str], seq_type, scfv): model = ModelRunner( - seq_type, self.mode, self.batch_size, self.device, self.verbose + seq_type, + self.mode, + self.batch_size, + self.device, + self.verbose, + self.return_logits, ) window_model = WindowFinder(seq_type, self.mode, self.batch_size, self.device) From 66568a0f851d159472d6ae13d9bbb86ab4151628 Mon Sep 17 00:00:00 2001 From: ALGW71 Date: Mon, 9 Feb 2026 22:19:55 +0000 Subject: [PATCH 2/3] Add some comments to explain what is going on. --- src/anarcii/inference/model_runner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/anarcii/inference/model_runner.py b/src/anarcii/inference/model_runner.py index 045428b8..bedc77ec 100644 --- a/src/anarcii/inference/model_runner.py +++ b/src/anarcii/inference/model_runner.py @@ -115,6 +115,7 @@ def __call__(self, tokenised_seqs: dict[str, TokenisedSequence], offsets): dl = dataloader(self.batch_size, list(tokenised_seqs.values())) + # The best solution to avoid slowing down the code was one if else here. if self.return_logits: numbering = dict(zip(tokenised_seqs, self._predict_numbering_logits(dl))) else: @@ -610,7 +611,7 @@ def _predict_numbering(self, dl): def _predict_numbering_logits(self, dl): """ - AS ABOVE - WITH LOGITS + AS ABOVE - with logits being added to the dictionary """ if self.verbose: print(f"Making predictions on {len(dl)} batches.") @@ -684,6 +685,8 @@ def _predict_numbering_logits(self, dl): & (max_input != sos_token) ) + ### LOGIT MASK - as above but include logit values of X + # This allows users to see instertion values. mask2 = ( (max_input != skip_token) # & (max_input != x_token) # Keep X tokens for logit return @@ -732,13 +735,13 @@ def _predict_numbering_logits(self, dl): ] valid_scores = scores[batch_no, valid_indices] + ### NEW LOGIT EXTRACTION FOR ALL positions ### all_indices = torch.arange(eos_position, device=self.device)[ mask2[batch_no, :eos_position] ] all_scores = scores[batch_no, all_indices] ### 5A Check score is valid - if len(valid_indices) >= 50: normalized_score = valid_scores.mean().item() else: From 62f33721b6e50e2b33d1e229cb42309f090d5a37 Mon Sep 17 00:00:00 2001 From: ALGW71 Date: Mon, 9 Feb 2026 22:20:59 +0000 Subject: [PATCH 3/3] Add some comments to explain what is going on. --- src/anarcii/inference/model_runner.py | 1 + src/anarcii/pipeline/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anarcii/inference/model_runner.py b/src/anarcii/inference/model_runner.py index bedc77ec..b7df6d24 100644 --- a/src/anarcii/inference/model_runner.py +++ b/src/anarcii/inference/model_runner.py @@ -740,6 +740,7 @@ def _predict_numbering_logits(self, dl): mask2[batch_no, :eos_position] ] all_scores = scores[batch_no, all_indices] + ### ### 5A Check score is valid if len(valid_indices) >= 50: diff --git a/src/anarcii/pipeline/__init__.py b/src/anarcii/pipeline/__init__.py index 92b3983d..91318289 100644 --- a/src/anarcii/pipeline/__init__.py +++ b/src/anarcii/pipeline/__init__.py @@ -100,7 +100,7 @@ def __init__( ncpu: int = -1, verbose: bool = False, max_seqs_len=1024 * 100, - return_logits=False, + return_logits: bool = False, ): self.seq_type = seq_type.lower()