diff --git a/tests/test_focus.py b/tests/test_focus.py
index a23046bc94..8ae484836d 100644
--- a/tests/test_focus.py
+++ b/tests/test_focus.py
@@ -1,8 +1,5 @@
-import pytest
-
 from textual.app import App, ComposeResult
 from textual.containers import Container, ScrollableContainer
-from textual.screen import Screen
 from textual.widget import Widget
 from textual.widgets import Button, Label
 
@@ -19,34 +16,23 @@ class ChildrenFocusableOnly(Widget, can_focus=False, can_focus_children=True):
     pass
 
 
-@pytest.fixture
-def screen() -> Screen:
-    app = App()
-
-    with app._context():
-        app.push_screen(Screen())
-
-        screen = app.screen
+class FocusTestApp(App):
+    AUTO_FOCUS = None
 
+    def compose(self) -> ComposeResult:
         # The classes even/odd alternate along the focus chain.
         # The classes in/out identify nested widgets.
-        screen._add_children(
-            Focusable(id="foo", classes="a"),
-            NonFocusable(id="bar"),
-            Focusable(Focusable(id="Paul", classes="c"), id="container1", classes="b"),
-            NonFocusable(Focusable(id="Jessica", classes="a"), id="container2"),
-            Focusable(id="baz", classes="b"),
-            ChildrenFocusableOnly(Focusable(id="child", classes="c")),
-        )
+        yield Focusable(id="foo", classes="a")
+        yield NonFocusable(id="bar")
+        yield Focusable(Focusable(id="Paul", classes="c"), id="container1", classes="b")
+        yield NonFocusable(Focusable(id="Jessica", classes="a"), id="container2")
+        yield Focusable(id="baz", classes="b")
+        yield ChildrenFocusableOnly(Focusable(id="child", classes="c"))
 
-        return screen
 
-
-def test_focus_chain():
+async def test_focus_chain():
     app = App()
-    with app._context():
-        app.push_screen(Screen())
-
+    async with app.run_test():
         screen = app.screen
 
         # Check empty focus chain
@@ -65,7 +51,7 @@ def test_focus_chain():
         assert focus_chain == ["foo", "container1", "Paul", "baz", "child"]
 
 
-def test_allow_focus():
+async def test_allow_focus():
     """Test allow_focus and allow_focus_children are called and the result used."""
     focusable_allow_focus_called = False
     non_focusable_allow_focus_called = False
@@ -92,9 +78,7 @@ def allow_focus_children(self) -> bool:
 
     app = App()
 
-    with app._context():
-        app.push_screen(Screen())
-
+    async with app.run_test():
         app.screen._add_children(
             Focusable(id="foo"),
             NonFocusable(id="bar"),
@@ -106,93 +90,119 @@ def allow_focus_children(self) -> bool:
         assert non_focusable_allow_focus_called
 
 
-def test_focus_next_and_previous(screen: Screen):
-    assert screen.focus_next().id == "foo"
-    assert screen.focus_next().id == "container1"
-    assert screen.focus_next().id == "Paul"
-    assert screen.focus_next().id == "baz"
-    assert screen.focus_next().id == "child"
+async def test_focus_next_and_previous():
+    app = FocusTestApp()
+    async with app.run_test():
+        screen = app.screen
+
+        assert screen.focus_next().id == "foo"
+        assert screen.focus_next().id == "container1"
+        assert screen.focus_next().id == "Paul"
+        assert screen.focus_next().id == "baz"
+        assert screen.focus_next().id == "child"
 
-    assert screen.focus_previous().id == "baz"
-    assert screen.focus_previous().id == "Paul"
-    assert screen.focus_previous().id == "container1"
-    assert screen.focus_previous().id == "foo"
+        assert screen.focus_previous().id == "baz"
+        assert screen.focus_previous().id == "Paul"
+        assert screen.focus_previous().id == "container1"
+        assert screen.focus_previous().id == "foo"
 
 
-def test_focus_next_wrap_around(screen: Screen):
+async def test_focus_next_wrap_around():
     """Ensure focusing the next widget wraps around the focus chain."""
-    screen.set_focus(screen.query_one("#child"))
-    assert screen.focused.id == "child"
+    app = FocusTestApp()
+    async with app.run_test():
+        screen = app.screen
 
-    assert screen.focus_next().id == "foo"
+        screen.set_focus(screen.query_one("#child"))
+        assert screen.focused.id == "child"
 
+        assert screen.focus_next().id == "foo"
 
-def test_focus_previous_wrap_around(screen: Screen):
+
+async def test_focus_previous_wrap_around():
     """Ensure focusing the previous widget wraps around the focus chain."""
-    screen.set_focus(screen.query_one("#foo"))
-    assert screen.focused.id == "foo"
+    app = FocusTestApp()
+    async with app.run_test():
+        screen = app.screen
+
+        screen.set_focus(screen.query_one("#foo"))
+        assert screen.focused.id == "foo"
 
-    assert screen.focus_previous().id == "child"
+        assert screen.focus_previous().id == "child"
 
 
-def test_wrap_around_selector(screen: Screen):
+async def test_wrap_around_selector():
     """Ensure moving focus in both directions wraps around the focus chain."""
-    screen.set_focus(screen.query_one("#foo"))
-    assert screen.focused.id == "foo"
+    app = FocusTestApp()
+    async with app.run_test():
+        screen = app.screen
+
+        screen.set_focus(screen.query_one("#foo"))
+        assert screen.focused.id == "foo"
 
-    assert screen.focus_previous("#Paul").id == "Paul"
-    assert screen.focus_next("#foo").id == "foo"
+        assert screen.focus_previous("#Paul").id == "Paul"
+        assert screen.focus_next("#foo").id == "foo"
 
 
-def test_no_focus_empty_selector(screen: Screen):
+async def test_no_focus_empty_selector():
     """Ensure focus is cleared when selector matches nothing."""
-    assert screen.focus_next("#bananas") is None
-    assert screen.focus_previous("#bananas") is None
+    app = FocusTestApp()
+    async with app.run_test():
+        screen = app.screen
+
+        assert screen.focus_next("#bananas") is None
+        assert screen.focus_previous("#bananas") is None
 
-    screen.set_focus(screen.query_one("#foo"))
-    assert screen.focused is not None
-    assert screen.focus_next("#bananas") is None
-    assert screen.focused is None
+        screen.set_focus(screen.query_one("#foo"))
+        assert screen.focused is not None
+        assert screen.focus_next("#bananas") is None
+        assert screen.focused is None
 
-    screen.set_focus(screen.query_one("#foo"))
-    assert screen.focused is not None
-    assert screen.focus_previous("#bananas") is None
-    assert screen.focused is None
+        screen.set_focus(screen.query_one("#foo"))
+        assert screen.focused is not None
+        assert screen.focus_previous("#bananas") is None
+        assert screen.focused is None
 
 
-def test_focus_next_and_previous_with_type_selector(screen: Screen):
+async def test_focus_next_and_previous_with_type_selector():
     """Move focus with a selector that matches the currently focused node."""
-    screen.set_focus(screen.query_one("#Paul"))
-    assert screen.focused.id == "Paul"
+    app = FocusTestApp()
+    async with app.run_test():
+        screen = app.screen
 
-    assert screen.focus_next(Focusable).id == "baz"
-    assert screen.focus_next(Focusable).id == "child"
+        screen.set_focus(screen.query_one("#Paul"))
+        assert screen.focused.id == "Paul"
 
-    assert screen.focus_previous(Focusable).id == "baz"
-    assert screen.focus_previous(Focusable).id == "Paul"
-    assert screen.focus_previous(Focusable).id == "container1"
-    assert screen.focus_previous(Focusable).id == "foo"
+        assert screen.focus_next(Focusable).id == "baz"
+        assert screen.focus_next(Focusable).id == "child"
 
+        assert screen.focus_previous(Focusable).id == "baz"
+        assert screen.focus_previous(Focusable).id == "Paul"
+        assert screen.focus_previous(Focusable).id == "container1"
+        assert screen.focus_previous(Focusable).id == "foo"
 
-def test_focus_next_and_previous_with_str_selector(screen: Screen):
+
+async def test_focus_next_and_previous_with_str_selector():
     """Move focus with a selector that matches the currently focused node."""
-    screen.set_focus(screen.query_one("#foo"))
-    assert screen.focused.id == "foo"
+    app = FocusTestApp()
+    async with app.run_test():
+        screen = app.screen
 
-    assert screen.focus_next(".a").id == "foo"
-    assert screen.focus_next(".c").id == "Paul"
-    assert screen.focus_next(".c").id == "child"
+        screen.set_focus(screen.query_one("#foo"))
+        assert screen.focused.id == "foo"
 
-    assert screen.focus_previous(".c").id == "Paul"
-    assert screen.focus_previous(".a").id == "foo"
+        assert screen.focus_next(".a").id == "foo"
+        assert screen.focus_next(".c").id == "Paul"
+        assert screen.focus_next(".c").id == "child"
 
+        assert screen.focus_previous(".c").id == "Paul"
+        assert screen.focus_previous(".a").id == "foo"
 
-def test_focus_next_and_previous_with_type_selector_without_self():
+
+async def test_focus_next_and_previous_with_type_selector_without_self():
     """Test moving the focus with a selector that does not match the currently focused node."""
     app = App()
-    with app._context():
-        app.push_screen(Screen())
-
+    async with app.run_test():
         screen = app.screen
 
         from textual.containers import Horizontal, VerticalScroll
@@ -233,18 +243,22 @@ def test_focus_next_and_previous_with_type_selector_without_self():
         assert screen.focus_previous(Input).id == "w5"
 
 
-def test_focus_next_and_previous_with_str_selector_without_self(screen: Screen):
+async def test_focus_next_and_previous_with_str_selector_without_self():
     """Test moving the focus with a selector that does not match the currently focused node."""
-    screen.set_focus(screen.query_one("#foo"))
-    assert screen.focused.id == "foo"
+    app = FocusTestApp()
+    async with app.run_test():
+        screen = app.screen
+
+        screen.set_focus(screen.query_one("#foo"))
+        assert screen.focused.id == "foo"
 
-    assert screen.focus_next(".c").id == "Paul"
-    assert screen.focus_next(".b").id == "baz"
-    assert screen.focus_next(".c").id == "child"
+        assert screen.focus_next(".c").id == "Paul"
+        assert screen.focus_next(".b").id == "baz"
+        assert screen.focus_next(".c").id == "child"
 
-    assert screen.focus_previous(".a").id == "foo"
-    assert screen.focus_previous(".a").id == "foo"
-    assert screen.focus_previous(".b").id == "baz"
+        assert screen.focus_previous(".a").id == "foo"
+        assert screen.focus_previous(".a").id == "foo"
+        assert screen.focus_previous(".b").id == "baz"
 
 
 async def test_focus_does_not_move_to_invisible_widgets():