diff --git a/synapse/lib/ast.py b/synapse/lib/ast.py index ab89a2c342..ced83a12b3 100644 --- a/synapse/lib/ast.py +++ b/synapse/lib/ast.py @@ -1538,6 +1538,27 @@ def reverseLift(self, astinfo): self.astinfo = astinfo self.reverse = True + def getPivProps(self, runt, name, lookup=False): + + if name.find('::') != -1: + parts = name.split('::') + name, pivs = parts[0], parts[1:] + + prop = runt.model.reqProp(name, extra=self.kids[0].addExcInfo) + if prop.isform: + pivname = pivs.pop(0) + prop = prop.reqProp(pivname, extra=self.kids[0].addExcInfo) + + props = runt.model.getChildProps(prop) + return props, pivs + + if lookup: + props = runt.model.reqPropsByLook(name, extra=self.kids[0].addExcInfo) + else: + props = runt.model.reqPropList(name, extra=self.kids[0].addExcInfo) + + return props, None + def getPivLifts(self, runt, props, pivs): plist = [prop.full for prop in props] virts = [] @@ -1776,16 +1797,11 @@ async def lift(self, runt, path): cmpr = self.kids[1].value() valu = await s_stormtypes.tostor(await self.kids[2].compute(runt, path)) - pivs = None - if name.find('::') != -1: - parts = name.split('::') - name, pivs = parts[0], parts[1:] - - props = runt.model.reqPropList(name, extra=self.kids[0].addExcInfo) + props, pivs = self.getPivProps(runt, name) relname = props[0].name try: - if pivs is not None: + if pivs: if (pivlifts := self.getPivLifts(runt, props, pivs)) is None: pivs.insert(0, relname) @@ -1848,16 +1864,11 @@ async def lift(self, runt, path): cmpr = self.kids[2].value() valu = await s_stormtypes.tostor(await self.kids[3].compute(runt, path)) - pivs = None - if name.find('::') != -1: - parts = name.split('::') - name, pivs = parts[0], parts[1:] - - props = runt.model.reqPropList(name, extra=self.kids[0].addExcInfo) + props, pivs = self.getPivProps(runt, name) relname = props[0].name try: - if pivs is not None: + if pivs: if (pivlifts := self.getPivLifts(runt, props, pivs)) is None: pivs.insert(0, relname) @@ -2269,16 +2280,11 @@ async def lift(self, runt, path): valu = await self.kids[2].compute(runt, path) valu = await s_stormtypes.tostor(valu) - pivs = None - if name.find('::') != -1: - parts = name.split('::') - name, pivs = parts[0], parts[1:] - - props = runt.model.reqPropList(name, extra=self.kids[0].addExcInfo) + props, pivs = self.getPivProps(runt, name) relname = props[0].name try: - if pivs is not None: + if pivs: if (pivlifts := self.getPivLifts(runt, props, pivs)) is None: pivs.insert(0, relname) @@ -2335,16 +2341,11 @@ async def lift(self, runt, path): valu = await self.kids[3].compute(runt, path) valu = await s_stormtypes.tostor(valu) - pivs = None - if name.find('::') != -1: - parts = name.split('::') - name, pivs = parts[0], parts[1:] - - props = runt.model.reqPropsByLook(name, extra=self.kids[0].addExcInfo) + props, pivs = self.getPivProps(runt, name, lookup=True) relname = props[0].name try: - if pivs is not None: + if pivs: if (pivlifts := self.getPivLifts(runt, props, pivs)) is None: pivs.insert(0, relname) diff --git a/synapse/tests/test_cortex.py b/synapse/tests/test_cortex.py index 2efb934afd..6deb847cdc 100644 --- a/synapse/tests/test_cortex.py +++ b/synapse/tests/test_cortex.py @@ -3118,6 +3118,11 @@ async def test_storm_pivprop(self): for node in nodes: self.eq('test:str', node.ndef[0]) + nodes = await core.nodes('test:str::pivvirt::server::proto=tcp') + self.len(1, nodes) + for node in nodes: + self.eq('test:str', node.ndef[0]) + nodes = await core.nodes('test:str:pivvirt::server::proto*in=(tcp, udp)') self.len(2, nodes) for node in nodes: @@ -3128,6 +3133,11 @@ async def test_storm_pivprop(self): for node in nodes: self.eq('test:str', node.ndef[0]) + nodes = await core.nodes('test:str::pivvirt::servers*[=tcp://1.2.3.4]') + self.len(2, nodes) + for node in nodes: + self.eq('test:str', node.ndef[0]) + nodes = await core.nodes('test:str:pivvirt::servers*[in=(tcp://1.2.3.4, udp://1.2.3.4)]') self.len(3, nodes) for node in nodes: @@ -3159,6 +3169,11 @@ async def test_storm_pivprop(self): for node in nodes: self.eq('test:str', node.ndef[0]) + nodes = await core.nodes('test:str::bar::seen.min>2020') + self.len(2, nodes) + for node in nodes: + self.eq('test:str', node.ndef[0]) + await core.nodes('test:guid:seen.min>2021 | delnode') self.len(1, await core.nodes('test:str:bar::seen.min>2020')) @@ -3171,6 +3186,11 @@ async def test_storm_pivprop(self): for node in nodes: self.eq('test:str', node.ndef[0]) + nodes = await core.nodes('test:str::bar::servers*[.ip=1.2.3.4]') + self.len(2, nodes) + for node in nodes: + self.eq('test:str', node.ndef[0]) + # When pivoting through mixed types, don't raise BadTypeValu for incompatible operations # since they could be valid in some cases self.len(0, await core.nodes('test:str:bar::seen*[=tcp]')) diff --git a/synapse/tests/test_lib_ast.py b/synapse/tests/test_lib_ast.py index 1b5a0a9f9c..2ed766c3e2 100644 --- a/synapse/tests/test_lib_ast.py +++ b/synapse/tests/test_lib_ast.py @@ -3006,9 +3006,9 @@ async def checkProp(self, name, reverse=False, virts=None): async for node in origprop(self, name, reverse=reverse, virts=virts): yield node - async def checkValu(self, name, cmpr, valu, reverse=False): + async def checkValu(self, name, cmpr, valu, reverse=False, virts=None): calls.append(('valu', name, cmpr, valu)) - async for node in origvalu(self, name, cmpr, valu, reverse=reverse): + async for node in origvalu(self, name, cmpr, valu, reverse=reverse, virts=virts): yield node with mock.patch('synapse.lib.view.View.nodesByProp', checkProp): @@ -3125,6 +3125,23 @@ async def checkValu(self, name, cmpr, valu, reverse=False): self.stormHasNoWarnErr(msgs) self.len(0, calls) + await core.nodes('[test:int=1 test:int=2 :type=foo]') + self.len(2, await core.nodes('test:int::type=foo')) + + self.eq(calls, [ + ('valu', 'test:int:type', '=', 'foo') + ]) + + await core.nodes('[test:str=foo :somestr=bar]') + calls = [] + + self.len(2, await core.nodes('test:int::type::somestr=bar')) + self.eq(calls, [ + ('valu', 'test:str2:somestr', '=', 'bar'), + ('valu', 'test:str:somestr', '=', 'bar'), + ('valu', 'test:int:type', '=', 'foo') + ]) + async def test_ast_tag_optimization(self): calls = [] origtag = s_view.View.nodesByTag