Skip to content

Commit 9f2bc59

Browse files
authored
Add a sync feature to common, core, and tensor (#893)
1 parent d021c7d commit 9f2bc59

File tree

11 files changed

+42
-34
lines changed

11 files changed

+42
-34
lines changed

burn-common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ default = ["std"]
1515

1616
std = ["rand/std"]
1717

18+
wasm-sync = []
1819

1920
[target.'cfg(target_family = "wasm")'.dependencies]
2021
async-trait = { workspace = true }

burn-common/src/reader.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use alloc::boxed::Box;
22
use core::marker::PhantomData;
33

4-
#[cfg(target_family = "wasm")]
4+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
55
#[async_trait::async_trait]
66
/// Allows to create async reader.
77
pub trait AsyncReader<T>: Send {
@@ -15,10 +15,10 @@ pub enum Reader<T> {
1515
Concrete(T),
1616
/// Sync data variant.
1717
Sync(Box<dyn SyncReader<T>>),
18-
#[cfg(target_family = "wasm")]
18+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
1919
/// Async data variant.
2020
Async(Box<dyn AsyncReader<T>>),
21-
#[cfg(target_family = "wasm")]
21+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
2222
/// Future data variant.
2323
Future(core::pin::Pin<Box<dyn core::future::Future<Output = T> + Send>>),
2424
}
@@ -52,7 +52,7 @@ where
5252
}
5353
}
5454

55-
#[cfg(target_family = "wasm")]
55+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
5656
#[async_trait::async_trait]
5757
impl<I, O, F> AsyncReader<O> for MappedReader<I, O, F>
5858
where
@@ -67,7 +67,7 @@ where
6767
}
6868

6969
impl<T> Reader<T> {
70-
#[cfg(target_family = "wasm")]
70+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
7171
/// Read the data.
7272
pub async fn read(self) -> T {
7373
match self {
@@ -78,7 +78,7 @@ impl<T> Reader<T> {
7878
}
7979
}
8080

81-
#[cfg(not(target_family = "wasm"))]
81+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
8282
/// Read the data.
8383
pub fn read(self) -> T {
8484
match self {
@@ -92,9 +92,9 @@ impl<T> Reader<T> {
9292
match self {
9393
Self::Concrete(data) => Some(data),
9494
Self::Sync(reader) => Some(reader.read()),
95-
#[cfg(target_family = "wasm")]
95+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
9696
Self::Async(_func) => return None,
97-
#[cfg(target_family = "wasm")]
97+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
9898
Self::Future(_future) => return None,
9999
}
100100
}
@@ -106,10 +106,10 @@ impl<T> Reader<T> {
106106
O: 'static + Send,
107107
F: 'static + Send,
108108
{
109-
#[cfg(target_family = "wasm")]
109+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
110110
return Reader::Async(Box::new(MappedReader::new(self, mapper)));
111111

112-
#[cfg(not(target_family = "wasm"))]
112+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
113113
Reader::Sync(Box::new(MappedReader::new(self, mapper)))
114114
}
115115
}

burn-core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ dataset-minimal = ["burn-dataset"]
2929
dataset-sqlite = ["burn-dataset/sqlite"]
3030
dataset-sqlite-bundled = ["burn-dataset/sqlite-bundled"]
3131

32+
wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]
33+
3234
# Backend
3335
autodiff = ["burn-autodiff"]
3436

burn-core/src/grad_clipping/base.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl GradientClipping {
6868
clipped_grad.mask_fill(lower_mask, -threshold)
6969
}
7070

71-
#[cfg(target_family = "wasm")]
71+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
7272
fn clip_by_norm<B: Backend, const D: usize>(
7373
&self,
7474
_grad: Tensor<B, D>,
@@ -77,7 +77,7 @@ impl GradientClipping {
7777
todo!("Not yet supported on wasm");
7878
}
7979

80-
#[cfg(not(target_family = "wasm"))]
80+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
8181
fn clip_by_norm<B: Backend, const D: usize>(
8282
&self,
8383
grad: Tensor<B, D>,
@@ -96,7 +96,7 @@ impl GradientClipping {
9696
}
9797
}
9898

99-
#[cfg(not(target_family = "wasm"))]
99+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
100100
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
101101
let squared = tensor.powf(2.0);
102102
let sum = squared.sum();

burn-core/src/record/tensor.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {
4444
}
4545
}
4646

47-
// #[cfg(not(target_family = "wasm"))]
4847
impl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {
4948
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
5049
where
@@ -90,10 +89,10 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D> {
9089
type Item<S: PrecisionSettings> = FloatTensorSerde<S>;
9190

9291
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
93-
#[cfg(target_family = "wasm")]
92+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
9493
todo!("Recording float tensors isn't yet supported on wasm.");
9594

96-
#[cfg(not(target_family = "wasm"))]
95+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
9796
FloatTensorSerde::new(self.into_data().convert().serialize())
9897
}
9998

@@ -106,10 +105,10 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D, Int> {
106105
type Item<S: PrecisionSettings> = IntTensorSerde<S>;
107106

108107
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
109-
#[cfg(target_family = "wasm")]
108+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
110109
todo!("Recording int tensors isn't yet supported on wasm.");
111110

112-
#[cfg(not(target_family = "wasm"))]
111+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
113112
IntTensorSerde::new(self.into_data().convert().serialize())
114113
}
115114

@@ -122,10 +121,10 @@ impl<B: Backend, const D: usize> Record for Tensor<B, D, Bool> {
122121
type Item<S: PrecisionSettings> = BoolTensorSerde;
123122

124123
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
125-
#[cfg(target_family = "wasm")]
124+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
126125
todo!("Recording bool tensors isn't yet supported on wasm.");
127126

128-
#[cfg(not(target_family = "wasm"))]
127+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
129128
BoolTensorSerde::new(self.into_data().serialize())
130129
}
131130

burn-tensor/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ experimental-named-tensor = []
1616
export_tests = ["burn-tensor-testgen"]
1717
std = ["rand/std", "half/std"]
1818
benchmark = []
19+
wasm-sync = []
1920

2021
[dependencies]
2122
burn-common = { path = "../burn-common", version = "0.10.0", default-features = false }

burn-tensor/src/tensor/api/base.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
use alloc::vec::Vec;
44

5-
#[cfg(not(target_family = "wasm"))]
5+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
66
use alloc::format;
7-
#[cfg(not(target_family = "wasm"))]
7+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
88
use alloc::string::String;
9-
#[cfg(not(target_family = "wasm"))]
9+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
1010
use alloc::vec;
1111

1212
use burn_common::{reader::Reader, stub::Mutex};
@@ -325,25 +325,25 @@ where
325325
Self::new(K::to_device(self.primitive, device))
326326
}
327327

328-
#[cfg(target_family = "wasm")]
328+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
329329
/// Returns the data of the current tensor.
330330
pub async fn into_data(self) -> Data<K::Elem, D> {
331331
K::into_data(self.primitive).read().await
332332
}
333333

334-
#[cfg(not(target_family = "wasm"))]
334+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
335335
/// Returns the data of the current tensor.
336336
pub fn into_data(self) -> Data<K::Elem, D> {
337337
K::into_data(self.primitive).read()
338338
}
339339

340-
#[cfg(target_family = "wasm")]
340+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
341341
/// Returns the data of the current tensor.
342342
pub async fn to_data(&self) -> Data<K::Elem, D> {
343343
K::into_data(self.primitive.clone()).read().await
344344
}
345345

346-
#[cfg(not(target_family = "wasm"))]
346+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
347347
/// Returns the data of the current tensor without taking ownership.
348348
pub fn to_data(&self) -> Data<K::Elem, D> {
349349
Self::into_data(self.clone())
@@ -467,7 +467,7 @@ where
467467
K: BasicOps<B>,
468468
<K as BasicOps<B>>::Elem: Debug,
469469
{
470-
#[cfg(not(target_family = "wasm"))]
470+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
471471
#[inline]
472472
fn push_newline_indent(acc: &mut String, indent: usize) {
473473
acc.push('\n');
@@ -476,7 +476,7 @@ where
476476
}
477477
}
478478

479-
#[cfg(not(target_family = "wasm"))]
479+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
480480
fn fmt_inner_tensor(
481481
&self,
482482
acc: &mut String,
@@ -498,7 +498,7 @@ where
498498
}
499499
}
500500

501-
#[cfg(not(target_family = "wasm"))]
501+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
502502
fn fmt_outer_tensor(
503503
&self,
504504
acc: &mut String,
@@ -533,7 +533,7 @@ where
533533
/// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
534534
/// * `depth` - The current depth of the tensor dimensions being processed.
535535
/// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
536-
#[cfg(not(target_family = "wasm"))]
536+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
537537
fn display_recursive(
538538
&self,
539539
acc: &mut String,
@@ -644,7 +644,7 @@ where
644644
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
645645
writeln!(f, "Tensor {{")?;
646646

647-
#[cfg(not(target_family = "wasm"))]
647+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
648648
{
649649
let po = PRINT_OPTS.lock().unwrap();
650650
let mut acc = String::new();

burn-tensor/src/tensor/api/numeric.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ where
99
K: Numeric<B>,
1010
K::Elem: Element,
1111
{
12-
#[cfg(not(target_family = "wasm"))]
12+
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
1313
/// Convert the tensor into a scalar.
1414
///
1515
/// # Panics
@@ -21,7 +21,7 @@ where
2121
data.value[0]
2222
}
2323

24-
#[cfg(target_family = "wasm")]
24+
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
2525
/// Convert the tensor into a scalar.
2626
///
2727
/// # Panics

burn/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ std = ["burn-core/std"]
1818
# Training with full features
1919
train = ["burn-train/default", "autodiff", "dataset"]
2020

21+
# Useful when targeting WASM and not using WGPU.
22+
wasm-sync = ["burn-core/wasm-sync"]
23+
2124
## Include nothing
2225
train-minimal = ["burn-train"]
2326

examples/mnist-inference-web/build-for-web.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env bash
12

23
# Add wasm32 target for compiler.
34
rustup target add wasm32-unknown-unknown

0 commit comments

Comments
 (0)