MADD/src/dns.rs
2025-08-10 01:14:56 +02:00

476 lines
15 KiB
Rust

use std::{
collections::HashMap, hash::Hash, net::Ipv4Addr, ops::Deref, str::FromStr, time::Duration,
};
use base64ct::{Base64, Encoding};
use dns_update::DnsUpdater;
use log::{error, info, warn};
use rand::{Rng, SeedableRng, rngs::StdRng};
use serde::{Deserialize, Serialize, de};
use ssh_key::{PublicKey, SshSig};
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
time::Instant,
};
use crate::Config;
// const MAX_HOSTNAME_LENGTH: usize = 15;
// #[derive(Debug)]
// pub struct Hostname(String);
// impl Deref for Hostname {
// type Target = String;
// fn deref(&self) -> &Self::Target {
// &self.0
// }
// }
// impl ser::Serialize for Hostname {
// fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
// where
// S: ser::Serializer,
// {
// serializer.serialize_str(&self.0)
// }
// }
// impl<'de> de::Deserialize<'de> for Hostname {
// fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
// where
// D: serde::Deserializer<'de>,
// {
// <String as de::Deserialize>::deserialize(deserializer).and_then(|inner| {
// if inner.len() > MAX_HOSTNAME_LENGTH {
// Err(de::Error::invalid_length(
// inner.len(),
// &"a shorter hostname",
// ))
// } else {
// Ok(Self(inner))
// }
// })
// }
// }
#[derive(Debug, Clone)]
pub struct DnsAddress(dns_update::providers::rfc2136::DnsAddress);
impl Deref for DnsAddress {
type Target = dns_update::providers::rfc2136::DnsAddress;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de> Deserialize<'de> for DnsAddress {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
<String as de::Deserialize>::deserialize(deserializer).and_then(|inner| {
match dns_update::providers::rfc2136::DnsAddress::try_from(&inner) {
Ok(addr) => Ok(DnsAddress(addr)),
Err(_) => Err(de::Error::custom(format!("Invalid DNS address: {inner}"))),
}
})
}
}
#[derive(Deserialize, Serialize, Debug)]
pub struct Registrations(HashMap<String, Registration>);
#[derive(Deserialize, Serialize, Debug)]
pub struct Registration {
pub ip: Ipv4Addr,
pub public_key: PublicKey,
}
pub struct TsigAlgorithm(dns_update::TsigAlgorithm);
impl Deref for TsigAlgorithm {
type Target = dns_update::TsigAlgorithm;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de> Deserialize<'de> for TsigAlgorithm {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
<String as de::Deserialize>::deserialize(deserializer).and_then(|inner| {
match dns_update::TsigAlgorithm::from_str(&inner) {
Ok(algorithm) => Ok(TsigAlgorithm(algorithm)),
Err(_) => Err(de::Error::unknown_variant(
&inner,
&[
"hmac-md5",
"gss",
"hmac-sha1",
"hmac-sha224",
"hmac-sha256",
"hmac-sha256-128",
"hmac-sha384",
"hmac-sha384-192",
"hmac-sha512",
"hmac-sha512-256",
],
)),
}
})
}
}
impl std::fmt::Debug for TsigAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let algorithm_name = match self.0 {
dns_update::TsigAlgorithm::HmacMd5 => "hmac-md5",
dns_update::TsigAlgorithm::Gss => "gss",
dns_update::TsigAlgorithm::HmacSha1 => "hmac-sha1",
dns_update::TsigAlgorithm::HmacSha224 => "hmac-sha224",
dns_update::TsigAlgorithm::HmacSha256 => "hmac-sha256",
dns_update::TsigAlgorithm::HmacSha256_128 => "hmac-sha256-128",
dns_update::TsigAlgorithm::HmacSha384 => "hmac-sha384",
dns_update::TsigAlgorithm::HmacSha384_192 => "hmac-sha384-192",
dns_update::TsigAlgorithm::HmacSha512 => "hmac-sha512",
dns_update::TsigAlgorithm::HmacSha512_256 => "hmac-sha512-256",
};
write!(f, "{algorithm_name}")
}
}
impl Clone for TsigAlgorithm {
fn clone(&self) -> Self {
Self(match self.0 {
dns_update::TsigAlgorithm::HmacMd5 => dns_update::TsigAlgorithm::HmacMd5,
dns_update::TsigAlgorithm::Gss => dns_update::TsigAlgorithm::Gss,
dns_update::TsigAlgorithm::HmacSha1 => dns_update::TsigAlgorithm::HmacSha1,
dns_update::TsigAlgorithm::HmacSha224 => dns_update::TsigAlgorithm::HmacSha224,
dns_update::TsigAlgorithm::HmacSha256 => dns_update::TsigAlgorithm::HmacSha256,
dns_update::TsigAlgorithm::HmacSha256_128 => dns_update::TsigAlgorithm::HmacSha256_128,
dns_update::TsigAlgorithm::HmacSha384 => dns_update::TsigAlgorithm::HmacSha384,
dns_update::TsigAlgorithm::HmacSha384_192 => dns_update::TsigAlgorithm::HmacSha384_192,
dns_update::TsigAlgorithm::HmacSha512 => dns_update::TsigAlgorithm::HmacSha512,
dns_update::TsigAlgorithm::HmacSha512_256 => dns_update::TsigAlgorithm::HmacSha512_256,
})
}
}
#[derive(Hash, PartialEq, Eq, Clone)]
pub struct Identifier(pub [u8; 256]);
impl Deref for Identifier {
type Target = [u8; 256];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Identifier {
pub fn to_base64(&self) -> String {
Base64::encode_string(&self.0)
}
}
pub enum Command {
CreateRequest(Box<CreateRequest>),
SignRequest(Box<SignRequest>),
}
#[derive(Clone)]
pub struct DNSRequest {
pub hostname: String,
pub time: Instant,
pub ssh_key: PublicKey,
// TODO: Handle Ipv6
pub requested_ip: Ipv4Addr,
pub host_ip: Ipv4Addr,
}
pub struct CreateRequest {
pub request: DNSRequest,
pub response_channel: oneshot::Sender<Identifier>,
}
pub struct SignRequest {
pub identifier: Identifier,
pub signed: SshSig,
pub time: Instant,
pub host_ip: Ipv4Addr,
pub response_channel: oneshot::Sender<Result<(), RequestError>>,
}
#[derive(Debug, Clone)]
pub enum RequestError {
HostIpMismatch,
RequestExpired,
InvalidSignature,
RequestedIpNotAllowed,
UpdateFailed,
AlreadyRegistered,
TooManyRegistrations,
}
pub async fn start_client(rx: mpsc::Receiver<Command>, config: &Config) -> JoinHandle<()> {
tokio::spawn(run_client(rx, config.clone()))
}
struct ClientState {
config: Config,
requests: HashMap<Identifier, DNSRequest>,
dns_updater: DnsUpdater,
registrations: Registrations,
}
async fn run_client(rx: mpsc::Receiver<Command>, config: Config) {
let mut rx = rx;
let requests = HashMap::new();
let dns_updater = get_dns_updater(&config);
let registrations = init_registrations(&config).await;
init_dns_registrations(&registrations, &config, &dns_updater).await;
let mut state = ClientState {
config,
requests,
dns_updater,
registrations,
};
loop {
match rx.recv().await {
Some(cmd) => handle_cmd(cmd, &mut state).await,
None => return,
};
}
}
async fn init_registrations(config: &Config) -> Registrations {
let path = config.data_dir.join("registrations.toml");
if !path.exists() {
return Registrations(HashMap::new());
}
let contents = tokio::fs::read_to_string(path).await.unwrap();
match toml::from_str(&contents) {
Ok(registrations) => registrations,
Err(e) => {
panic!("Failed to parse registrations file: {e}");
}
}
}
async fn init_dns_registrations(
registrations: &Registrations,
config: &Config,
updater: &DnsUpdater,
) {
for (hostname, registration) in registrations.0.iter() {
execute_dns_update(hostname, &registration.ip, config, updater)
.await
.unwrap();
}
}
async fn write_registrations(registrations: &Registrations, config: &Config) {
let path = config.data_dir.join("registrations.toml");
let contents = toml::to_string(registrations).unwrap();
if !config.data_dir.exists() {
tokio::fs::create_dir_all(&config.data_dir).await.unwrap();
}
tokio::fs::write(path, contents).await.unwrap();
}
fn get_dns_updater(config: &Config) -> DnsUpdater {
let addr = format!("tcp://{}", config.dns_server);
let key_name = &config.tsig_key_name;
let key = match std::fs::read_to_string(&config.tsig_key_file) {
Ok(contents) => Base64::decode_vec(&contents).unwrap(),
Err(e) => panic!("Failed to read TSIG key file: {e}"),
};
let algorithm = config.tsig_algorithm.clone().0;
info!("Creating DNS client for {addr}");
dns_update::DnsUpdater::new_rfc2136_tsig(addr, key_name, key, algorithm)
.expect("Failed to create DNS client")
}
async fn handle_cmd(cmd: Command, state: &mut ClientState) {
match cmd {
Command::CreateRequest(create_request) => {
handle_create_request(*create_request, state).await
}
Command::SignRequest(sign_request) => handle_sign_request(*sign_request, state).await,
}
}
async fn handle_create_request(request: CreateRequest, state: &mut ClientState) {
let random_value: [u8; 256] = StdRng::from_os_rng().random();
let identifier = Identifier(random_value);
state
.requests
.insert(identifier.clone(), request.request.clone());
let _ = request.response_channel.send(identifier);
info!(
"Registered host update request for {} to {} from {}.",
&request.request.hostname, &request.request.requested_ip, &request.request.host_ip
)
}
async fn handle_sign_request(request: SignRequest, state: &mut ClientState) {
let dns_request = match state.requests.remove(&request.identifier) {
Some(req) => req,
None => todo!(),
};
// The request must be submitted and signed by the same host
if dns_request.host_ip != request.host_ip {
warn!(
"Host IP mismatch for signed request for host {}: expected from {}, received from {}",
dns_request.hostname, dns_request.host_ip, request.host_ip
);
let _ = request
.response_channel
.send(Err(RequestError::HostIpMismatch));
return;
}
// The requested IP must be within the allowed networks
if !state
.config
.networks
.iter()
.any(|net| net.contains(&dns_request.requested_ip))
{
warn!(
"Requested IP {} for host {} is not allowed",
dns_request.requested_ip, dns_request.hostname
);
let _ = request
.response_channel
.send(Err(RequestError::RequestedIpNotAllowed));
return;
}
// The signed request must be submitted within 10 seconds of request creation
if request.time - dns_request.time > Duration::from_secs(10) {
warn!("Request expired: signed request is older than 10 seconds");
let _ = request
.response_channel
.send(Err(RequestError::RequestExpired));
return;
}
// Verify the SSH signature
let original_text = format!(
"{} {} {}",
request.identifier.to_base64(),
dns_request.hostname,
dns_request.requested_ip
);
match dns_request
.ssh_key
.verify("madd", original_text.as_bytes(), &request.signed)
{
Ok(_) => (),
Err(e) => {
warn!(
"Failed to verify SSH signature for host {}: {}",
dns_request.hostname, e
);
let _ = request
.response_channel
.send(Err(RequestError::InvalidSignature));
return;
}
}
// Check with the registrations if the hostname has not already been registered
if let Some(registration) = state.registrations.0.get(&dns_request.hostname)
&& registration.public_key.key_data() != dns_request.ssh_key.key_data()
{
warn!(
"Host {} attempted to register {} which has already been registered by a different host",
dns_request.host_ip, dns_request.hostname
);
let _ = request
.response_channel
.send(Err(RequestError::AlreadyRegistered));
return;
}
// Restrict the number of self-registrations per host
if !state.registrations.0.contains_key(&dns_request.hostname)
&& state
.registrations
.0
.iter()
.filter(|(_, registration)| {
registration.public_key.key_data() == dns_request.ssh_key.key_data()
})
.count()
>= state.config.registration_limit
{
warn!(
"Host {} attempted to register {} but has already reached the registration limit",
dns_request.host_ip, dns_request.hostname
);
let _ = request
.response_channel
.send(Err(RequestError::TooManyRegistrations));
return;
}
// Create the registration
state.registrations.0.insert(
dns_request.hostname.clone(),
Registration {
ip: dns_request.requested_ip,
public_key: dns_request.ssh_key,
},
);
write_registrations(&state.registrations, &state.config).await;
// Execute the DNS update request
match execute_dns_update(
&dns_request.hostname,
&dns_request.requested_ip,
&state.config,
&state.dns_updater,
)
.await
{
Ok(_) => (),
Err(e) => {
let _ = request.response_channel.send(Err(e));
return;
}
};
let _ = request.response_channel.send(Ok(()));
info!(
"Executed host update request for {} to {} from {}.",
&dns_request.hostname, &dns_request.requested_ip, &dns_request.host_ip
)
}
async fn execute_dns_update(
hostname: &String,
ip: &Ipv4Addr,
config: &Config,
updater: &DnsUpdater,
) -> Result<(), RequestError> {
let name = &format!("{}.{}", hostname, config.zone);
let record = dns_update::DnsRecord::A { content: *ip };
let ttl = 60;
let origin = &config.zone;
match updater
.delete(name, origin, dns_update::DnsRecordType::A)
.await
{
Ok(_) => (),
Err(e) => {
error!("Failed to delete existing DNS record: {e}");
return Err(RequestError::UpdateFailed);
}
};
match updater.create(name, record, ttl, origin).await {
Ok(_) => Ok(()),
Err(e) => {
error!("Failed to create DNS record: {e}");
Err(RequestError::UpdateFailed)
}
}
}