use std::{ collections::HashMap, hash::Hash, net::Ipv4Addr, ops::Deref, str::FromStr, time::Duration, }; use base64ct::{Base64, Encoding}; use dns_update::DnsUpdater; use log::{error, info, warn}; use rand::{Rng, SeedableRng, rngs::StdRng}; use serde::{Deserialize, Serialize, de}; use ssh_key::{PublicKey, SshSig}; use tokio::{ sync::{mpsc, oneshot}, task::JoinHandle, time::Instant, }; use crate::Config; // const MAX_HOSTNAME_LENGTH: usize = 15; // #[derive(Debug)] // pub struct Hostname(String); // impl Deref for Hostname { // type Target = String; // fn deref(&self) -> &Self::Target { // &self.0 // } // } // impl ser::Serialize for Hostname { // fn serialize(&self, serializer: S) -> Result // where // S: ser::Serializer, // { // serializer.serialize_str(&self.0) // } // } // impl<'de> de::Deserialize<'de> for Hostname { // fn deserialize(deserializer: D) -> Result // where // D: serde::Deserializer<'de>, // { // ::deserialize(deserializer).and_then(|inner| { // if inner.len() > MAX_HOSTNAME_LENGTH { // Err(de::Error::invalid_length( // inner.len(), // &"a shorter hostname", // )) // } else { // Ok(Self(inner)) // } // }) // } // } #[derive(Debug, Clone)] pub struct DnsAddress(dns_update::providers::rfc2136::DnsAddress); impl Deref for DnsAddress { type Target = dns_update::providers::rfc2136::DnsAddress; fn deref(&self) -> &Self::Target { &self.0 } } impl<'de> Deserialize<'de> for DnsAddress { fn deserialize(deserializer: D) -> Result where D: de::Deserializer<'de>, { ::deserialize(deserializer).and_then(|inner| { match dns_update::providers::rfc2136::DnsAddress::try_from(&inner) { Ok(addr) => Ok(DnsAddress(addr)), Err(_) => Err(de::Error::custom(format!("Invalid DNS address: {inner}"))), } }) } } #[derive(Deserialize, Serialize, Debug)] pub struct Registrations(HashMap); #[derive(Deserialize, Serialize, Debug)] pub struct Registration { pub ip: Ipv4Addr, pub public_key: PublicKey, } pub struct TsigAlgorithm(dns_update::TsigAlgorithm); impl Deref for TsigAlgorithm { type Target = dns_update::TsigAlgorithm; fn deref(&self) -> &Self::Target { &self.0 } } impl<'de> Deserialize<'de> for TsigAlgorithm { fn deserialize(deserializer: D) -> Result where D: de::Deserializer<'de>, { ::deserialize(deserializer).and_then(|inner| { match dns_update::TsigAlgorithm::from_str(&inner) { Ok(algorithm) => Ok(TsigAlgorithm(algorithm)), Err(_) => Err(de::Error::unknown_variant( &inner, &[ "hmac-md5", "gss", "hmac-sha1", "hmac-sha224", "hmac-sha256", "hmac-sha256-128", "hmac-sha384", "hmac-sha384-192", "hmac-sha512", "hmac-sha512-256", ], )), } }) } } impl std::fmt::Debug for TsigAlgorithm { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let algorithm_name = match self.0 { dns_update::TsigAlgorithm::HmacMd5 => "hmac-md5", dns_update::TsigAlgorithm::Gss => "gss", dns_update::TsigAlgorithm::HmacSha1 => "hmac-sha1", dns_update::TsigAlgorithm::HmacSha224 => "hmac-sha224", dns_update::TsigAlgorithm::HmacSha256 => "hmac-sha256", dns_update::TsigAlgorithm::HmacSha256_128 => "hmac-sha256-128", dns_update::TsigAlgorithm::HmacSha384 => "hmac-sha384", dns_update::TsigAlgorithm::HmacSha384_192 => "hmac-sha384-192", dns_update::TsigAlgorithm::HmacSha512 => "hmac-sha512", dns_update::TsigAlgorithm::HmacSha512_256 => "hmac-sha512-256", }; write!(f, "{algorithm_name}") } } impl Clone for TsigAlgorithm { fn clone(&self) -> Self { Self(match self.0 { dns_update::TsigAlgorithm::HmacMd5 => dns_update::TsigAlgorithm::HmacMd5, dns_update::TsigAlgorithm::Gss => dns_update::TsigAlgorithm::Gss, dns_update::TsigAlgorithm::HmacSha1 => dns_update::TsigAlgorithm::HmacSha1, dns_update::TsigAlgorithm::HmacSha224 => dns_update::TsigAlgorithm::HmacSha224, dns_update::TsigAlgorithm::HmacSha256 => dns_update::TsigAlgorithm::HmacSha256, dns_update::TsigAlgorithm::HmacSha256_128 => dns_update::TsigAlgorithm::HmacSha256_128, dns_update::TsigAlgorithm::HmacSha384 => dns_update::TsigAlgorithm::HmacSha384, dns_update::TsigAlgorithm::HmacSha384_192 => dns_update::TsigAlgorithm::HmacSha384_192, dns_update::TsigAlgorithm::HmacSha512 => dns_update::TsigAlgorithm::HmacSha512, dns_update::TsigAlgorithm::HmacSha512_256 => dns_update::TsigAlgorithm::HmacSha512_256, }) } } #[derive(Hash, PartialEq, Eq, Clone)] pub struct Identifier(pub [u8; 256]); impl Deref for Identifier { type Target = [u8; 256]; fn deref(&self) -> &Self::Target { &self.0 } } impl Identifier { pub fn to_base64(&self) -> String { Base64::encode_string(&self.0) } } pub enum Command { CreateRequest(Box), SignRequest(Box), } #[derive(Clone)] pub struct DNSRequest { pub hostname: String, pub time: Instant, pub ssh_key: PublicKey, // TODO: Handle Ipv6 pub requested_ip: Ipv4Addr, pub host_ip: Ipv4Addr, } pub struct CreateRequest { pub request: DNSRequest, pub response_channel: oneshot::Sender, } pub struct SignRequest { pub identifier: Identifier, pub signed: SshSig, pub time: Instant, pub host_ip: Ipv4Addr, pub response_channel: oneshot::Sender>, } #[derive(Debug, Clone)] pub enum RequestError { HostIpMismatch, RequestExpired, InvalidSignature, RequestedIpNotAllowed, UpdateFailed, AlreadyRegistered, TooManyRegistrations, } pub async fn start_client(rx: mpsc::Receiver, config: &Config) -> JoinHandle<()> { tokio::spawn(run_client(rx, config.clone())) } struct ClientState { config: Config, requests: HashMap, dns_updater: DnsUpdater, registrations: Registrations, } async fn run_client(rx: mpsc::Receiver, config: Config) { let mut rx = rx; let requests = HashMap::new(); let dns_updater = get_dns_updater(&config); let registrations = init_registrations(&config).await; init_dns_registrations(®istrations, &config, &dns_updater).await; let mut state = ClientState { config, requests, dns_updater, registrations, }; loop { match rx.recv().await { Some(cmd) => handle_cmd(cmd, &mut state).await, None => return, }; } } async fn init_registrations(config: &Config) -> Registrations { let path = config.data_dir.join("registrations.toml"); if !path.exists() { return Registrations(HashMap::new()); } let contents = tokio::fs::read_to_string(path).await.unwrap(); match toml::from_str(&contents) { Ok(registrations) => registrations, Err(e) => { panic!("Failed to parse registrations file: {e}"); } } } async fn init_dns_registrations( registrations: &Registrations, config: &Config, updater: &DnsUpdater, ) { for (hostname, registration) in registrations.0.iter() { execute_dns_update(hostname, ®istration.ip, config, updater) .await .unwrap(); } } async fn write_registrations(registrations: &Registrations, config: &Config) { let path = config.data_dir.join("registrations.toml"); let contents = toml::to_string(registrations).unwrap(); if !config.data_dir.exists() { tokio::fs::create_dir_all(&config.data_dir).await.unwrap(); } tokio::fs::write(path, contents).await.unwrap(); } fn get_dns_updater(config: &Config) -> DnsUpdater { let addr = format!("tcp://{}", config.dns_server); let key_name = &config.tsig_key_name; let key = match std::fs::read_to_string(&config.tsig_key_file) { Ok(contents) => Base64::decode_vec(&contents).unwrap(), Err(e) => panic!("Failed to read TSIG key file: {e}"), }; let algorithm = config.tsig_algorithm.clone().0; info!("Creating DNS client for {addr}"); dns_update::DnsUpdater::new_rfc2136_tsig(addr, key_name, key, algorithm) .expect("Failed to create DNS client") } async fn handle_cmd(cmd: Command, state: &mut ClientState) { match cmd { Command::CreateRequest(create_request) => { handle_create_request(*create_request, state).await } Command::SignRequest(sign_request) => handle_sign_request(*sign_request, state).await, } } async fn handle_create_request(request: CreateRequest, state: &mut ClientState) { let random_value: [u8; 256] = StdRng::from_os_rng().random(); let identifier = Identifier(random_value); state .requests .insert(identifier.clone(), request.request.clone()); let _ = request.response_channel.send(identifier); info!( "Registered host update request for {} to {} from {}.", &request.request.hostname, &request.request.requested_ip, &request.request.host_ip ) } async fn handle_sign_request(request: SignRequest, state: &mut ClientState) { let dns_request = match state.requests.remove(&request.identifier) { Some(req) => req, None => todo!(), }; // The request must be submitted and signed by the same host if dns_request.host_ip != request.host_ip { warn!( "Host IP mismatch for signed request for host {}: expected from {}, received from {}", dns_request.hostname, dns_request.host_ip, request.host_ip ); let _ = request .response_channel .send(Err(RequestError::HostIpMismatch)); return; } // The requested IP must be within the allowed networks if !state .config .networks .iter() .any(|net| net.contains(&dns_request.requested_ip)) { warn!( "Requested IP {} for host {} is not allowed", dns_request.requested_ip, dns_request.hostname ); let _ = request .response_channel .send(Err(RequestError::RequestedIpNotAllowed)); return; } // The signed request must be submitted within 10 seconds of request creation if request.time - dns_request.time > Duration::from_secs(10) { warn!("Request expired: signed request is older than 10 seconds"); let _ = request .response_channel .send(Err(RequestError::RequestExpired)); return; } // Verify the SSH signature let original_text = format!( "{} {} {}", request.identifier.to_base64(), dns_request.hostname, dns_request.requested_ip ); match dns_request .ssh_key .verify("madd", original_text.as_bytes(), &request.signed) { Ok(_) => (), Err(e) => { warn!( "Failed to verify SSH signature for host {}: {}", dns_request.hostname, e ); let _ = request .response_channel .send(Err(RequestError::InvalidSignature)); return; } } // Check with the registrations if the hostname has not already been registered if let Some(registration) = state.registrations.0.get(&dns_request.hostname) && registration.public_key.key_data() != dns_request.ssh_key.key_data() { warn!( "Host {} attempted to register {} which has already been registered by a different host", dns_request.host_ip, dns_request.hostname ); let _ = request .response_channel .send(Err(RequestError::AlreadyRegistered)); return; } // Restrict the number of self-registrations per host if !state.registrations.0.contains_key(&dns_request.hostname) && state .registrations .0 .iter() .filter(|(_, registration)| { registration.public_key.key_data() == dns_request.ssh_key.key_data() }) .count() >= state.config.registration_limit { warn!( "Host {} attempted to register {} but has already reached the registration limit", dns_request.host_ip, dns_request.hostname ); let _ = request .response_channel .send(Err(RequestError::TooManyRegistrations)); return; } // Create the registration state.registrations.0.insert( dns_request.hostname.clone(), Registration { ip: dns_request.requested_ip, public_key: dns_request.ssh_key, }, ); write_registrations(&state.registrations, &state.config).await; // Execute the DNS update request match execute_dns_update( &dns_request.hostname, &dns_request.requested_ip, &state.config, &state.dns_updater, ) .await { Ok(_) => (), Err(e) => { let _ = request.response_channel.send(Err(e)); return; } }; let _ = request.response_channel.send(Ok(())); info!( "Executed host update request for {} to {} from {}.", &dns_request.hostname, &dns_request.requested_ip, &dns_request.host_ip ) } async fn execute_dns_update( hostname: &String, ip: &Ipv4Addr, config: &Config, updater: &DnsUpdater, ) -> Result<(), RequestError> { let name = &format!("{}.{}", hostname, config.zone); let record = dns_update::DnsRecord::A { content: *ip }; let ttl = 60; let origin = &config.zone; match updater .delete(name, origin, dns_update::DnsRecordType::A) .await { Ok(_) => (), Err(e) => { error!("Failed to delete existing DNS record: {e}"); return Err(RequestError::UpdateFailed); } }; match updater.create(name, record, ttl, origin).await { Ok(_) => Ok(()), Err(e) => { error!("Failed to create DNS record: {e}"); Err(RequestError::UpdateFailed) } } }