Skip to content

Commit

Permalink
Enforce null-safety with ptr::NonNull
Browse files Browse the repository at this point in the history
  • Loading branch information
aumetra committed Jan 28, 2025
1 parent 1e27541 commit 2f8e126
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 64 deletions.
104 changes: 69 additions & 35 deletions packages/std/src/exports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
//! the contract-specific function pointer. This is done via the `#[entry_point]`
//! macro attribute from cosmwasm-derive.
use alloc::vec::Vec;
use core::marker::PhantomData;
use core::{marker::PhantomData, ptr};

use serde::de::DeserializeOwned;

Expand Down Expand Up @@ -91,7 +91,8 @@ extern "C" fn allocate(size: usize) -> u32 {
#[no_mangle]
extern "C" fn deallocate(pointer: u32) {
// auto-drop Region on function end
let _ = unsafe { Region::from_heap_ptr(pointer as *mut Region<Owned>) };
let _ =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(pointer as *mut Region<Owned>).unwrap()) };
}

// TODO: replace with https://doc.rust-lang.org/std/ops/trait.Try.html once stabilized
Expand Down Expand Up @@ -529,9 +530,12 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let info: Vec<u8> = unsafe { Region::from_heap_ptr(info_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let info: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(info_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let info: MessageInfo = try_into_contract_result!(from_json(info));
Expand All @@ -553,9 +557,12 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let info: Vec<u8> = unsafe { Region::from_heap_ptr(info_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let info: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(info_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let info: MessageInfo = try_into_contract_result!(from_json(info));
Expand All @@ -576,8 +583,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: M = try_into_contract_result!(from_json(msg));
Expand All @@ -598,9 +607,12 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let migrate_info = unsafe { Region::from_heap_ptr(migrate_info_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };
let migrate_info =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(migrate_info_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: M = try_into_contract_result!(from_json(msg));
Expand All @@ -621,8 +633,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: M = try_into_contract_result!(from_json(msg));
Expand All @@ -641,8 +655,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: Reply = try_into_contract_result!(from_json(msg));
Expand All @@ -661,8 +677,10 @@ where
M: DeserializeOwned,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: M = try_into_contract_result!(from_json(msg));
Expand All @@ -680,8 +698,10 @@ where
Q: CustomQuery,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: IbcChannelOpenMsg = try_into_contract_result!(from_json(msg));
Expand All @@ -701,8 +721,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: IbcChannelConnectMsg = try_into_contract_result!(from_json(msg));
Expand All @@ -722,8 +744,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: IbcChannelCloseMsg = try_into_contract_result!(from_json(msg));
Expand All @@ -743,8 +767,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: IbcPacketReceiveMsg = try_into_contract_result!(from_json(msg));
Expand All @@ -764,8 +790,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: IbcPacketAckMsg = try_into_contract_result!(from_json(msg));
Expand All @@ -785,8 +813,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: IbcPacketTimeoutMsg = try_into_contract_result!(from_json(msg));
Expand All @@ -805,8 +835,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: IbcSourceCallbackMsg = try_into_contract_result!(from_json(msg));
Expand All @@ -829,8 +861,10 @@ where
C: CustomMsg,
E: ToString,
{
let env: Vec<u8> = unsafe { Region::from_heap_ptr(env_ptr).into_vec() };
let msg: Vec<u8> = unsafe { Region::from_heap_ptr(msg_ptr).into_vec() };
let env: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(env_ptr).unwrap()).into_vec() };
let msg: Vec<u8> =
unsafe { Region::from_heap_ptr(ptr::NonNull::new(msg_ptr).unwrap()).into_vec() };

let env: Env = try_into_contract_result!(from_json(env));
let msg: IbcDestinationCallbackMsg = try_into_contract_result!(from_json(msg));
Expand Down
33 changes: 22 additions & 11 deletions packages/std/src/imports.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use alloc::vec::Vec;
use core::ptr;

use crate::import_helpers::{from_high_half, from_low_half};
use crate::memory::{Owned, Region};
Expand Down Expand Up @@ -133,7 +134,7 @@ impl Storage for ExternalStorage {
}

let value_ptr = read as *mut Region<Owned>;
let data = unsafe { Region::from_heap_ptr(value_ptr) };
let data = unsafe { Region::from_heap_ptr(ptr::NonNull::new(value_ptr).unwrap()) };

Some(data.into_vec())
}
Expand Down Expand Up @@ -255,7 +256,7 @@ impl Iterator for ExternalPartialIterator {
}

let data_region = next_result as *mut Region<Owned>;
let data = unsafe { Region::from_heap_ptr(data_region) };
let data = unsafe { Region::from_heap_ptr(ptr::NonNull::new(data_region).unwrap()) };

Some(data.into_vec())
}
Expand Down Expand Up @@ -284,7 +285,7 @@ impl Iterator for ExternalIterator {
fn next(&mut self) -> Option<Self::Item> {
let next_result = unsafe { db_next(self.iterator_id) };
let kv_region_ptr = next_result as *mut Region<Owned>;
let kv = unsafe { Region::from_heap_ptr(kv_region_ptr) };
let kv = unsafe { Region::from_heap_ptr(ptr::NonNull::new(kv_region_ptr).unwrap()) };

let (key, value) = decode_sections2(kv.into_vec());

Expand All @@ -307,7 +308,7 @@ fn skip_iter(iter_id: u32, count: usize) {
}

// just deallocate the region
unsafe { Region::from_heap_ptr(region as *mut Region<Owned>) };
unsafe { Region::from_heap_ptr(ptr::NonNull::new(region as *mut Region<Owned>).unwrap()) };
}
}

Expand Down Expand Up @@ -574,8 +575,12 @@ impl Api for ExternalApi {
let pubkey_ptr = from_low_half(result);
match error_code {
0 => {
let pubkey =
unsafe { Region::from_heap_ptr(pubkey_ptr as *mut Region<Owned>).into_vec() };
let pubkey = unsafe {
Region::from_heap_ptr(
ptr::NonNull::new(pubkey_ptr as *mut Region<Owned>).unwrap(),
)
.into_vec()
};
Ok(pubkey)
}
2 => panic!("MessageTooLong must not happen. This is a bug in the VM."),
Expand Down Expand Up @@ -631,8 +636,12 @@ impl Api for ExternalApi {
let pubkey_ptr = from_low_half(result);
match error_code {
0 => {
let pubkey =
unsafe { Region::from_heap_ptr(pubkey_ptr as *mut Region<Owned>).into_vec() };
let pubkey = unsafe {
Region::from_heap_ptr(
ptr::NonNull::new(pubkey_ptr as *mut Region<Owned>).unwrap(),
)
.into_vec()
};
Ok(pubkey)
}
2 => panic!("MessageTooLong must not happen. This is a bug in the VM."),
Expand Down Expand Up @@ -712,7 +721,7 @@ impl Api for ExternalApi {
/// Takes a pointer to a Region and reads the data into a String.
/// This is for trusted string sources only.
unsafe fn consume_string_region_written_by_vm(from: *mut Region<Owned>) -> String {
let data = Region::from_heap_ptr(from).into_vec();
let data = Region::from_heap_ptr(ptr::NonNull::new(from).unwrap()).into_vec();
// We trust the VM/chain to return correct UTF-8, so let's save some gas
String::from_utf8_unchecked(data)
}
Expand All @@ -732,8 +741,10 @@ impl Querier for ExternalQuerier {
let request_ptr = req.as_ptr() as u32;

let response_ptr = unsafe { query_chain(request_ptr) };
let response =
unsafe { Region::from_heap_ptr(response_ptr as *mut Region<Owned>).into_vec() };
let response = unsafe {
Region::from_heap_ptr(ptr::NonNull::new(response_ptr as *mut Region<Owned>).unwrap())
.into_vec()
};

from_json(&response).unwrap_or_else(|parsing_err| {
SystemResult::Err(SystemError::InvalidResponse {
Expand Down
Loading

0 comments on commit 2f8e126

Please sign in to comment.