Skip to content

Commit 30c3f9a

Browse files
Jiayu Yetensorflower-gardener
Jiayu Ye
authored andcommitted
Internal change
PiperOrigin-RevId: 480378058
1 parent 0de1746 commit 30c3f9a

5 files changed

+84
-6
lines changed

orbit/controller.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pprint
1818
import time
1919

20-
from typing import Any, Callable, Iterable, Optional, Union
20+
from typing import Callable, Iterable, Optional, Union
2121

2222
from absl import logging
2323

@@ -101,8 +101,8 @@ def __init__(
101101
summary_dir: Optional[str] = None,
102102
# Evaluation related
103103
eval_summary_dir: Optional[str] = None,
104-
summary_manager: Optional[Any] = None,
105-
eval_summary_manager: Optional[Any] = None):
104+
summary_manager: Optional[utils.SummaryManagerInterface] = None,
105+
eval_summary_manager: Optional[utils.SummaryManagerInterface] = None):
106106
"""Initializes a `Controller` instance.
107107
108108
Note that if `checkpoint_manager` is provided and there are checkpoints in

orbit/controller_test.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from orbit import controller
2525
from orbit import runner
2626
from orbit import standard_runner
27+
import orbit.utils
2728

2829
import tensorflow as tf
2930

@@ -698,12 +699,22 @@ def test_eval_and_checkpoint_interval(self):
698699
self.assertLen(
699700
summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
700701

701-
def test_evaluate_with_nested_summaries(self):
702+
@parameterized.named_parameters(("DefaultSummary", False),
703+
("InjectSummary", True))
704+
def test_evaluate_with_nested_summaries(self, inject_summary_manager):
702705
test_evaluator = TestEvaluatorWithNestedSummary()
706+
if inject_summary_manager:
707+
summary_manager = orbit.utils.SummaryManager(
708+
self.model_dir,
709+
tf.summary.scalar,
710+
global_step=tf.Variable(0, dtype=tf.int64))
711+
else:
712+
summary_manager = None
703713
test_controller = controller.Controller(
704714
evaluator=test_evaluator,
705715
global_step=tf.Variable(0, dtype=tf.int64),
706-
eval_summary_dir=self.model_dir)
716+
eval_summary_dir=self.model_dir,
717+
summary_manager=summary_manager)
707718
test_controller.evaluate(steps=5)
708719

709720
self.assertNotEmpty(

orbit/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@
2525
from orbit.utils.loop_fns import LoopFnWithSummaries
2626

2727
from orbit.utils.summary_manager import SummaryManager
28+
from orbit.utils.summary_manager_interface import SummaryManagerInterface
2829

2930
from orbit.utils.tpu_summaries import OptionalSummariesFunction

orbit/utils/summary_manager.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
import os
1818

19+
from orbit.utils.summary_manager_interface import SummaryManagerInterface
20+
1921
import tensorflow as tf
2022

2123

22-
class SummaryManager:
24+
class SummaryManager(SummaryManagerInterface):
2325
"""A utility class for managing summary writing."""
2426

2527
def __init__(self, summary_dir, summary_fn, global_step=None):
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2022 The Orbit Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Provides a utility class for managing summary writing."""
16+
17+
import abc
18+
19+
20+
class SummaryManagerInterface(abc.ABC):
21+
"""A utility interface for managing summary writing."""
22+
23+
@abc.abstractmethod
24+
def flush(self):
25+
"""Flushes the the recorded summaries."""
26+
raise NotImplementedError
27+
28+
@abc.abstractmethod
29+
def summary_writer(self, relative_path=""):
30+
"""Returns the underlying summary writer for scoped writers."""
31+
raise NotImplementedError
32+
33+
@abc.abstractmethod
34+
def write_summaries(self, summary_dict):
35+
"""Writes summaries for the given dictionary of values.
36+
37+
The summary_dict can be any nested dict. The SummaryManager should
38+
recursively creates summaries, yielding a hierarchy of summaries which will
39+
then be reflected in the corresponding UIs.
40+
41+
For example, users may evaluate on multiple datasets and return
42+
`summary_dict` as a nested dictionary:
43+
44+
{
45+
"dataset1": {
46+
"loss": loss1,
47+
"accuracy": accuracy1
48+
},
49+
"dataset2": {
50+
"loss": loss2,
51+
"accuracy": accuracy2
52+
},
53+
}
54+
55+
This will create two set of summaries, "dataset1" and "dataset2". Each
56+
summary dict will contain summaries including both "loss" and "accuracy".
57+
58+
Args:
59+
summary_dict: A dictionary of values. If any value in `summary_dict` is
60+
itself a dictionary, then the function will create a new summary_dict
61+
with name given by the corresponding key. This is performed recursively.
62+
Leaf values are then summarized using the parent relative path.
63+
"""
64+
raise NotImplementedError

0 commit comments

Comments
 (0)