sdk: refactor pda generation

This commit is contained in:
Trent Nelson
2021-06-18 01:03:58 -06:00
parent b14af989b8
commit fcabaa7eff
11 changed files with 110 additions and 22 deletions

View File

@@ -1,5 +1,9 @@
#![allow(clippy::integer_arithmetic)]
use crate::{decode_error::DecodeError, hash::hashv};
use crate::{
bpf_loader, bpf_loader_deprecated, config, decode_error::DecodeError, feature, hash::hashv,
secp256k1_program, stake, system_program, sysvar, vote,
};
use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};
use num_derive::{FromPrimitive, ToPrimitive};
use std::{
@@ -18,6 +22,8 @@ pub const MAX_SEEDS: usize = 16;
/// Maximum string length of a base58 encoded pubkey
const MAX_BASE58_LEN: usize = 44;
const PDA_MARKER: &[u8; 21] = b"ProgramDerivedAddress";
#[derive(Error, Debug, Serialize, Clone, PartialEq, FromPrimitive, ToPrimitive)]
pub enum PubkeyError {
/// Length of the seed is too long for address generation
@@ -25,6 +31,8 @@ pub enum PubkeyError {
MaxSeedLengthExceeded,
#[error("Provided seeds do not result in a valid address")]
InvalidSeeds,
#[error("Provided owner is not allowed")]
IllegalOwner,
}
impl<T> DecodeError<T> for PubkeyError {
fn type_of() -> &'static str {
@@ -159,8 +167,16 @@ impl Pubkey {
return Err(PubkeyError::MaxSeedLengthExceeded);
}
let owner = owner.as_ref();
if owner.len() >= PDA_MARKER.len() {
let slice = &owner[owner.len() - PDA_MARKER.len()..];
if slice == PDA_MARKER {
return Err(PubkeyError::IllegalOwner);
}
}
Ok(Pubkey::new(
hashv(&[base.as_ref(), seed.as_ref(), owner.as_ref()]).as_ref(),
hashv(&[base.as_ref(), seed.as_ref(), owner]).as_ref(),
))
}
@@ -200,6 +216,10 @@ impl Pubkey {
}
}
if program_id.is_native_program_id() {
return Err(PubkeyError::IllegalOwner);
}
// Perform the calculation inline, calling this from within a program is
// not supported
#[cfg(not(target_arch = "bpf"))]
@@ -208,7 +228,7 @@ impl Pubkey {
for seed in seeds.iter() {
hasher.hash(seed);
}
hasher.hashv(&[program_id.as_ref(), "ProgramDerivedAddress".as_ref()]);
hasher.hashv(&[program_id.as_ref(), PDA_MARKER]);
let hash = hasher.result();
if bytes_are_curve_point(hash) {
@@ -289,9 +309,10 @@ impl Pubkey {
{
let mut seeds_with_bump = seeds.to_vec();
seeds_with_bump.push(&bump_seed);
if let Ok(address) = Self::create_program_address(&seeds_with_bump, program_id)
{
return Some((address, bump_seed[0]));
match Self::create_program_address(&seeds_with_bump, program_id) {
Ok(address) => return Some((address, bump_seed[0])),
Err(PubkeyError::InvalidSeeds) => (),
_ => break,
}
}
bump_seed[0] -= 1;
@@ -349,6 +370,22 @@ impl Pubkey {
#[cfg(not(target_arch = "bpf"))]
crate::program_stubs::sol_log(&self.to_string());
}
pub fn is_native_program_id(&self) -> bool {
let all_program_ids = [
bpf_loader::id(),
bpf_loader_deprecated::id(),
feature::id(),
config::program::id(),
stake::program::id(),
stake::config::id(),
vote::program::id(),
secp256k1_program::id(),
system_program::id(),
sysvar::id(),
];
all_program_ids.contains(self)
}
}
impl AsRef<[u8]> for Pubkey {
@@ -485,7 +522,7 @@ mod tests {
fn test_create_program_address() {
let exceeded_seed = &[127; MAX_SEED_LEN + 1];
let max_seed = &[0; MAX_SEED_LEN];
let program_id = Pubkey::from_str("BPFLoader1111111111111111111111111111111111").unwrap();
let program_id = Pubkey::from_str("BPFLoaderUpgradeab1e11111111111111111111111").unwrap();
let public_key = Pubkey::from_str("SeedPubey1111111111111111111111111111111111").unwrap();
assert_eq!(
@@ -499,25 +536,25 @@ mod tests {
assert!(Pubkey::create_program_address(&[max_seed], &program_id).is_ok());
assert_eq!(
Pubkey::create_program_address(&[b"", &[1]], &program_id),
Ok("3gF2KMe9KiC6FNVBmfg9i267aMPvK37FewCip4eGBFcT"
Ok("BwqrghZA2htAcqq8dzP1WDAhTXYTYWj7CHxF5j7TDBAe"
.parse()
.unwrap())
);
assert_eq!(
Pubkey::create_program_address(&["".as_ref()], &program_id),
Ok("7ytmC1nT1xY4RfxCV2ZgyA7UakC93do5ZdyhdF3EtPj7"
Pubkey::create_program_address(&["".as_ref(), &[0]], &program_id),
Ok("13yWmRpaTR4r5nAktwLqMpRNr28tnVUZw26rTvPSSB19"
.parse()
.unwrap())
);
assert_eq!(
Pubkey::create_program_address(&[b"Talking", b"Squirrels"], &program_id),
Ok("HwRVBufQ4haG5XSgpspwKtNd3PC9GM9m1196uJW36vds"
Ok("2fnQrngrQT4SeLcdToJAD96phoEjNL2man2kfRLCASVk"
.parse()
.unwrap())
);
assert_eq!(
Pubkey::create_program_address(&[public_key.as_ref()], &program_id),
Ok("GUs5qLUfsEHkcMB9T38vjr18ypEhRuNWiePW2LoK4E3K"
Pubkey::create_program_address(&[public_key.as_ref(), &[1]], &program_id),
Ok("976ymqVnfE32QFe6NfGDctSvVa36LWnvYxhU6G2232YL"
.parse()
.unwrap())
);
@@ -564,4 +601,18 @@ mod tests {
);
}
}
#[test]
fn test_is_native_program_id() {
assert!(bpf_loader::id().is_native_program_id());
assert!(bpf_loader_deprecated::id().is_native_program_id());
assert!(config::program::id().is_native_program_id());
assert!(feature::id().is_native_program_id());
assert!(secp256k1_program::id().is_native_program_id());
assert!(stake::program::id().is_native_program_id());
assert!(stake::config::id().is_native_program_id());
assert!(system_program::id().is_native_program_id());
assert!(sysvar::id().is_native_program_id());
assert!(vote::program::id().is_native_program_id());
}
}