Skip to content

Commit

Permalink
[Refactor] Just-In-Time Backend (#1280)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Feb 12, 2024
1 parent 03bbc64 commit dfc65ab
Show file tree
Hide file tree
Showing 82 changed files with 1,510 additions and 1,321 deletions.
2 changes: 1 addition & 1 deletion burn-compute/src/channel/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use burn_common::reader::Reader;

/// The ComputeChannel trait links the ComputeClient to the ComputeServer
/// while ensuring thread-safety
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug {
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send + Sync {
/// Given a handle, returns owned resource as bytes
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>>;

Expand Down
5 changes: 5 additions & 0 deletions burn-compute/src/channel/cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,8 @@ where
self.server.borrow_mut().sync()
}
}

/// This is unsafe, since no concurrency is supported by the `RefCell` channel.
/// However using this channel should only be done in single threaded environments such as `no-std`.
unsafe impl<Server: ComputeServer> Send for RefCellComputeChannel<Server> {}
unsafe impl<Server: ComputeServer> Sync for RefCellComputeChannel<Server> {}
9 changes: 1 addition & 8 deletions burn-compute/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ use alloc::vec::Vec;
use alloc::{boxed::Box, sync::Arc};
use burn_common::reader::Reader;
use burn_common::stub::RwLock;
use core::marker::PhantomData;

/// The ComputeClient is the entry point to require tasks from the ComputeServer.
/// It should be obtained for a specific device via the Compute struct.
#[derive(Debug)]
pub struct ComputeClient<Server: ComputeServer, Channel> {
channel: Channel,
tuner: Arc<RwLock<Tuner<Server, Channel>>>,
_server: PhantomData<Server>,
}

impl<S, C> Clone for ComputeClient<S, C>
Expand All @@ -27,7 +25,6 @@ where
Self {
channel: self.channel.clone(),
tuner: self.tuner.clone(),
_server: PhantomData,
}
}
}
Expand All @@ -39,11 +36,7 @@ where
{
/// Create a new client.
pub fn new(channel: Channel, tuner: Arc<RwLock<Tuner<Server, Channel>>>) -> Self {
Self {
channel,
tuner,
_server: PhantomData,
}
Self { channel, tuner }
}

/// Given a handle, returns owned resource as bytes.
Expand Down
4 changes: 2 additions & 2 deletions burn-compute/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use hashbrown::HashMap;

/// The compute type has the responsibility to retrieve the correct compute client based on the
/// given device.
pub struct Compute<Device, Server: ComputeServer, Channel> {
pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
}

impl<Device, Server, Channel> Compute<Device, Server, Channel>
impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Server: ComputeServer,
Expand Down
2 changes: 1 addition & 1 deletion burn-compute/src/memory_management/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::storage::ComputeStorage;
///
/// It is responsible for determining if the memory segment can be mutated,
/// for instance by keeping track of a reference count
pub trait MemoryHandle: Clone + Send + core::fmt::Debug {
pub trait MemoryHandle: Clone + Send + Sync + core::fmt::Debug {
/// Checks if the underlying memory can be safely mutated.
fn can_mut(&self) -> bool;
}
Expand Down
11 changes: 10 additions & 1 deletion burn-compute/src/tune/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,16 @@ pub trait AutotuneOperation {
#[cfg(feature = "autotune-persistent-cache")]
/// Trait alias with support for persistent caching
pub trait AutotuneKey:
Clone + Debug + PartialEq + Eq + Hash + Display + serde::Serialize + serde::de::DeserializeOwned
Clone
+ Debug
+ PartialEq
+ Eq
+ Hash
+ Display
+ serde::Serialize
+ serde::de::DeserializeOwned
+ Send
+ Sync
{
}
#[cfg(not(feature = "autotune-persistent-cache"))]
Expand Down
6 changes: 3 additions & 3 deletions burn-compute/tests/dummy/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn_compute::client::ComputeClient;
use burn_compute::memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy};
use burn_compute::storage::BytesStorage;
use burn_compute::tune::Tuner;
use burn_compute::Compute;
use burn_compute::ComputeRuntime;

/// The dummy device.
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
Expand All @@ -16,7 +16,7 @@ pub struct DummyDevice;
pub type DummyChannel = MutexComputeChannel<DummyServer>;
pub type DummyClient = ComputeClient<DummyServer, DummyChannel>;

static COMPUTE: Compute<DummyDevice, DummyServer, DummyChannel> = Compute::new();
static RUNTIME: ComputeRuntime<DummyDevice, DummyServer, DummyChannel> = ComputeRuntime::new();
pub static TUNER_DEVICE_ID: &str = "tests/dummy-device";

pub fn init_client() -> ComputeClient<DummyServer, MutexComputeChannel<DummyServer>> {
Expand All @@ -30,5 +30,5 @@ pub fn init_client() -> ComputeClient<DummyServer, MutexComputeChannel<DummyServ
}

pub fn client(device: &DummyDevice) -> DummyClient {
COMPUTE.client(device, init_client)
RUNTIME.client(device, init_client)
}
30 changes: 16 additions & 14 deletions burn-compute/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod dummy;
use std::sync::Arc;

use crate::dummy::{client, DummyDevice, DummyElementwiseAddition};
use burn_compute::ComputeRuntime;

#[allow(unused)]
use serial_test::serial;
Expand Down Expand Up @@ -90,9 +91,10 @@ fn autotune_basic_multiplication_execution() {
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_same_key_return_a_cache_hit() {
let compute: burn_compute::Compute<DummyDevice, dummy::DummyServer, dummy::DummyChannel> =
burn_compute::Compute::new();
let client = compute.client(&DummyDevice, dummy::init_client);
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
let runtime = Runtime::new();

let client = runtime.client(&DummyDevice, dummy::init_client);

// note: the key name depends on the shapes of the operation set
// see log_shape_input_key for more info.
Expand Down Expand Up @@ -133,8 +135,9 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_DEVICE_ID);
let _ = std::fs::remove_file(file_path);

let compute: burn_compute::Compute<DummyDevice, dummy::DummyServer, dummy::DummyChannel> =
burn_compute::Compute::new();
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
let compute = Runtime::new();

let client = compute.client(&DummyDevice, dummy::init_client);

// in this test shapes [1,3] and [1,5] ends up with different key names
Expand Down Expand Up @@ -178,9 +181,9 @@ fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {
// Delete the cache file's parent directory
let _ = std::fs::remove_dir_all(parent_dir);

let compute: burn_compute::Compute<DummyDevice, dummy::DummyServer, dummy::DummyChannel> =
burn_compute::Compute::new();
let client = compute.client(&DummyDevice, dummy::init_client);
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
let runtime = Runtime::new();
let client = runtime.client(&DummyDevice, dummy::init_client);

// in this test shapes [1,3] and [1,5] ends up with different key names
// which are 'cache_test-1,4' and 'cache_test-1,8'
Expand Down Expand Up @@ -240,9 +243,9 @@ fn autotune_cache_different_keys_return_a_cache_miss() {
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_different_checksums_return_a_cache_miss() {
let compute: burn_compute::Compute<DummyDevice, dummy::DummyServer, dummy::DummyChannel> =
burn_compute::Compute::new();
let client = compute.client(&DummyDevice, dummy::init_client);
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
let runtime = Runtime::new();
let client = runtime.client(&DummyDevice, dummy::init_client);

// in this test both shapes [1,3] and [1,4] end up with the same key name
// which is 'cache_test-1,4'
Expand All @@ -259,9 +262,8 @@ fn autotune_cache_different_checksums_return_a_cache_miss() {
// we use a second compute client in order to have freshly initialized autotune cache
// and test invalidation of the cache when the checksum of the operation set is
// different
let compute: burn_compute::Compute<DummyDevice, dummy::DummyServer, dummy::DummyChannel> =
burn_compute::Compute::new();
let client = compute.client(&DummyDevice, dummy::init_client);
let runtime = Runtime::new();
let client = runtime.client(&DummyDevice, dummy::init_client);

let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]];
let lhs_2 = client.create(&[0, 1, 2, 3]);
Expand Down
7 changes: 2 additions & 5 deletions burn-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ pub type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(all(test, feature = "test-tch"))]
pub type TestBackend = burn_tch::LibTorch<f32>;

#[cfg(all(test, feature = "test-wgpu", not(target_os = "macos")))]
pub type TestBackend = burn_wgpu::WgpuBackend<burn_wgpu::Vulkan, f32, i32>;

#[cfg(all(test, feature = "test-wgpu", target_os = "macos"))]
pub type TestBackend = burn_wgpu::WgpuBackend<burn_wgpu::Metal, f32, i32>;
#[cfg(all(test, feature = "test-wgpu"))]
pub type TestBackend = burn_wgpu::Wgpu;

#[cfg(feature = "std")]
#[cfg(test)]
Expand Down
6 changes: 4 additions & 2 deletions burn-wgpu/benches/fused_elemwise.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_tensor::backend::Backend;
use burn_tensor::{Distribution, Shape, Tensor};
use burn_wgpu::{Wgpu, WgpuDevice};
use burn_wgpu::compute::WgpuRuntime;
use burn_wgpu::{AutoGraphicsApi, JitBackend, WgpuDevice};
use derive_new::new;
use std::marker::PhantomData;

Expand Down Expand Up @@ -55,7 +56,8 @@ impl<B: Backend> Benchmark for ElemWiseBenchmark<B> {
#[allow(dead_code)]
/// Runs the benchmarks for wgpu matmul implementations
pub fn bench(device: &WgpuDevice) {
let result = run_benchmark(ElemWiseBenchmark::<Wgpu>::new(
type Backend = JitBackend<WgpuRuntime<AutoGraphicsApi, f32, i32>>;
let result = run_benchmark(ElemWiseBenchmark::<Backend>::new(
Shape::new([256, 256, 1024]),
device.clone(),
10,
Expand Down
12 changes: 7 additions & 5 deletions burn-wgpu/benches/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_tensor::backend::Backend;
use burn_tensor::{Distribution, Shape, Tensor};
use burn_wgpu::compute::WgpuRuntime;
use burn_wgpu::kernel::matmul::init_matmul_output;
use burn_wgpu::kernel::matmul::unpadded::matmul_tiling_2d_unpadded;
use burn_wgpu::kernel::matmul::vec4::matmul_tiling_2d_vec4;
use burn_wgpu::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs;
use burn_wgpu::WgpuDevice;
use burn_wgpu::{AutoGraphicsApi, WgpuBackend};
use burn_wgpu::{AutoGraphicsApi, JitBackend};
use derive_new::new;
use std::marker::PhantomData;

Expand All @@ -15,7 +16,8 @@ use burn_wgpu::{
GraphicsApi,
};

type WTensor<G, const D: usize> = Tensor<WgpuBackend<G, f32, i32>, D>;
type WBackend<G> = JitBackend<WgpuRuntime<G, f32, i32>>;
type WTensor<G, const D: usize> = Tensor<WBackend<G>, D>;

#[derive(new)]
struct MatmulBenchmark<B: Backend, F, const D: usize> {
Expand All @@ -30,7 +32,7 @@ trait MatmulFunction<G: GraphicsApi, const D: usize> {
fn run(lhs: WTensor<G, D>, rhs: WTensor<G, D>) -> WTensor<G, D>;
}

impl<F, const D: usize, G> Benchmark for MatmulBenchmark<WgpuBackend<G, f32, i32>, F, D>
impl<F, const D: usize, G> Benchmark for MatmulBenchmark<WBackend<G>, F, D>
where
F: MatmulFunction<G, D>,
G: GraphicsApi,
Expand Down Expand Up @@ -64,7 +66,7 @@ where
}

fn sync(&self) {
WgpuBackend::<G, f32, i32>::sync(&self.device)
WBackend::<G>::sync(&self.device)
}
}

Expand All @@ -80,7 +82,7 @@ macro_rules! bench_matmul {
}
}
type $benchmark<const D: usize> =
MatmulBenchmark<WgpuBackend<AutoGraphicsApi, f32, i32>, $matmul_name, D>;
MatmulBenchmark<WBackend<AutoGraphicsApi>, $matmul_name, D>;
};
}
bench_matmul!(NaiveMatmulBenchmark, NaiveMatmul, matmul_naive_default);
Expand Down
15 changes: 8 additions & 7 deletions burn-wgpu/benches/reduction.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use burn_common::benchmark::{run_benchmark, Benchmark};
use burn_tensor::backend::Backend;
use burn_tensor::{Distribution, Shape, Tensor};
use burn_wgpu::compute::WgpuRuntime;
use burn_wgpu::kernel::reduce::{init_reduce_output, sum_dim, sum_dim_shared_memory};
use burn_wgpu::GraphicsApi;
use burn_wgpu::WgpuDevice;
use burn_wgpu::{AutoGraphicsApi, WgpuBackend};
use burn_wgpu::{AutoGraphicsApi, JitBackend};
use derive_new::new;
use std::marker::PhantomData;

use burn_wgpu::GraphicsApi;

type WTensor<G, const D: usize> = Tensor<WgpuBackend<G, f32, i32>, D>;
type WBackend<G> = JitBackend<WgpuRuntime<G, f32, i32>>;
type WTensor<G, const D: usize> = Tensor<WBackend<G>, D>;

#[derive(new)]
struct ReduceBenchmark<B: Backend, F, const D: usize> {
Expand All @@ -24,7 +25,7 @@ trait ReduceFunction<G: GraphicsApi, const D: usize> {
fn run(input: WTensor<G, D>, dim: usize) -> WTensor<G, D>;
}

impl<F, const D: usize, G> Benchmark for ReduceBenchmark<WgpuBackend<G, f32, i32>, F, D>
impl<F, const D: usize, G> Benchmark for ReduceBenchmark<WBackend<G>, F, D>
where
F: ReduceFunction<G, D>,
G: GraphicsApi,
Expand Down Expand Up @@ -55,7 +56,7 @@ where
}

fn sync(&self) {
WgpuBackend::<G, f32, i32>::sync(&self.device)
WBackend::<G>::sync(&self.device)
}
}

Expand All @@ -70,7 +71,7 @@ macro_rules! bench_reduce {
}
}
type $benchmark<const D: usize> =
ReduceBenchmark<WgpuBackend<AutoGraphicsApi, f32, i32>, $reduce_name, D>;
ReduceBenchmark<WBackend<AutoGraphicsApi>, $reduce_name, D>;
};
}

Expand Down
Loading

0 comments on commit dfc65ab

Please sign in to comment.