Skip to content

Commit f4a9216

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
[PyTorch, Mobile] Serialization format change for source range (pytorch#54284)
Summary: Pull Request resolved: pytorch#54284 In order to bring mobile deployment, via lite interpreter, on feature parity with JIT, with respect model level debug information we must make model level debug information available to mobile runtime. At the moment, model level debug information is stored in SourceRange which associates node's of graph to where the come from in original python source code. This information is serialized as part of debug_pkl and deserialized when JIT loads the model and reads the model code. On lite interpreter, we do not have access to all the functionality of JIT and hence we cannot load model in the same way as JIT, by reading code, constructing module hierarchy and graph corresponding module methods etc. Instead in, lite interpreter, only bytecode corresonding to the compiled graph, Code, is saved. Thus in order to annotate OPs in the bytecode with equivalent SourceRange information we do the following: 1. During model serialization, we create a unique tag for each source range of the model. 2. Create a map of <SourceRange, tag> 3. During debug_pkl serialization we save tag along with SourceRange, on top of byte offset. 4. During bytecode generation, the methods of the top module are lowered. During this process methods are inlined. In the inlined graph, when the node of a graph is lowered to bytecode, we query node's source range and look it up against the map. 5. Resulting source range tag is serialized in module_debug_info. 6. During model deserialization, we read all the debug_pkl records in the archieve and create a map of <tag, SourceRange> 7. This map can be used to find source code information. During mobile runtime: 1. We read all the debug_pkl records and create <tag=debug_handle, SourceRange> map. 1.1 This map, MobileDebugInfo, is a member of mobile Module. 2. Interpreter catches appropriate exceptions and sets the thread local debug handle and rethrows the exception. 3. In Function's run method we catch exception and query current debug handle where the exception happened. 4. Query MobileDebugInfo with debug handle to retrieve source range and augment error with source range info. This information is still incomplete as it does not contain entire callstack. In the following diffs we will serialize InlinedCallStack directly. Note that compilation is gated by SYMBOLICATE_MOBILE_DEBUG_HANDLE macro, so that mobile builds can avoid building MobileDebugInfo, source range and source range pickler/unpickler. Later we will add path where, if building without debug support stack trace will contain only debug handles. They can be symbolicated later. Test Plan: Ported bunch of source range tests from test_jit.py. Added on more test in test_lite_interpreter.py Imported from OSS Reviewed By: raziel Differential Revision: D27174722 fbshipit-source-id: a7b7c6088ce16dec37e823c7fefa4f0b61047e12
1 parent aa5ff7c commit f4a9216

20 files changed

+626
-240
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ if(NOT DEFINED USE_VULKAN)
254254
"ANDROID" OFF)
255255
endif()
256256

257+
option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON)
257258
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
258259
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
259260
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
@@ -647,6 +648,10 @@ if(USE_PYTORCH_METAL)
647648
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL")
648649
endif()
649650

651+
if(USE_SOURCE_DEBUG_ON_MOBILE)
652+
string(APPEND CMAKE_CXX_FLAGS " -DSYMBOLICATE_MOBILE_DEBUG_HANDLE")
653+
endif()
654+
650655
# ---[ Allowlist file if allowlist is specified
651656
include(cmake/Allowlist.cmake)
652657

caffe2/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,22 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
516516
list(APPEND TORCH_SRCS ${GENERATED_H_TORCH})
517517
list(APPEND LIBTORCH_CMAKE_SRCS "")
518518

519+
list(APPEND LITE_EAGER_SYMOBLICATION_SRCS "")
520+
if(USE_SOURCE_DEBUG_ON_MOBILE)
521+
append_filelist("libtorch_lite_eager_symbolication" LITE_EAGER_SYMOBLICATION_SRCS)
522+
# For source debug on lite interpreter, we have to add dependency on pickling
523+
# but references to read/writeArchiveAndTensor is not built for mobile
524+
# so this condition specifically says we are building for source debug
525+
# on mobile.
526+
if(BUILD_LITE_INTERPRETER)
527+
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/serialization/pickle.cpp PROPERTIES COMPILE_FLAGS "-DC10_MOBILE -DFEATURE_TORCH_MOBILE")
528+
endif()
529+
endif()
530+
519531
# Switch between the full jit interpreter and lite interpreter
520532
if(BUILD_LITE_INTERPRETER)
521533
append_filelist("libtorch_lite_cmake_sources" LIBTORCH_CMAKE_SRCS)
534+
list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS})
522535
else()
523536
append_filelist("libtorch_cmake_sources" LIBTORCH_CMAKE_SRCS)
524537

@@ -565,6 +578,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
565578
${TORCH_SRC_DIR}/csrc/jit/mobile/sequential.cpp
566579
)
567580
list(APPEND TORCH_SRCS ${MOBILE_SRCS})
581+
list(APPEND TORCH_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS})
568582
endif()
569583

570584
# This one needs to be unconditionally added as Functions.cpp is also unconditionally added

test/mobile/test_lite_script_module.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,27 @@
44
import io
55
from typing import Dict, List, NamedTuple
66
from collections import namedtuple
7+
import inspect
78

89
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
910
from torch.testing._internal.common_utils import TestCase, run_tests
1011

1112
class TestLiteScriptModule(TestCase):
1213

14+
def getScriptExportImportCopy(self, m, save_mobile_debug_info=True, also_test_file=False):
15+
m_scripted = torch.jit.script(m)
16+
17+
if not also_test_file:
18+
buffer = io.BytesIO(m_scripted._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=save_mobile_debug_info))
19+
buffer.seek(0)
20+
mobile_module = _load_for_lite_interpreter(buffer)
21+
return mobile_module
22+
23+
with TemporaryFileName() as fname:
24+
m_scripted._save_for_lite_interpreter(fname, _save_mobile_debug_info=save_mobile_debug_info)
25+
mobile_module = _load_for_lite_interpreter(fname)
26+
return mobile_module
27+
1328
def test_load_mobile_module(self):
1429
class MyTestModule(torch.nn.Module):
1530
def __init__(self):
@@ -374,5 +389,69 @@ def forward(self, input):
374389
actual_ops = _export_operator_list(mobile_module)
375390
self.assertEqual(actual_ops, expected_ops)
376391

392+
def test_source_range_simple(self):
393+
394+
class FooTest(torch.jit.ScriptModule):
395+
@torch.jit.script_method
396+
def forward(self, x, w):
397+
return torch.mm(x, w.t())
398+
399+
ft = FooTest()
400+
loaded = self.getScriptExportImportCopy(ft)
401+
_, lineno = inspect.getsourcelines(FooTest)
402+
403+
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
404+
loaded(torch.rand(3, 4), torch.rand(30, 40))
405+
406+
def test_source_range_raise_exception(self):
407+
408+
class FooTest2(torch.jit.ScriptModule):
409+
@torch.jit.script_method
410+
def forward(self):
411+
raise RuntimeError('foo')
412+
413+
_, lineno = inspect.getsourcelines(FooTest2)
414+
415+
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
416+
ft = FooTest2()
417+
loaded = self.getScriptExportImportCopy(ft)
418+
loaded()
419+
420+
def test_source_range_function_call(self):
421+
class FooTest3(torch.jit.ScriptModule):
422+
@torch.jit.script_method
423+
def add_method(self, x, w):
424+
return x + w
425+
426+
@torch.jit.script_method
427+
def forward(self, x, y, w):
428+
x = x * y
429+
x = x + 2
430+
return self.add_method(x, w)
431+
432+
ft = FooTest3()
433+
loaded = self.getScriptExportImportCopy(ft)
434+
_, lineno = inspect.getsourcelines(FooTest3)
435+
436+
with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
437+
loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
438+
439+
def test_source_range_no_debug_info(self):
440+
441+
class FooTest4(torch.jit.ScriptModule):
442+
@torch.jit.script_method
443+
def forward(self, x, w):
444+
return torch.mm(x, w.t())
445+
446+
ft = FooTest4()
447+
loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False)
448+
449+
try:
450+
loaded(torch.rand(3, 4), torch.rand(30, 40))
451+
except RuntimeError as e:
452+
error_message = f"{e}"
453+
self.assertTrue("test_lite_script_module.py" not in error_message)
454+
455+
377456
if __name__ == '__main__':
378457
run_tests()

test/test_jit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4182,8 +4182,8 @@ def debug_records_from_mod(mod):
41824182
debug_files = debug_records_from_mod(ft3)
41834183
for debug_file in debug_files:
41844184
for i in range(len(debug_file) - 1):
4185-
offset, source_range = debug_file[i]
4186-
offset2, source_range2 = debug_file[i + 1]
4185+
offset, source_range_tag, source_range = debug_file[i]
4186+
offset2, source_range_tag2, source_range2 = debug_file[i + 1]
41874187
self.assertNotEqual(source_range, source_range2)
41884188

41894189
def test_circular_dependency(self):

tools/build_variables.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,17 @@ torch_mobile_core = [
356356
"torch/csrc/jit/runtime/register_special_ops.cpp",
357357
]
358358

359+
libtorch_lite_eager_symbolication = [
360+
"torch/csrc/jit/frontend/source_range.cpp",
361+
"torch/csrc/jit/mobile/debug_info.cpp",
362+
"torch/csrc/jit/serialization/source_range_serialization.cpp",
363+
# Later we can split serialization and deserialization logic
364+
# to have better separation within build and only build relevant parts.
365+
"torch/csrc/jit/serialization/pickle.cpp",
366+
"torch/csrc/jit/serialization/pickler.cpp",
367+
"torch/csrc/jit/serialization/unpickler.cpp",
368+
]
369+
359370
# TODO: core_trainer_sources is not necessary for libtorch lite
360371
libtorch_lite_cmake_sources = sorted(core_trainer_sources + core_sources_common + torch_mobile_core)
361372

@@ -368,6 +379,9 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
368379
"torch/csrc/jit/api/module_save.cpp",
369380
"torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp",
370381
"torch/csrc/jit/mobile/export_data.cpp",
382+
# To be included for eager symbolication in lite interpreter
383+
# when it is built in libtorch
384+
"torch/csrc/jit/mobile/debug_info.cpp",
371385
"torch/csrc/jit/mobile/function.cpp",
372386
"torch/csrc/jit/mobile/import.cpp",
373387
"torch/csrc/jit/mobile/import_data.cpp",

torch/csrc/jit/frontend/source_range.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33

44
namespace torch {
55
namespace jit {
6+
size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
7+
return (
8+
std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^
9+
std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
10+
}
611

712
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
813
const SourceRange& range) {

torch/csrc/jit/frontend/source_range.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <algorithm>
66
#include <iostream>
77
#include <memory>
8+
#include <unordered_map>
89
namespace torch {
910
namespace jit {
1011

@@ -178,6 +179,11 @@ struct TORCH_API SourceRange {
178179
size_t end_;
179180
};
180181

182+
struct SourceRangeHasher {
183+
public:
184+
size_t operator()(const torch::jit::SourceRange& key) const;
185+
};
186+
181187
struct StackEntry {
182188
std::string filename;
183189
SourceRange range;
@@ -201,6 +207,8 @@ struct TaggedRange {
201207
SourceRange range;
202208
};
203209
using SourceRangeRecords = std::vector<TaggedRange>;
210+
using SourceRangeTagMap =
211+
std::unordered_map<SourceRange, int64_t, SourceRangeHasher>;
204212

205213
} // namespace jit
206214
} // namespace torch

torch/csrc/jit/mobile/debug_info.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include <torch/csrc/jit/mobile/debug_info.h>
2+
#include <torch/csrc/jit/serialization/source_range_serialization.h>
3+
4+
#include <ATen/core/ivalue.h>
5+
#include <torch/csrc/jit/serialization/pickle.h>
6+
7+
#include <c10/util/string_view.h>
8+
9+
namespace torch {
10+
namespace jit {
11+
12+
MobileDebugTable::MobileDebugTable(
13+
std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader) {
14+
const std::vector<std::string>& record_names = reader->getAllRecords();
15+
const c10::string_view suffix(".debug_pkl");
16+
for (const auto& record_name : record_names) {
17+
if (c10::string_view(record_name).ends_with(suffix)) {
18+
at::DataPtr debug_data;
19+
size_t debug_size{0};
20+
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
21+
auto ivalues =
22+
jit::unpickle(
23+
reinterpret_cast<const char*>(debug_data.get()), debug_size)
24+
.toTuple()
25+
->elements();
26+
SourceRangeDeserializer deserializer;
27+
for (auto& val : ivalues) {
28+
auto tup_elems = val.toTuple()->elements();
29+
// For BC we decode only tuples with 3 elements
30+
// assuming it contains
31+
// byte_offset, debug_handle (=source range tag), source range
32+
if (tup_elems.size() == 3) {
33+
int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt();
34+
auto source_range =
35+
deserializer.deserialize(tup_elems[kSourceRangeIndex]);
36+
source_range_map_.emplace(debug_handle, std::move(source_range));
37+
}
38+
}
39+
}
40+
}
41+
}
42+
43+
std::string MobileDebugTable::getSourceDebugString(
44+
const int64_t debug_handle) const {
45+
const auto it = source_range_map_.find(debug_handle);
46+
if (it == source_range_map_.end()) {
47+
return "";
48+
}
49+
return source_range_map_.at(debug_handle).str();
50+
}
51+
52+
} // namespace jit
53+
} // namespace torch

torch/csrc/jit/mobile/debug_info.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
#include <c10/util/flat_hash_map.h>
3+
#include <caffe2/serialize/inline_container.h>
4+
#include <torch/csrc/jit/frontend/source_range.h>
5+
6+
namespace torch {
7+
namespace jit {
8+
/*
9+
* MobileDebugTable:
10+
* Deserializes debug_pkl records from PT model's zip archive and
11+
* stores them in a map of debug handles to source range.
12+
* Debug handles are unique per model and runtime, be in lite interpreter
13+
* or delegate, raises exception using debug handles.
14+
* getSourceDebugString method is responsible for translating debug
15+
* handles to correspond debug information.
16+
* At the moment this only contains information about model source.
17+
* But later diffs will include entire stack corresponding to debug handle.
18+
*/
19+
class MobileDebugTable {
20+
public:
21+
MobileDebugTable() = default;
22+
MobileDebugTable(
23+
std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader);
24+
std::string getSourceDebugString(const int64_t debug_handle) const;
25+
26+
private:
27+
ska::flat_hash_map<int64_t, SourceRange> source_range_map_;
28+
};
29+
30+
} // namespace jit
31+
} // namespace torch

torch/csrc/jit/mobile/function.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ const std::string& Function::name() const {
2222
return name_.name();
2323
}
2424

25-
void Function::append_instruction(OpCode op, int X, int N) {
25+
void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
2626
TORCH_CHECK(
2727
isOpSupportedInMobile(op),
2828
toString(op),
2929
" is not supported in mobile module.");
3030
code_->instructions_.emplace_back(op, X, N);
31+
code_->debug_handles_.emplace_back(dbg_handle);
3132
}
3233

3334
bool Function::append_operator(
@@ -130,6 +131,11 @@ const std::shared_ptr<Code> Function::get_code() const {
130131
return code_;
131132
}
132133

134+
int64_t Function::getExceptionDebugHandle() const {
135+
size_t pc = getInterpretersExceptionPC();
136+
return (pc < code_->debug_handles_.size()) ? code_->debug_handles_[pc] : -1;
137+
}
138+
133139
} // namespace mobile
134140
} // namespace jit
135141
} // namespace torch

0 commit comments

Comments
 (0)