Add try_find_program_address syscall (#14118)
This commit is contained in:
@ -16,7 +16,7 @@ use solana_sdk::{
|
||||
entrypoint::{MAX_PERMITTED_DATA_INCREASE, SUCCESS},
|
||||
feature_set::{
|
||||
pubkey_log_syscall_enabled, ristretto_mul_syscall_enabled, sha256_syscall_enabled,
|
||||
sol_log_compute_units_syscall,
|
||||
sol_log_compute_units_syscall, try_find_program_address_syscall_enabled,
|
||||
},
|
||||
hash::{Hasher, HASH_BYTES},
|
||||
instruction::{AccountMeta, Instruction, InstructionError},
|
||||
@ -123,6 +123,12 @@ pub fn register_syscalls(
|
||||
b"sol_create_program_address",
|
||||
SyscallCreateProgramAddress::call,
|
||||
)?;
|
||||
if invoke_context.is_feature_active(&try_find_program_address_syscall_enabled::id()) {
|
||||
syscall_registry.register_syscall_by_name(
|
||||
b"sol_try_find_program_address",
|
||||
SyscallTryFindProgramAddress::call,
|
||||
)?;
|
||||
}
|
||||
syscall_registry
|
||||
.register_syscall_by_name(b"sol_invoke_signed_c", SyscallInvokeSignedC::call)?;
|
||||
syscall_registry
|
||||
@ -217,6 +223,17 @@ pub fn bind_syscall_context_objects<'a>(
|
||||
None,
|
||||
)?;
|
||||
|
||||
if invoke_context.is_feature_active(&try_find_program_address_syscall_enabled::id()) {
|
||||
vm.bind_syscall_context_object(
|
||||
Box::new(SyscallTryFindProgramAddress {
|
||||
cost: bpf_compute_budget.create_program_address_units,
|
||||
compute_meter: invoke_context.get_compute_meter(),
|
||||
loader_id,
|
||||
}),
|
||||
None,
|
||||
)?;
|
||||
}
|
||||
|
||||
// Cross-program invocation syscalls
|
||||
|
||||
let invoke_context = Rc::new(RefCell::new(invoke_context));
|
||||
@ -580,6 +597,33 @@ impl SyscallObject<BPFError> for SyscallAllocFree {
|
||||
}
|
||||
}
|
||||
|
||||
fn translate_program_address_inputs<'a>(
|
||||
seeds_addr: u64,
|
||||
seeds_len: u64,
|
||||
program_id_addr: u64,
|
||||
memory_mapping: &MemoryMapping,
|
||||
loader_id: &Pubkey,
|
||||
) -> Result<(Vec<&'a [u8]>, &'a Pubkey), EbpfError<BPFError>> {
|
||||
let untranslated_seeds =
|
||||
translate_slice::<&[&u8]>(memory_mapping, seeds_addr, seeds_len, loader_id)?;
|
||||
if untranslated_seeds.len() > MAX_SEEDS {
|
||||
return Err(SyscallError::BadSeeds(PubkeyError::MaxSeedLengthExceeded).into());
|
||||
}
|
||||
let seeds = untranslated_seeds
|
||||
.iter()
|
||||
.map(|untranslated_seed| {
|
||||
translate_slice::<u8>(
|
||||
memory_mapping,
|
||||
untranslated_seed.as_ptr() as *const _ as u64,
|
||||
untranslated_seed.len() as u64,
|
||||
loader_id,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, EbpfError<BPFError>>>()?;
|
||||
let program_id = translate_type::<Pubkey>(memory_mapping, program_id_addr, loader_id)?;
|
||||
Ok((seeds, program_id))
|
||||
}
|
||||
|
||||
/// Create a program address
|
||||
struct SyscallCreateProgramAddress<'a> {
|
||||
cost: u64,
|
||||
@ -597,35 +641,18 @@ impl<'a> SyscallObject<BPFError> for SyscallCreateProgramAddress<'a> {
|
||||
memory_mapping: &MemoryMapping,
|
||||
result: &mut Result<u64, EbpfError<BPFError>>,
|
||||
) {
|
||||
question_mark!(self.compute_meter.consume(self.cost), result);
|
||||
// TODO need ref?
|
||||
let untranslated_seeds = question_mark!(
|
||||
translate_slice::<&[&u8]>(memory_mapping, seeds_addr, seeds_len, self.loader_id),
|
||||
result
|
||||
);
|
||||
if untranslated_seeds.len() > MAX_SEEDS {
|
||||
*result = Ok(1);
|
||||
return;
|
||||
}
|
||||
let seeds = question_mark!(
|
||||
untranslated_seeds
|
||||
.iter()
|
||||
.map(|untranslated_seed| {
|
||||
translate_slice::<u8>(
|
||||
memory_mapping,
|
||||
untranslated_seed.as_ptr() as *const _ as u64,
|
||||
untranslated_seed.len() as u64,
|
||||
self.loader_id,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, EbpfError<BPFError>>>(),
|
||||
result
|
||||
);
|
||||
let program_id = question_mark!(
|
||||
translate_type::<Pubkey>(memory_mapping, program_id_addr, self.loader_id),
|
||||
let (seeds, program_id) = question_mark!(
|
||||
translate_program_address_inputs(
|
||||
seeds_addr,
|
||||
seeds_len,
|
||||
program_id_addr,
|
||||
memory_mapping,
|
||||
self.loader_id,
|
||||
),
|
||||
result
|
||||
);
|
||||
|
||||
question_mark!(self.compute_meter.consume(self.cost), result);
|
||||
let new_address = match Pubkey::create_program_address(&seeds, program_id) {
|
||||
Ok(address) => address,
|
||||
Err(_) => {
|
||||
@ -642,6 +669,64 @@ impl<'a> SyscallObject<BPFError> for SyscallCreateProgramAddress<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a program address
|
||||
struct SyscallTryFindProgramAddress<'a> {
|
||||
cost: u64,
|
||||
compute_meter: Rc<RefCell<dyn ComputeMeter>>,
|
||||
loader_id: &'a Pubkey,
|
||||
}
|
||||
impl<'a> SyscallObject<BPFError> for SyscallTryFindProgramAddress<'a> {
|
||||
fn call(
|
||||
&mut self,
|
||||
seeds_addr: u64,
|
||||
seeds_len: u64,
|
||||
program_id_addr: u64,
|
||||
address_addr: u64,
|
||||
bump_seed_addr: u64,
|
||||
memory_mapping: &MemoryMapping,
|
||||
result: &mut Result<u64, EbpfError<BPFError>>,
|
||||
) {
|
||||
let (seeds, program_id) = question_mark!(
|
||||
translate_program_address_inputs(
|
||||
seeds_addr,
|
||||
seeds_len,
|
||||
program_id_addr,
|
||||
memory_mapping,
|
||||
self.loader_id,
|
||||
),
|
||||
result
|
||||
);
|
||||
|
||||
let mut bump_seed = [std::u8::MAX];
|
||||
for _ in 0..std::u8::MAX {
|
||||
{
|
||||
let mut seeds_with_bump = seeds.to_vec();
|
||||
seeds_with_bump.push(&bump_seed);
|
||||
|
||||
question_mark!(self.compute_meter.consume(self.cost), result);
|
||||
if let Ok(new_address) =
|
||||
Pubkey::create_program_address(&seeds_with_bump, program_id)
|
||||
{
|
||||
let bump_seed_ref = question_mark!(
|
||||
translate_type_mut::<u8>(memory_mapping, bump_seed_addr, self.loader_id),
|
||||
result
|
||||
);
|
||||
let address = question_mark!(
|
||||
translate_slice_mut::<u8>(memory_mapping, address_addr, 32, self.loader_id),
|
||||
result
|
||||
);
|
||||
*bump_seed_ref = bump_seed[0];
|
||||
address.copy_from_slice(new_address.as_ref());
|
||||
*result = Ok(0);
|
||||
return;
|
||||
}
|
||||
}
|
||||
bump_seed[0] -= 1;
|
||||
}
|
||||
*result = Ok(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// SHA256
|
||||
pub struct SyscallSha256<'a> {
|
||||
sha256_base_cost: u64,
|
||||
|
Reference in New Issue
Block a user