1 stable release

1.0.0 May 18, 2021

#1985 in Procedural macros

MPL-2.0 license

25KB
316 lines

standard-dist

An attribute macro for creating a Standard distribution for rust types


lib.rs:

standard-dist is a library for automatically deriving a rand standard distribution for your types via a derive macro.

Usage examples

use rand::distributions::Uniform;
use standard_dist::StandardDist;

// Select heads or tails with equal probability
#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
enum Coin {
Heads,
Tails,
}

// Flip 3 coins, independently
#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
struct Coins {
first: Coin,
second: Coin,
third: Coin,
}

// Use the `#[distribution]` attribute to customize the distribution used on
// a field
#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
struct Die {
#[distribution(Uniform::from(1..=6))]
value: u8
}

// Use the `#[weight]` attribute to customize the relative probabilities of
// enum variants
#[derive(Debug, Clone, Copy, PartialEq, Eq, StandardDist)]
enum D20 {
#[weight(18)]
Normal,

Critical,
CriticalFail,
}

rand generates typed random values via the Distribution trait, which uses a source of randomness to produce values of the given type. Of particular note is the Standard distribution, which is the stateless "default" way to produce random values of a particular type. For instance:

  • For ints, this randomly chooses from all possible values for that int type
  • For bools, it chooses true or false with 50/50 probability
  • For Option<T>, it chooses None or Some with 50/50 probability, and uses Standard to randomly populate the inner Some value.

Structs

When you derive StandardDist for one of your own structs, it creates an impl Distribution<YourStruct> for Standard implementation, allowing you to create randomized instances of the struct via Rng::gen. This implementation will in turn use the Standard distribution to populate all the fields of your type.

use standard_dist::StandardDist;

#[derive(StandardDist)]
struct SimpleStruct {
coin: bool,
percent: f64,
}

let mut heads = 0;

for _ in 0..2000 {
let s: SimpleStruct = rand::random();
assert!(0.0 <= s.percent);
assert!(s.percent < 1.0);
if s.coin {
heads += 1;
}
}

assert!(900 < heads, "heads: {}", heads);
assert!(heads < 1100, "heads: {}", heads);

Custom Distributions

You can customize the distribution used for any field with the #[distribution] attribute:

use std::collections::HashMap;
use standard_dist::StandardDist;
use rand::distributions::Uniform;

#[derive(StandardDist)]
struct Die {
#[distribution(Uniform::from(1..=6))]
value: u8
}

let mut counter: HashMap<u8, u32> = HashMap::new();

for _ in 0..6000 {
let die: Die = rand::random();
*counter.entry(die.value).or_insert(0) += 1;
}

assert_eq!(counter.len(), 6);

for i in 1..=6 {
let count = counter[&i];
assert!(900 < count, "{}: {}", i, count);
assert!(count < 1100, "{}: {}", i, count);
}

Enums

When applied to an enum type, the implementation will randomly select a variant (where each variant has an equal probability) and then populate all the fields of that variant in the same manner as with a struct. Enum variant fields may have custom distributions applied via #[distribution], just like struct fields.

use standard_dist::StandardDist;

#[derive(PartialEq, Eq, StandardDist)]
enum Coin {
Heads,
Tails,
}

let mut heads = 0;

for _ in 0..2000 {
let coin: Coin = rand::random();
if coin == Coin::Heads {
heads += 1;
}
}

assert!(900 < heads, "heads: {}", heads);
assert!(heads < 1100, "heads: {}", heads);

Weights

Enum variants may be weighted with the #[weight] attribute to make them relatively more or less likely to be randomly selected. A weight of 0 means that the variant will never be selected. Any untagged variants will have a weight of 1.

use standard_dist::StandardDist;

#[derive(StandardDist)]
enum D20 {
#[weight(18)]
Normal,

CriticalHit,
CriticalMiss,
}

let mut crits = 0;

for _ in 0..20000 {
let roll: D20 = rand::random();
if matches!(roll, D20::CriticalHit) {
crits += 1;
}
}

assert!(900 < crits, "crits: {}", crits);
assert!(crits < 1100, "crits: {}", crits);

Advanced custom distributions

Distribution types

You may optionally explicitly specify a type for your distributions; this can sometimes be necessary when using generic types.

use std::collections::HashMap;
use standard_dist::StandardDist;
use rand::distributions::Uniform;

#[derive(StandardDist)]
struct Die {
#[distribution(Uniform<u8> = Uniform::from(1..=6))]
value: u8
}

let mut counter: HashMap<u8, u32> = HashMap::new();

for _ in 0..6000 {
let die: Die = rand::random();
*counter.entry(die.value).or_insert(0) += 1;
}

assert_eq!(counter.len(), 6);

for i in 1..=6 {
let count = counter[&i];
assert!(900 < count, "{}: {}", i, count);
assert!(count < 1100, "{}: {}", i, count);
}

Distribution caching

In some cases, you may wish to cache a Distribution instance for reuse. Many distributions perform some initial calculations when constructed, and it can help performance to reuse existing distributions rather than recreate them every time a value is generated. standard-dist provides two ways to cache distributions: static and once. A static distribution is stored as a global static variable; this is the preferable option, but it requires the initializer to be usable in a const context. A once distribution is stored in a once_cell::sync::OnceCell; it is initialized the first time it's used, and then reused on subsequent invocations.

In either case, a cache policy is specified by prefixing the type with once or static. The type must be specified in order to use a cache policy.

use std::collections::HashMap;
use std::time::{Instant, Duration};
use standard_dist::StandardDist;
use rand::prelude::*;
use rand::distributions::Uniform;

#[derive(StandardDist)]
struct Die {
#[distribution(Uniform::from(1..=6))]
value: u8
}

#[derive(StandardDist)]
struct CachedDie {
#[distribution(once Uniform<u8> = Uniform::from(1..=6))]
value: u8
}

fn timed<T>(task: impl FnOnce() -> T) -> (T, Duration) {
let start = Instant::now();
(task(), start.elapsed())
}

// Count the 6s
let mut rng = StdRng::from_entropy();

let (count, plain_die_duration) = timed(|| (0..600000)
.map(|_| rng.gen())
.filter(|&Die{ value }| value == 6)
.count()
);

assert!(90000 < count);
assert!(count < 110000);

let (count, cache_die_duration) = timed(|| (0..600000)
.map(|_| rng.gen())
.filter(|&CachedDie{ value }| value == 6)
.count()
);

assert!(90000 < count);
assert!(count < 110000);

assert!(
cache_die_duration < plain_die_duration,
"cache: {:?}, plain: {:?}",
cache_die_duration,
plain_die_duration,
);

Note that, unless you're generating a huge quantity of random objects, using cell is likely a pessimization because of the upfront cost to initializing the cell. Make sure to benchmark your specific use case if performance is a concern.

Dependencies

~2MB
~45K SLoC