diff --git a/stcrpy/tcr_processing/TCR.py b/stcrpy/tcr_processing/TCR.py index 6e5a71e..1575de3 100644 --- a/stcrpy/tcr_processing/TCR.py +++ b/stcrpy/tcr_processing/TCR.py @@ -157,13 +157,13 @@ def get_TCR_type(self): """ if hasattr(self, "tcr_type"): return self.tcr_type - elif hasattr(self, "VB") and hasattr(self, "VA"): + elif hasattr(self, "VA") and hasattr(self, "VB"): self.tcr_type = "abTCR" return self.tcr_type - elif hasattr(self, "VD") and hasattr(self, "VG"): + elif hasattr(self, "VG") and hasattr(self, "VD"): self.tcr_type = "gdTCR" return self.tcr_type - elif hasattr(self, "VB") and hasattr(self, "VD"): + elif hasattr(self, "VD") and hasattr(self, "VB"): self.tcr_type = "dbTCR" return self.tcr_type @@ -563,7 +563,7 @@ def __init__(self, c1, c2): c2 (TCRchain): alpha or beta type TCR chain """ - if c1.chain_type == "B": + if c1.chain_type == "A": Entity.__init__(self, c1.id + c2.id) else: Entity.__init__(self, c2.id + c1.id) @@ -572,9 +572,9 @@ def __init__(self, c1, c2): self.level = "H" self._add_domain(c1) self._add_domain(c2) - self.child_list = sorted( - self.child_list, key=lambda x: x.chain_type, reverse=True - ) # make sure that the list goes B->A or G->D + # make sure that the list goes A->B, D->B, or D->G + chain_ordering = {'A': 1, 'B': 2, 'D': 1, 'G': 2} + self.child_list = sorted(self.child_list, key=lambda x: chain_ordering[x.chain_type]) self.antigen = [] self.MHC = [] self.engineered = False @@ -589,7 +589,7 @@ def __repr__(self): Returns: str: String representation of the abTCR objec """ - return "" % (self.VB, self.VA, self.VB, self.VA) + return f"" def _add_domain(self, chain): """ @@ -599,24 +599,15 @@ def _add_domain(self, chain): Args: chain (TCRchain): TCR chain whose domain is being added. """ - if chain.chain_type == "B": - self.VB = chain.id - elif chain.chain_type == "A" or chain.chain_type == "D": + if chain.chain_type == "A" or chain.chain_type == "D": self.VA = chain.id + elif chain.chain_type == "B": + self.VB = chain.id + # Add the chain as a child of this entity. self.add(chain) - def get_VB(self): - """ - Retrieve the variable beta chain of the TCR - - Returns: - TCRchain: VB chain - """ - if hasattr(self, "VB"): - return self.child_dict[self.VB] - def get_VA(self): """ Retrieve the variable alpha chain of the TCR @@ -627,6 +618,16 @@ def get_VA(self): if hasattr(self, "VA"): return self.child_dict[self.VA] + def get_VB(self): + """ + Retrieve the variable beta chain of the TCR + + Returns: + TCRchain: VB chain + """ + if hasattr(self, "VB"): + return self.child_dict[self.VB] + def get_domain_assignment(self): """ Retrieve the domain assignment of the TCR as a dict with variable domain type as key and chain ID as value. @@ -637,10 +638,10 @@ def get_domain_assignment(self): try: return {"VA": self.VA, "VB": self.VB} except AttributeError: - if hasattr(self, "VB"): - return {"VB": self.VB} if hasattr(self, "VA"): return {"VA": self.VA} + if hasattr(self, "VB"): + return {"VB": self.VB} return None def is_engineered(self): @@ -653,8 +654,8 @@ def is_engineered(self): if self.engineered: return True else: - vb, va = self.get_VB(), self.get_VA() - for var_domain in [vb, va]: + va, vb = self.get_VA(), self.get_VB() + for var_domain in [va, vb]: if var_domain and var_domain.is_engineered(): self.engineered = True return self.engineered @@ -669,10 +670,10 @@ def get_fragments(self): Yields: Fragment: fragment of TCR chain. """ - vb, va = self.get_VB(), self.get_VA() + va, vb = self.get_VA(), self.get_VB() # If a variable domain exists - for var_domain in [vb, va]: + for var_domain in [va, vb]: if var_domain: for frag in var_domain.get_fragments(): yield frag @@ -699,16 +700,16 @@ def standardise_chain_names(self) -> None: new_id = [] new_child_dict = {} - if hasattr(self, 'VB'): - new_child_dict['E'] = self.child_dict[self.VB] - self.VB = 'E' - new_id.append('E') - if hasattr(self, 'VA'): new_child_dict['D'] = self.child_dict[self.VA] self.VA = 'D' new_id.append('D') + if hasattr(self, 'VB'): + new_child_dict['E'] = self.child_dict[self.VB] + self.VB = 'E' + new_id.append('E') + with warnings.catch_warnings(): warnings.simplefilter('ignore', BiopythonWarning) @@ -732,7 +733,7 @@ class gdTCR(TCR): def __init__(self, c1, c2): - if c1.chain_type == "D": + if c1.chain_type == "G": Entity.__init__(self, c1.id + c2.id) else: Entity.__init__(self, c2.id + c1.id) @@ -741,9 +742,9 @@ def __init__(self, c1, c2): self.level = "H" self._add_domain(c1) self._add_domain(c2) - self.child_list = sorted( - self.child_list, key=lambda x: x.chain_type - ) # make sure that the list goes B->A or D->G + # make sure that the list goes A->B, D->B, or D->G + chain_ordering = {'A': 1, 'B': 2, 'D': 1, 'G': 2} + self.child_list = sorted(self.child_list, key=lambda x: chain_ordering[x.chain_type]) self.antigen = [] self.MHC = [] self.engineered = False @@ -752,41 +753,44 @@ def __init__(self, c1, c2): self.visualise_interactions = self._create_interaction_visualiser() def __repr__(self): - return "" % (self.VD, self.VG, self.VD, self.VG) + return f"" def _add_domain(self, chain): - if chain.chain_type == "D": - self.VD = chain.id - elif chain.chain_type == "G": + if chain.chain_type == "G": self.VG = chain.id + elif chain.chain_type == "D": + self.VD = chain.id + # Add the chain as a child of this entity. self.add(chain) - def get_VD(self): - if hasattr(self, "VD"): - return self.child_dict[self.VD] - def get_VG(self): if hasattr(self, "VG"): return self.child_dict[self.VG] + def get_VD(self): + if hasattr(self, "VD"): + return self.child_dict[self.VD] + def get_domain_assignment(self): try: return {"VG": self.VG, "VD": self.VD} except AttributeError: - if hasattr(self, "VD"): - return {"VD": self.VD} if hasattr(self, "VG"): return {"VG": self.VG} + + if hasattr(self, "VD"): + return {"VD": self.VD} + return None def is_engineered(self): if self.engineered: return True else: - vd, vg = self.get_VD(), self.get_VG() - for var_domain in [vd, vg]: + vg, vd = self.get_VG(), self.get_VD() + for var_domain in [vg, vd]: if var_domain and var_domain.is_engineered(): self.engineered = True return self.engineered @@ -795,7 +799,7 @@ def is_engineered(self): return False def get_fragments(self): - vd, vg = self.get_VD(), self.get_VG() + vg, vd = self.get_VG(), self.get_VD() # If a variable domain exists for var_domain in [vg, vd]: @@ -858,7 +862,7 @@ class dbTCR(TCR): def __init__(self, c1, c2): super(TCR, self).__init__() - if c1.chain_type == "B": + if c1.chain_type == "D": Entity.__init__(self, c1.id + c2.id) else: Entity.__init__(self, c2.id + c1.id) @@ -867,9 +871,9 @@ def __init__(self, c1, c2): self.level = "H" self._add_domain(c1) self._add_domain(c2) - self.child_list = sorted( - self.child_list, key=lambda x: x.chain_type, reverse=False - ) # make sure that the list goes B->D + # make sure that the list goes A->B, D->B, or D->G + chain_ordering = {'A': 1, 'B': 2, 'D': 1, 'G': 2} + self.child_list = sorted(self.child_list, key=lambda x: chain_ordering[x.chain_type]) self.antigen = [] self.MHC = [] self.engineered = False @@ -878,41 +882,42 @@ def __init__(self, c1, c2): self.visualise_interactions = self._create_interaction_visualiser() def __repr__(self): - return "" % (self.VB, self.VD, self.VB, self.VD) + return f"" def _add_domain(self, chain): - if chain.chain_type == "B": - self.VB = chain.id - elif chain.chain_type == "D": + if chain.chain_type == "D": self.VD = chain.id + elif chain.chain_type == "B": + self.VB = chain.id + # Add the chain as a child of this entity. self.add(chain) - def get_VB(self): - if hasattr(self, "VB"): - return self.child_dict[self.VB] - def get_VD(self): if hasattr(self, "VD"): return self.child_dict[self.VD] + def get_VB(self): + if hasattr(self, "VB"): + return self.child_dict[self.VB] + def get_domain_assignment(self): try: return {"VD": self.VD, "VB": self.VB} except AttributeError: - if hasattr(self, "VB"): - return {"VB": self.VB} if hasattr(self, "VD"): return {"VD": self.VD} + if hasattr(self, "VB"): + return {"VB": self.VB} return None def is_engineered(self): if self.engineered: return True else: - vb, vd = self.get_VB(), self.get_VD() - for var_domain in [vb, vd]: + vd, vb, = self.get_VD(), self.get_VB() + for var_domain in [vd, vb]: if var_domain and var_domain.is_engineered(): self.engineered = True return self.engineered @@ -921,10 +926,10 @@ def is_engineered(self): return False def get_fragments(self): - vb, vd = self.get_VB(), self.get_VD() + vd, vb = self.get_VD(), self.get_VB() # If a variable domain exists - for var_domain in [vb, vd]: + for var_domain in [vd, vb]: if var_domain: for frag in var_domain.get_fragments(): yield frag @@ -951,16 +956,16 @@ def standardise_chain_names(self) -> None: new_id = [] new_child_dict = {} - if hasattr(self, 'VB'): - new_child_dict['E'] = self.child_dict[self.VB] - self.VB = 'E' - new_id.append('E') - if hasattr(self, 'VD'): new_child_dict['D'] = self.child_dict[self.VD] self.VD = 'D' new_id.append('D') + if hasattr(self, 'VB'): + new_child_dict['E'] = self.child_dict[self.VB] + self.VB = 'E' + new_id.append('E') + with warnings.catch_warnings(): warnings.simplefilter('ignore', BiopythonWarning) diff --git a/stcrpy/tcr_processing/TCRParser.py b/stcrpy/tcr_processing/TCRParser.py index 5bfaa29..328cec4 100644 --- a/stcrpy/tcr_processing/TCRParser.py +++ b/stcrpy/tcr_processing/TCRParser.py @@ -69,7 +69,7 @@ def _create_chain(self, chain, new_chain_id, numbering, chain_type): Create a new TCR or MHC chain. Residues before the numbered region are now ignored. """ - if chain_type in ["D", "A", "B", "G"]: + if chain_type in {"A", "B", "G", "D"}: newchain = TCRchain(new_chain_id) elif chain_type in [ "MH1", @@ -334,7 +334,7 @@ def get_tcr_structure( tcrstructure, chain, germline_info ) - if numbering and chain_type in ["G", "D", "B", "A"]: + if numbering and chain_type in {"A", "B", "G", "D"}: # create a new TCR chain newchain = self._create_chain( chain, chain.id, numbering, chain_type @@ -385,10 +385,8 @@ def get_tcr_structure( tcr = abTCR(chain1, chain2) elif not obs_chaintypes - set(["G", "D"]): tcr = gdTCR(chain1, chain2) - elif not obs_chaintypes - set(["B", "D"]): - tcr = abTCR( - chain1, chain2 - ) # initial way to deal with anarci missclassification of alpha chains as delta chains + elif not obs_chaintypes - set(["D", "B"]): + tcr = abTCR(chain1, chain2) # initial way to deal with narci missclassification of alpha chains as delta chains # tcr = dbTCR(chain1, chain2) tcr.scTCR = True # @@ -422,7 +420,7 @@ def get_tcr_structure( tcr = abTCR(pair[0], pair[1]) elif not obs_chaintypes - set(["G", "D"]): tcr = gdTCR(pair[0], pair[1]) - elif not obs_chaintypes - set(["B", "D"]): + elif not obs_chaintypes - set(["D", "B"]): # tcr = dbTCR(pair[0], pair[1]) tcr = abTCR(pair[0], pair[1]) @@ -453,7 +451,7 @@ def get_tcr_structure( tcr = abTCR(pair[0], pair[1]) elif not obs_chaintypes - set(["G", "D"]): tcr = gdTCR(pair[0], pair[1]) - elif not obs_chaintypes - set(["B", "D"]): + elif not obs_chaintypes - set(["D", "B"]): tcr = abTCR(pair[0], pair[1]) # tcr = dbTCR(pair[0], pair[1]) else: diff --git a/stcrpy/tcr_processing/TCRchain.py b/stcrpy/tcr_processing/TCRchain.py index b4be781..c4c1c67 100644 --- a/stcrpy/tcr_processing/TCRchain.py +++ b/stcrpy/tcr_processing/TCRchain.py @@ -5,10 +5,10 @@ from .Fragment import Fragment regions = { - "B": ["fwb1", "cdrb1", "fwb2", "cdrb2", "fwb3", "cdrb3", "fwb4"], "A": ["fwa1", "cdra1", "fwa2", "cdra2", "fwa3", "cdra3", "fwa4"], - "D": ["fwd1", "cdrd1", "fwd2", "cdrd2", "fwd3", "cdrd3", "fwd4"], + "B": ["fwb1", "cdrb1", "fwb2", "cdrb2", "fwb3", "cdrb3", "fwb4"], "G": ["fwg1", "cdrg1", "fwg2", "cdrg2", "fwg3", "cdrg3", "fwg4"], + "D": ["fwd1", "cdrd1", "fwd2", "cdrd2", "fwd3", "cdrd3", "fwd4"], } @@ -55,7 +55,7 @@ def analyse(self, chain_type): def set_chain_type(self, chain_type): """ - Set the chain type to B, A, D, or G + Set the chain type to A, B, G, or D """ self.chain_type = chain_type