@@ -20,18 +20,24 @@ import scala.collection.mutable.ArrayBuffer
2020
2121import io .fabric8 .kubernetes .api .model .{Pod , PodSpec , PodStatus }
2222import org .mockito .Mockito ._
23+ import org .scalatest .BeforeAndAfter
2324
2425import org .apache .spark .{SparkContext , SparkFunSuite }
26+ import org .apache .spark .deploy .kubernetes .config ._
2527import org .apache .spark .scheduler .{FakeTask , FakeTaskScheduler , HostTaskLocation , TaskLocation }
2628
27- class KubernetesTaskSetManagerSuite extends SparkFunSuite {
29+ class KubernetesTaskSetManagerSuite extends SparkFunSuite with BeforeAndAfter {
2830
2931 val sc = new SparkContext (" local" , " test" )
3032 val sched = new FakeTaskScheduler (sc,
3133 (" execA" , " 10.0.0.1" ), (" execB" , " 10.0.0.2" ), (" execC" , " 10.0.0.3" ))
3234 val backend = mock(classOf [KubernetesClusterSchedulerBackend ])
3335 sched.backend = backend
3436
37+ before {
38+ sc.conf.remove(KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED )
39+ }
40+
3541 test(" Find pending tasks for executors using executor pod IP addresses" ) {
3642 val taskSet = FakeTask .createTaskSet(3 ,
3743 Seq (TaskLocation (" 10.0.0.1" , " execA" )), // Task 0 runs on executor pod 10.0.0.1.
@@ -76,7 +82,33 @@ class KubernetesTaskSetManagerSuite extends SparkFunSuite {
7682 assert(manager.getPendingTasksForHost(" 10.0.0.1" ) == ArrayBuffer (1 , 0 ))
7783 }
7884
85+ test(" Test DNS lookup is disabled by default for cluster node full hostnames" ) {
86+ assert(! sc.conf.get(KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED ))
87+ }
88+
89+ test(" Find pending tasks for executors, but avoid looking up cluster node FQDNs from DNS" ) {
90+ sc.conf.set(KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED , false )
91+ val taskSet = FakeTask .createTaskSet(2 ,
92+ Seq (HostTaskLocation (" kube-node1.domain1" )), // Task 0's partition belongs to datanode here.
93+ Seq (HostTaskLocation (" kube-node1.domain1" )) // task 1's partition belongs to datanode here.
94+ )
95+ val spec1 = mock(classOf [PodSpec ])
96+ when(spec1.getNodeName).thenReturn(" kube-node1" )
97+ val pod1 = mock(classOf [Pod ])
98+ when(pod1.getSpec).thenReturn(spec1)
99+ val status1 = mock(classOf [PodStatus ])
100+ when(status1.getHostIP).thenReturn(" 196.0.0.5" )
101+ when(pod1.getStatus).thenReturn(status1)
102+ val inetAddressUtil = mock(classOf [InetAddressUtil ])
103+ when(inetAddressUtil.getFullHostName(" 196.0.0.5" )).thenReturn(" kube-node1.domain1" )
104+ when(backend.getExecutorPodByIP(" 10.0.0.1" )).thenReturn(Some (pod1))
105+
106+ val manager = new KubernetesTaskSetManager (sched, taskSet, maxTaskFailures = 2 , inetAddressUtil)
107+ assert(manager.getPendingTasksForHost(" 10.0.0.1" ) == ArrayBuffer ())
108+ }
109+
79110 test(" Find pending tasks for executors using cluster node FQDNs that executor pods run on" ) {
111+ sc.conf.set(KUBERNETES_DRIVER_CLUSTER_NODENAME_DNS_LOOKUP_ENABLED , true )
80112 val taskSet = FakeTask .createTaskSet(2 ,
81113 Seq (HostTaskLocation (" kube-node1.domain1" )), // Task 0's partition belongs to datanode here.
82114 Seq (HostTaskLocation (" kube-node1.domain1" )) // task 1's partition belongs to datanode here.
0 commit comments