From 0c9f0c538aa0a5674950d8f81699e2a825c11473 Mon Sep 17 00:00:00 2001 From: Mingwei Samuel Date: Wed, 21 Aug 2024 09:28:48 -0700 Subject: [PATCH] multi-process proof-of-concept yay unsoundness: thread '' panicked at core\src\panicking.rs:221:5: unsafe precondition(s) violated: slice::get_unchecked requires that the index is within the slice note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace thread caused non-unwinding panic. aborting. error: test failed, to rerun pass `--test basic` Caused by: process didn't exit successfully: `D:\Projects\blondie\target\debug\deps\basic-10e35b968cd5dbed.exe` (exit code: 0xc0000409, STATUS_STACK_BUFFER_OVERRUN) --- src/lib.rs | 216 ++++++++++++++++++++++++++----------------------- tests/multi.rs | 16 ++++ 2 files changed, 133 insertions(+), 99 deletions(-) create mode 100644 tests/multi.rs diff --git a/src/lib.rs b/src/lib.rs index 76efc66..47c8757 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,19 +11,17 @@ mod error; mod util; +use std::collections::hash_map::Entry; use std::ffi::OsString; use std::io::{Read, Write}; -use std::marker::PhantomPinned; use std::mem::size_of; use std::os::windows::ffi::OsStringExt; use std::path::PathBuf; use std::ptr::{addr_of, addr_of_mut}; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::mpsc::{Receiver, Sender}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::{Arc, Mutex, Weak}; -use std::thread::JoinHandle; -use crossbeam_skiplist::{SkipMap, SkipSet}; use object::Object; use pdb_addr2line::pdb::PDB; use pdb_addr2line::ContextPdbData; @@ -61,7 +59,7 @@ struct TraceContext { show_kernel_samples: bool, /// map[array_of_stacktrace_addrs] = sample_count - stack_counts_hashmap: FxHashMap<[u64; MAX_STACK_DEPTH], AtomicU64>, + stack_counts_hashmap: FxHashMap<[u64; MAX_STACK_DEPTH], u64>, /// (image_path, image_base, image_size) image_paths: Vec<(OsString, u64, u64)>, } @@ -87,7 +85,7 @@ impl TraceContext { ["Y", "YES", "TRUE"].iter().any(|truthy| &upper == truthy) }) .unwrap_or(kernel_stacks), - image_paths: SkipSet::new(), + image_paths: Vec::with_capacity(1024), }) } } @@ -103,50 +101,78 @@ impl Drop for TraceContext { } } +/// Stateful context provided to `event_record_callback`, containing multiple [`TraceContext`]s. struct Context { - /// Keys are process IDs. + /// Keys are process IDs. `Weak` deallocates when tracing should stop. traces: FxHashMap>, /// Receive new processes to subscribe to tracing. subscribe_recv: Receiver>, - /// Set to true once the trace starts running, deallocated afterward. - trace_running: Weak, + /// Set to true once the trace starts running. + trace_running: Arc, +} +impl Context { + /// SAFETY: May only be called by `event_record_callback`, while tracing. + unsafe fn get_trace_context(&mut self, pid: u32) -> Option> { + // TODO: handle PID reuse??? + let Entry::Occupied(entry) = self.traces.entry(pid) else { + // Ignore dlls for other processes + return None; + }; + if let Some(trace_context) = Weak::upgrade(entry.get()) { + Some(trace_context) + } else { + // Tracing just stopped, remove deallocated `Weak`. + entry.remove(); + return None; + } + } } -/// Global state context provided to `event_record_callback`. -struct Global { +/// Global tracing session. +/// +/// When this is dropped, the tracing session will be stopped. +struct Session { /// Send new processes to subscribe to tracing. - subscriber_send: Sender>, - /// Allocation for [`Context`]. - context: *const Context, - /// Allocation for event trace props, need deallocation after. - event_trace_props: Box, - /// Allocation for Logfile. - log: EVENT_TRACE_LOGFILEA, + subscribe_send: SyncSender>, + /// Box allocation for [`UserContext`]. + context: *mut Context, + /// Box allocation for event trace props, need deallocation after. + event_trace_props: *mut EVENT_TRACE_PROPERTIES_WITH_STRING, + /// Box allocation for Logfile. + log: *mut EVENT_TRACE_LOGFILEA, } +unsafe impl Send for Session {} +unsafe impl Sync for Session {} -impl Global { +impl Session { fn start(self: &Arc, trace_context: TraceContext) -> TraceGuard { let trace_context = Arc::new(trace_context); - self.traces - .insert((*trace_context).target_proc_pid, Arc::clone(&trace_context)); + self.subscribe_send + .send(Arc::downgrade(&trace_context)) + .unwrap(); TraceGuard { - global: Arc::clone(&self), trace_context, + _session: Arc::clone(&self), } } } -impl Drop for Global { +impl Drop for Session { fn drop(&mut self) { let ret = unsafe { // This unblocks ProcessTrace ControlTraceA( ::default(), KERNEL_LOGGER_NAMEA, - addr_of_mut!(*self.event_trace_props).cast(), + self.event_trace_props.cast(), EVENT_TRACE_CONTROL_STOP, ) }; + unsafe { + drop(Box::from_raw(self.context)); + drop(Box::from_raw(self.event_trace_props)); + drop(Box::from_raw(self.log)); + } if ret != ERROR_SUCCESS { eprintln!( "Error dropping GlobalContext: {:?}", @@ -156,43 +182,22 @@ impl Drop for Global { } } -#[must_use] struct TraceGuard { - global: Arc, trace_context: Arc, + /// Ensure session stays alive while `TraceGuard` is alive. + _session: Arc, } impl TraceGuard { - fn stop(mut self) -> TraceContext { - println!("BEFORE {}.", Arc::strong_count(&self.trace_context)); - - let entry = self - .global - .traces - .remove(&self.trace_context.target_proc_pid) - .expect("Cannot stop multiple times."); - println!("REMOVE {}", entry.remove()); - drop(entry); - drop(self.global); - - println!( - "TRYING TO TAKE FROM ARC {}.", - Arc::strong_count(&self.trace_context) - ); - loop { - match Arc::try_unwrap(self.trace_context) { - Ok(trace_context) => break trace_context, - Err(still_shared) => { - self.trace_context = still_shared; - } - } - std::hint::spin_loop(); - } + fn stop(self) -> TraceContext { + Arc::try_unwrap(self.trace_context) + .map_err(drop) + .expect("TraceContext Arc count should never have been incremented.") } } /// Gets the global context. Begins tracing if not already running. -fn get_global_context() -> Result> { - static GLOBAL_CONTEXT: Mutex> = Mutex::new(Weak::new()); +fn get_global_context() -> Result> { + static GLOBAL_CONTEXT: Mutex> = Mutex::new(Weak::new()); let mut unlocked = GLOBAL_CONTEXT.lock().unwrap(); if let Some(global_context) = unlocked.upgrade() { @@ -333,19 +338,7 @@ fn get_global_context() -> Result> { } } - let mut context_arc = Arc::new(Global { - trace_running: AtomicBool::new(false), - traces: SkipMap::new(), - event_trace_props, - log: EVENT_TRACE_LOGFILEA::default(), - _pin: PhantomPinned, - }); - let context = Arc::get_mut(&mut context_arc).unwrap(); - - // `!Unpin` - context.log.Context = addr_of!(*context).cast_mut().cast(); - - let log = &mut context.log; + let mut log = Box::new(EVENT_TRACE_LOGFILEA::default()); log.LoggerName = PSTR(kernel_logger_name_with_nul.as_mut_ptr()); log.Anonymous1.ProcessTraceMode = PROCESS_TRACE_MODE_REAL_TIME | PROCESS_TRACE_MODE_EVENT_RECORD @@ -354,19 +347,28 @@ fn get_global_context() -> Result> { unsafe extern "system" fn event_record_callback(record: *mut EVENT_RECORD) { let provider_guid_data1 = (*record).EventHeader.ProviderId.data1; let event_opcode = (*record).EventHeader.EventDescriptor.Opcode; - let context = &mut *(*record).UserContext.cast::(); + + let context = &mut *(*record).UserContext.cast::(); context.trace_running.store(true, Ordering::Relaxed); + // Subscribe any new processes. + context + .traces + .extend(context.subscribe_recv.try_iter().filter_map(|weak| { + let pid = Weak::upgrade(&weak)?.target_proc_pid; + Some((pid, weak)) + })); const EVENT_TRACE_TYPE_LOAD: u8 = 10; if event_opcode == EVENT_TRACE_TYPE_LOAD { let event = (*record).UserData.cast::().read_unaligned(); - // TODO: handle PID reuse??? - let Some(trace_context) = context.traces.get(&event.ProcessId) else { + let Some(trace_context) = context.get_trace_context(event.ProcessId) else { // Ignore dlls for other processes return; }; - let trace_context = trace_context.value(); + // TODO: use `Arc::get_mut_unchecked` once stable. + // SAFETY: Only the callback may modify the `TraceContext` while running. + let trace_context = Arc::into_raw(trace_context).cast_mut(); let filename_p = (*record) .UserData @@ -377,12 +379,15 @@ fn get_global_context() -> Result> { filename_p, ((*record).UserDataLength as usize - size_of::()) / 2, )); - trace_context.image_paths.insert(( + (*trace_context).image_paths.push(( filename_os_string, event.ImageBase as u64, event.ImageSize as u64, )); + // SAFETY: De-increments Arc from above. + drop(Arc::from_raw(trace_context)); + return; } @@ -399,11 +404,13 @@ fn get_global_context() -> Result> { let _thread = ud_p.cast::().offset(3).read_unaligned(); // TODO: handle PID reuse??? - let Some(trace_context) = context.traces.get(&proc) else { + let Some(trace_context) = context.get_trace_context(proc) else { // Ignore stackwalks for other processes return; }; - let trace_context = trace_context.value(); + // TODO: use `Arc::get_mut_unchecked` once stable. + // SAFETY: Only the callback may modify the `TraceContext` while running. + let trace_context = Arc::into_raw(trace_context).cast_mut(); let stack_depth_32 = ((*record).UserDataLength - 16) / 4; let stack_depth_64 = stack_depth_32 / 2; @@ -434,10 +441,11 @@ fn get_global_context() -> Result> { let mut stack = [0u64; MAX_STACK_DEPTH]; stack[..(stack_depth as usize).min(MAX_STACK_DEPTH)].copy_from_slice(stack_addrs); - let entry = trace_context - .stack_counts_hashmap - .get_or_insert(stack, AtomicU64::new(0)); - entry.value().fetch_add(1, Ordering::Relaxed); + let entry = (*trace_context).stack_counts_hashmap.entry(stack); + *entry.or_insert(0) += 1; + + // SAFETY: De-increments Arc from above. + drop(Arc::from_raw(trace_context)); const DEBUG_OUTPUT_EVENTS: bool = false; if DEBUG_OUTPUT_EVENTS { @@ -486,6 +494,15 @@ fn get_global_context() -> Result> { } log.Anonymous2.EventRecordCallback = Some(event_record_callback); + let (subscribe_send, subscribe_recv) = sync_channel(16); + let trace_running = Arc::new(AtomicBool::new(false)); + let context = Box::into_raw(Box::new(Context { + traces: Default::default(), + subscribe_recv, + trace_running: Arc::clone(&trace_running), + })); + log.Context = context.cast(); + let trace_processing_handle = unsafe { OpenTraceA(addr_of_mut!(*log)) }; if trace_processing_handle.0 == INVALID_HANDLE_VALUE.0 as u64 { return Err(get_last_error("OpenTraceA processing")); @@ -506,13 +523,20 @@ fn get_global_context() -> Result> { }); // Wait until we know for sure the trace is running - while !context.trace_running.load(Ordering::Relaxed) { + while !trace_running.load(Ordering::Relaxed) { std::hint::spin_loop(); } - let stored = Arc::downgrade(&context_arc); - *unlocked = stored; - Ok(context_arc) + // Store the session. + let session = Arc::new(Session { + subscribe_send, + context, + event_trace_props: Box::into_raw(event_trace_props), + // TODO: does log need to survive past the `OpenTraceA` call? Maybe not + log: Box::into_raw(log), + }); + *unlocked = Arc::downgrade(&session); + Ok(session) } const KERNEL_LOGGER_NAMEA_LEN: usize = unsafe { @@ -572,11 +596,10 @@ unsafe fn trace_from_process_id( // Wait for it to end util::wait_for_process_by_handle(target_proc_handle)?; - let trace_context = trace_guard.stop(); + let mut trace_context = trace_guard.stop(); if trace_context.show_kernel_samples { - todo!(); - // let kernel_module_paths = util::list_kernel_modules(); - // trace_context.image_paths.extend(kernel_module_paths); + let kernel_module_paths = util::list_kernel_modules(); + trace_context.image_paths.extend(kernel_module_paths); } Ok(trace_context) } @@ -624,10 +647,8 @@ pub fn trace_command( /// /// You can get them using [`CollectionResults::iter_callstacks`] pub struct CallStack<'a> { - stack: [u64; MAX_STACK_DEPTH], + stack: &'a [u64; MAX_STACK_DEPTH], sample_count: u64, - // TODO - _phantom: std::marker::PhantomData<&'a ()>, } /// An address from a callstack @@ -648,9 +669,7 @@ type PdbDb<'a, 'b> = std::collections::BTreeMap)>; /// Returns Vec<(image_base, image_size, image_name, addr2line pdb context)> -fn find_pdbs( - images: impl IntoIterator, -) -> Vec<(u64, u64, OsString, OwnedPdb)> { +fn find_pdbs(images: &[(OsString, u64, u64)]) -> Vec<(u64, u64, OsString, OwnedPdb)> { let mut pdb_db = Vec::new(); fn owned_pdb(pdb_file_bytes: Vec) -> Option { @@ -713,7 +732,7 @@ fn find_pdbs( _ => continue, }; - pdb_db.push((image_base, image_size, image_name.to_owned(), pdb_ctx)); + pdb_db.push((*image_base, *image_size, image_name.to_owned(), pdb_ctx)); } else if use_symsrv { let pdb_filename = match pdb_path.file_name() { Some(x) => x, @@ -746,7 +765,7 @@ fn find_pdbs( Some(x) => x, _ => continue, }; - pdb_db.push((image_base, image_size, image_name.to_owned(), pdb_ctx)); + pdb_db.push((*image_base, *image_size, image_name.to_owned(), pdb_ctx)); } } } @@ -774,7 +793,7 @@ impl<'a> CallStack<'a> { } let displacement = 0u64; let mut symbol_names_storage = reuse_vec(std::mem::take(v)); - for addr in self.stack { + for &addr in self.stack { if addr == 0 { *v = symbol_names_storage; return Ok(()); @@ -816,14 +835,13 @@ impl CollectionResults { /// Iterate the distinct callstacks sampled in this execution pub fn iter_callstacks(&self) -> impl std::iter::Iterator> { self.0.stack_counts_hashmap.iter().map(|x| CallStack { - stack: x.key().clone(), - sample_count: x.value().load(Ordering::Relaxed), - _phantom: std::marker::PhantomData, + stack: x.0, + sample_count: *x.1, }) } /// Resolve call stack symbols and write a dtrace-like sampling report to `w` pub fn write_dtrace(&self, mut w: W) -> Result<()> { - let pdbs = find_pdbs(self.0.image_paths.iter().map(|entry| entry.value().clone())); + let pdbs = find_pdbs(&self.0.image_paths); let pdb_db: PdbDb = pdbs .iter() .filter_map(|(a, b, c, d)| d.make_context().ok().map(|d| (*a, (*a, *b, c.clone(), d)))) diff --git a/tests/multi.rs b/tests/multi.rs new file mode 100644 index 0000000..d6cc168 --- /dev/null +++ b/tests/multi.rs @@ -0,0 +1,16 @@ +use std::process::Command; + +#[test] +fn test_multi() { + let handle = std::thread::spawn(|| { + let mut cmd = Command::new("ping"); + cmd.arg("localhost"); + let _ctx = blondie::trace_command(cmd, false).unwrap(); + }); + + let mut cmd = Command::new("ping"); + cmd.arg("localhost"); + let _ctx = blondie::trace_command(cmd, false).unwrap(); + + handle.join().unwrap(); +}