476 lines
15 KiB
Rust
476 lines
15 KiB
Rust
use std::{
|
|
collections::HashMap, hash::Hash, net::Ipv4Addr, ops::Deref, str::FromStr, time::Duration,
|
|
};
|
|
|
|
use base64ct::{Base64, Encoding};
|
|
use dns_update::DnsUpdater;
|
|
use log::{error, info, warn};
|
|
use rand::{Rng, SeedableRng, rngs::StdRng};
|
|
use serde::{Deserialize, Serialize, de};
|
|
use ssh_key::{PublicKey, SshSig};
|
|
use tokio::{
|
|
sync::{mpsc, oneshot},
|
|
task::JoinHandle,
|
|
time::Instant,
|
|
};
|
|
|
|
use crate::Config;
|
|
|
|
// const MAX_HOSTNAME_LENGTH: usize = 15;
|
|
|
|
// #[derive(Debug)]
|
|
// pub struct Hostname(String);
|
|
// impl Deref for Hostname {
|
|
// type Target = String;
|
|
// fn deref(&self) -> &Self::Target {
|
|
// &self.0
|
|
// }
|
|
// }
|
|
// impl ser::Serialize for Hostname {
|
|
// fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
// where
|
|
// S: ser::Serializer,
|
|
// {
|
|
// serializer.serialize_str(&self.0)
|
|
// }
|
|
// }
|
|
// impl<'de> de::Deserialize<'de> for Hostname {
|
|
// fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
// where
|
|
// D: serde::Deserializer<'de>,
|
|
// {
|
|
// <String as de::Deserialize>::deserialize(deserializer).and_then(|inner| {
|
|
// if inner.len() > MAX_HOSTNAME_LENGTH {
|
|
// Err(de::Error::invalid_length(
|
|
// inner.len(),
|
|
// &"a shorter hostname",
|
|
// ))
|
|
// } else {
|
|
// Ok(Self(inner))
|
|
// }
|
|
// })
|
|
// }
|
|
// }
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct DnsAddress(dns_update::providers::rfc2136::DnsAddress);
|
|
impl Deref for DnsAddress {
|
|
type Target = dns_update::providers::rfc2136::DnsAddress;
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
impl<'de> Deserialize<'de> for DnsAddress {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: de::Deserializer<'de>,
|
|
{
|
|
<String as de::Deserialize>::deserialize(deserializer).and_then(|inner| {
|
|
match dns_update::providers::rfc2136::DnsAddress::try_from(&inner) {
|
|
Ok(addr) => Ok(DnsAddress(addr)),
|
|
Err(_) => Err(de::Error::custom(format!("Invalid DNS address: {inner}"))),
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Deserialize, Serialize, Debug)]
|
|
pub struct Registrations(HashMap<String, Registration>);
|
|
|
|
#[derive(Deserialize, Serialize, Debug)]
|
|
pub struct Registration {
|
|
pub ip: Ipv4Addr,
|
|
pub public_key: PublicKey,
|
|
}
|
|
|
|
pub struct TsigAlgorithm(dns_update::TsigAlgorithm);
|
|
impl Deref for TsigAlgorithm {
|
|
type Target = dns_update::TsigAlgorithm;
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
impl<'de> Deserialize<'de> for TsigAlgorithm {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: de::Deserializer<'de>,
|
|
{
|
|
<String as de::Deserialize>::deserialize(deserializer).and_then(|inner| {
|
|
match dns_update::TsigAlgorithm::from_str(&inner) {
|
|
Ok(algorithm) => Ok(TsigAlgorithm(algorithm)),
|
|
Err(_) => Err(de::Error::unknown_variant(
|
|
&inner,
|
|
&[
|
|
"hmac-md5",
|
|
"gss",
|
|
"hmac-sha1",
|
|
"hmac-sha224",
|
|
"hmac-sha256",
|
|
"hmac-sha256-128",
|
|
"hmac-sha384",
|
|
"hmac-sha384-192",
|
|
"hmac-sha512",
|
|
"hmac-sha512-256",
|
|
],
|
|
)),
|
|
}
|
|
})
|
|
}
|
|
}
|
|
impl std::fmt::Debug for TsigAlgorithm {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
let algorithm_name = match self.0 {
|
|
dns_update::TsigAlgorithm::HmacMd5 => "hmac-md5",
|
|
dns_update::TsigAlgorithm::Gss => "gss",
|
|
dns_update::TsigAlgorithm::HmacSha1 => "hmac-sha1",
|
|
dns_update::TsigAlgorithm::HmacSha224 => "hmac-sha224",
|
|
dns_update::TsigAlgorithm::HmacSha256 => "hmac-sha256",
|
|
dns_update::TsigAlgorithm::HmacSha256_128 => "hmac-sha256-128",
|
|
dns_update::TsigAlgorithm::HmacSha384 => "hmac-sha384",
|
|
dns_update::TsigAlgorithm::HmacSha384_192 => "hmac-sha384-192",
|
|
dns_update::TsigAlgorithm::HmacSha512 => "hmac-sha512",
|
|
dns_update::TsigAlgorithm::HmacSha512_256 => "hmac-sha512-256",
|
|
};
|
|
write!(f, "{algorithm_name}")
|
|
}
|
|
}
|
|
impl Clone for TsigAlgorithm {
|
|
fn clone(&self) -> Self {
|
|
Self(match self.0 {
|
|
dns_update::TsigAlgorithm::HmacMd5 => dns_update::TsigAlgorithm::HmacMd5,
|
|
dns_update::TsigAlgorithm::Gss => dns_update::TsigAlgorithm::Gss,
|
|
dns_update::TsigAlgorithm::HmacSha1 => dns_update::TsigAlgorithm::HmacSha1,
|
|
dns_update::TsigAlgorithm::HmacSha224 => dns_update::TsigAlgorithm::HmacSha224,
|
|
dns_update::TsigAlgorithm::HmacSha256 => dns_update::TsigAlgorithm::HmacSha256,
|
|
dns_update::TsigAlgorithm::HmacSha256_128 => dns_update::TsigAlgorithm::HmacSha256_128,
|
|
dns_update::TsigAlgorithm::HmacSha384 => dns_update::TsigAlgorithm::HmacSha384,
|
|
dns_update::TsigAlgorithm::HmacSha384_192 => dns_update::TsigAlgorithm::HmacSha384_192,
|
|
dns_update::TsigAlgorithm::HmacSha512 => dns_update::TsigAlgorithm::HmacSha512,
|
|
dns_update::TsigAlgorithm::HmacSha512_256 => dns_update::TsigAlgorithm::HmacSha512_256,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Hash, PartialEq, Eq, Clone)]
|
|
pub struct Identifier(pub [u8; 256]);
|
|
impl Deref for Identifier {
|
|
type Target = [u8; 256];
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
impl Identifier {
|
|
pub fn to_base64(&self) -> String {
|
|
Base64::encode_string(&self.0)
|
|
}
|
|
}
|
|
|
|
pub enum Command {
|
|
CreateRequest(Box<CreateRequest>),
|
|
SignRequest(Box<SignRequest>),
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct DNSRequest {
|
|
pub hostname: String,
|
|
pub time: Instant,
|
|
pub ssh_key: PublicKey,
|
|
// TODO: Handle Ipv6
|
|
pub requested_ip: Ipv4Addr,
|
|
pub host_ip: Ipv4Addr,
|
|
}
|
|
|
|
pub struct CreateRequest {
|
|
pub request: DNSRequest,
|
|
pub response_channel: oneshot::Sender<Identifier>,
|
|
}
|
|
|
|
pub struct SignRequest {
|
|
pub identifier: Identifier,
|
|
pub signed: SshSig,
|
|
pub time: Instant,
|
|
pub host_ip: Ipv4Addr,
|
|
pub response_channel: oneshot::Sender<Result<(), RequestError>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum RequestError {
|
|
HostIpMismatch,
|
|
RequestExpired,
|
|
InvalidSignature,
|
|
RequestedIpNotAllowed,
|
|
UpdateFailed,
|
|
AlreadyRegistered,
|
|
TooManyRegistrations,
|
|
}
|
|
|
|
pub async fn start_client(rx: mpsc::Receiver<Command>, config: &Config) -> JoinHandle<()> {
|
|
tokio::spawn(run_client(rx, config.clone()))
|
|
}
|
|
|
|
struct ClientState {
|
|
config: Config,
|
|
requests: HashMap<Identifier, DNSRequest>,
|
|
dns_updater: DnsUpdater,
|
|
registrations: Registrations,
|
|
}
|
|
|
|
async fn run_client(rx: mpsc::Receiver<Command>, config: Config) {
|
|
let mut rx = rx;
|
|
let requests = HashMap::new();
|
|
let dns_updater = get_dns_updater(&config);
|
|
let registrations = init_registrations(&config).await;
|
|
|
|
init_dns_registrations(®istrations, &config, &dns_updater).await;
|
|
|
|
let mut state = ClientState {
|
|
config,
|
|
requests,
|
|
dns_updater,
|
|
registrations,
|
|
};
|
|
loop {
|
|
match rx.recv().await {
|
|
Some(cmd) => handle_cmd(cmd, &mut state).await,
|
|
None => return,
|
|
};
|
|
}
|
|
}
|
|
|
|
async fn init_registrations(config: &Config) -> Registrations {
|
|
let path = config.data_dir.join("registrations.toml");
|
|
if !path.exists() {
|
|
return Registrations(HashMap::new());
|
|
}
|
|
|
|
let contents = tokio::fs::read_to_string(path).await.unwrap();
|
|
match toml::from_str(&contents) {
|
|
Ok(registrations) => registrations,
|
|
Err(e) => {
|
|
panic!("Failed to parse registrations file: {e}");
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn init_dns_registrations(
|
|
registrations: &Registrations,
|
|
config: &Config,
|
|
updater: &DnsUpdater,
|
|
) {
|
|
for (hostname, registration) in registrations.0.iter() {
|
|
execute_dns_update(hostname, ®istration.ip, config, updater)
|
|
.await
|
|
.unwrap();
|
|
}
|
|
}
|
|
|
|
async fn write_registrations(registrations: &Registrations, config: &Config) {
|
|
let path = config.data_dir.join("registrations.toml");
|
|
let contents = toml::to_string(registrations).unwrap();
|
|
if !config.data_dir.exists() {
|
|
tokio::fs::create_dir_all(&config.data_dir).await.unwrap();
|
|
}
|
|
tokio::fs::write(path, contents).await.unwrap();
|
|
}
|
|
|
|
fn get_dns_updater(config: &Config) -> DnsUpdater {
|
|
let addr = format!("tcp://{}", config.dns_server);
|
|
let key_name = &config.tsig_key_name;
|
|
let key = match std::fs::read_to_string(&config.tsig_key_file) {
|
|
Ok(contents) => Base64::decode_vec(&contents).unwrap(),
|
|
Err(e) => panic!("Failed to read TSIG key file: {e}"),
|
|
};
|
|
let algorithm = config.tsig_algorithm.clone().0;
|
|
info!("Creating DNS client for {addr}");
|
|
dns_update::DnsUpdater::new_rfc2136_tsig(addr, key_name, key, algorithm)
|
|
.expect("Failed to create DNS client")
|
|
}
|
|
|
|
async fn handle_cmd(cmd: Command, state: &mut ClientState) {
|
|
match cmd {
|
|
Command::CreateRequest(create_request) => {
|
|
handle_create_request(*create_request, state).await
|
|
}
|
|
Command::SignRequest(sign_request) => handle_sign_request(*sign_request, state).await,
|
|
}
|
|
}
|
|
|
|
async fn handle_create_request(request: CreateRequest, state: &mut ClientState) {
|
|
let random_value: [u8; 256] = StdRng::from_os_rng().random();
|
|
let identifier = Identifier(random_value);
|
|
state
|
|
.requests
|
|
.insert(identifier.clone(), request.request.clone());
|
|
let _ = request.response_channel.send(identifier);
|
|
info!(
|
|
"Registered host update request for {} to {} from {}.",
|
|
&request.request.hostname, &request.request.requested_ip, &request.request.host_ip
|
|
)
|
|
}
|
|
|
|
async fn handle_sign_request(request: SignRequest, state: &mut ClientState) {
|
|
let dns_request = match state.requests.remove(&request.identifier) {
|
|
Some(req) => req,
|
|
None => todo!(),
|
|
};
|
|
|
|
// The request must be submitted and signed by the same host
|
|
if dns_request.host_ip != request.host_ip {
|
|
warn!(
|
|
"Host IP mismatch for signed request for host {}: expected from {}, received from {}",
|
|
dns_request.hostname, dns_request.host_ip, request.host_ip
|
|
);
|
|
let _ = request
|
|
.response_channel
|
|
.send(Err(RequestError::HostIpMismatch));
|
|
return;
|
|
}
|
|
|
|
// The requested IP must be within the allowed networks
|
|
if !state
|
|
.config
|
|
.networks
|
|
.iter()
|
|
.any(|net| net.contains(&dns_request.requested_ip))
|
|
{
|
|
warn!(
|
|
"Requested IP {} for host {} is not allowed",
|
|
dns_request.requested_ip, dns_request.hostname
|
|
);
|
|
let _ = request
|
|
.response_channel
|
|
.send(Err(RequestError::RequestedIpNotAllowed));
|
|
return;
|
|
}
|
|
|
|
// The signed request must be submitted within 10 seconds of request creation
|
|
if request.time - dns_request.time > Duration::from_secs(10) {
|
|
warn!("Request expired: signed request is older than 10 seconds");
|
|
let _ = request
|
|
.response_channel
|
|
.send(Err(RequestError::RequestExpired));
|
|
return;
|
|
}
|
|
|
|
// Verify the SSH signature
|
|
let original_text = format!(
|
|
"{} {} {}",
|
|
request.identifier.to_base64(),
|
|
dns_request.hostname,
|
|
dns_request.requested_ip
|
|
);
|
|
match dns_request
|
|
.ssh_key
|
|
.verify("madd", original_text.as_bytes(), &request.signed)
|
|
{
|
|
Ok(_) => (),
|
|
Err(e) => {
|
|
warn!(
|
|
"Failed to verify SSH signature for host {}: {}",
|
|
dns_request.hostname, e
|
|
);
|
|
let _ = request
|
|
.response_channel
|
|
.send(Err(RequestError::InvalidSignature));
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Check with the registrations if the hostname has not already been registered
|
|
if let Some(registration) = state.registrations.0.get(&dns_request.hostname)
|
|
&& registration.public_key.key_data() != dns_request.ssh_key.key_data()
|
|
{
|
|
warn!(
|
|
"Host {} attempted to register {} which has already been registered by a different host",
|
|
dns_request.host_ip, dns_request.hostname
|
|
);
|
|
let _ = request
|
|
.response_channel
|
|
.send(Err(RequestError::AlreadyRegistered));
|
|
return;
|
|
}
|
|
|
|
// Restrict the number of self-registrations per host
|
|
if !state.registrations.0.contains_key(&dns_request.hostname)
|
|
&& state
|
|
.registrations
|
|
.0
|
|
.iter()
|
|
.filter(|(_, registration)| {
|
|
registration.public_key.key_data() == dns_request.ssh_key.key_data()
|
|
})
|
|
.count()
|
|
>= state.config.registration_limit
|
|
{
|
|
warn!(
|
|
"Host {} attempted to register {} but has already reached the registration limit",
|
|
dns_request.host_ip, dns_request.hostname
|
|
);
|
|
let _ = request
|
|
.response_channel
|
|
.send(Err(RequestError::TooManyRegistrations));
|
|
return;
|
|
}
|
|
|
|
// Create the registration
|
|
state.registrations.0.insert(
|
|
dns_request.hostname.clone(),
|
|
Registration {
|
|
ip: dns_request.requested_ip,
|
|
public_key: dns_request.ssh_key,
|
|
},
|
|
);
|
|
write_registrations(&state.registrations, &state.config).await;
|
|
|
|
// Execute the DNS update request
|
|
match execute_dns_update(
|
|
&dns_request.hostname,
|
|
&dns_request.requested_ip,
|
|
&state.config,
|
|
&state.dns_updater,
|
|
)
|
|
.await
|
|
{
|
|
Ok(_) => (),
|
|
Err(e) => {
|
|
let _ = request.response_channel.send(Err(e));
|
|
return;
|
|
}
|
|
};
|
|
|
|
let _ = request.response_channel.send(Ok(()));
|
|
info!(
|
|
"Executed host update request for {} to {} from {}.",
|
|
&dns_request.hostname, &dns_request.requested_ip, &dns_request.host_ip
|
|
)
|
|
}
|
|
|
|
async fn execute_dns_update(
|
|
hostname: &String,
|
|
ip: &Ipv4Addr,
|
|
config: &Config,
|
|
updater: &DnsUpdater,
|
|
) -> Result<(), RequestError> {
|
|
let name = &format!("{}.{}", hostname, config.zone);
|
|
let record = dns_update::DnsRecord::A { content: *ip };
|
|
let ttl = 60;
|
|
let origin = &config.zone;
|
|
match updater
|
|
.delete(name, origin, dns_update::DnsRecordType::A)
|
|
.await
|
|
{
|
|
Ok(_) => (),
|
|
Err(e) => {
|
|
error!("Failed to delete existing DNS record: {e}");
|
|
return Err(RequestError::UpdateFailed);
|
|
}
|
|
};
|
|
match updater.create(name, record, ttl, origin).await {
|
|
Ok(_) => Ok(()),
|
|
Err(e) => {
|
|
error!("Failed to create DNS record: {e}");
|
|
Err(RequestError::UpdateFailed)
|
|
}
|
|
}
|
|
}
|