tinydns / tinydns.c /
a809b71 4 months ago
2 contributor
278 lines | 9.252kb
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <stdint.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <time.h>
#include <errno.h>

#define DNS_PORT 53
#define BUF_SIZE 512

struct DNS_HEADER {
    uint16_t id, flags, qdcount, ancount, nscount, arcount;
};

struct QUESTION {
    uint16_t qtype, qclass;
} __attribute__((__packed__));

struct R_DATA {
    uint16_t type;
    uint16_t _class;
    uint32_t ttl;
    uint16_t data_len;
} __attribute__((__packed__));

void fqdn_to_dns_name(unsigned char *dns, const char *host) {
    const char *pos = host;
    const char *dot;
    while ((dot = strchr(pos, '.')) != NULL) {
        size_t len = dot - pos;
        *dns++ = len;
        memcpy(dns, pos, len);
        dns += len;
        pos = dot + 1;
    }
    size_t len = strlen(pos);
    *dns++ = len;
    memcpy(dns, pos, len);
    dns += len;
    *dns = '\0';
}

uint16_t qtype_from_string(const char *s) {
    if (strcasecmp(s, "A") == 0) return 1;
    if (strcasecmp(s, "AAAA") == 0) return 28;
    if (strcasecmp(s, "CNAME") == 0) return 5;
    if (strcasecmp(s, "MX") == 0) return 15;
    if (strcasecmp(s, "NS") == 0) return 2;
    fprintf(stderr, "Unknown query type: %s\n", s);
    exit(1);
}

unsigned char* skip_dns_name(unsigned char *reader) {
    while (*reader != 0) {
        if ((*reader & 0xC0) == 0xC0) {
            reader += 2;
            return reader;
        } else {
            reader += (*reader) + 1;
        }
    }
    return reader + 1;
}

unsigned char* read_dns_name(unsigned char *reader, unsigned char *buffer, char *out) {
    int p = 0, jumped = 0;
    unsigned char *orig = reader;

    while (*reader != 0) {
        if ((*reader & 0xC0) == 0xC0) {
            int offset = ((*reader & 0x3F) << 8) | *(reader + 1);
            reader = buffer + offset;
            jumped = 1;
        } else {
            int len = *reader++;
            for (int i = 0; i < len; i++) {
                out[p++] = *reader++;
            }
            out[p++] = '.';
        }
    }
    if (p > 0) out[p - 1] = '\0'; else out[0] = '\0';
    if (!jumped) reader++;
    else reader = orig + 2;
    return reader;
}

void help() {
  printf("Usage: tinydns -h <IP> [-p] [-t|-u] [--time] -q type fqdn\n");
  printf("  server : DNS Server IP address\n");
  printf("  -p : port (optional, default is 53)\n");
  printf("  -u udp (default) or -t tcp\n");
  printf("  type : query type A (default), AAAA, CNAME or MX\n");
  printf("  --time: to produce timing CSV output\n");
  printf("  fqdn : what to query\n");
}

int main(int argc, char *argv[]) {
    int opt, use_tcp = 0, time_mode = 0;
    char *dns_server = NULL;
    int dns_port = DNS_PORT;
    char *qtype_str = "A";
    char *fqdn = NULL;

    struct timeval start_tv, end_tv;
    time_t start_epoch, end_epoch = 0;
    double duration = 0.0;
    const char *status = "ERROR";

    while ((opt = getopt(argc, argv, "uth:p:q:-:")) != -1) {
        if (opt == '-') { // long option
            if (strcmp(optarg, "time") == 0) time_mode = 1;
            continue;
        }
        switch (opt) {
            case 'u': use_tcp = 0; break;
            case 't': use_tcp = 1; break;
            case 'h': dns_server = optarg; break;
            case 'p': dns_port = atoi(optarg); break;
            case 'q': qtype_str = optarg; break;
            default:
                //fprintf(stderr, "Usage: %s [-u|-t] -h server [-p port] -q type [--time] fqdn\n", argv[0]);
                help();
                exit(EXIT_FAILURE);
        }
    }

    //if (optind >= argc) { fprintf(stderr, "FQDN is required\n"); exit(EXIT_FAILURE); }
    if (optind >= argc) { help(); exit(EXIT_FAILURE); }
    fqdn = argv[optind];
    //if (!dns_server) { fprintf(stderr, "DNS server (-h) is required\n"); exit(EXIT_FAILURE); }
    if (!dns_server) { help(); exit(EXIT_FAILURE); }

    gettimeofday(&start_tv, NULL);
    start_epoch = start_tv.tv_sec;

    uint8_t buf[BUF_SIZE] = {0};
    struct DNS_HEADER *dns = (struct DNS_HEADER *) &buf;
    dns->id = htons(0x1234);
    dns->flags = htons(0x0100);
    dns->qdcount = htons(1);

    unsigned char *qname = (unsigned char*)&buf[sizeof(struct DNS_HEADER)];
    fqdn_to_dns_name(qname, fqdn);

    struct QUESTION *qinfo = (struct QUESTION *) &buf[sizeof(struct DNS_HEADER) + strlen((const char*)qname) + 1];
    qinfo->qtype = htons(qtype_from_string(qtype_str));
    qinfo->qclass = htons(1);

    int packet_len = sizeof(struct DNS_HEADER) + (strlen((const char*)qname) + 1) + sizeof(struct QUESTION);

    int sock = socket(AF_INET, use_tcp ? SOCK_STREAM : SOCK_DGRAM, 0);
    if (sock < 0) { perror("socket"); status = "ERROR"; goto output; }

    struct sockaddr_in dest;
    dest.sin_family = AF_INET;
    dest.sin_port = htons(dns_port);
    if (inet_pton(AF_INET, dns_server, &dest.sin_addr) != 1) {
        fprintf(stderr, "Invalid DNS server IP address\n");
        status = "ERROR";
        goto output;
    }

    struct timeval tv;
    tv.tv_sec = 5;
    tv.tv_usec = 0;
    setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));

    if (use_tcp) {
        if (connect(sock, (struct sockaddr*)&dest, sizeof(dest)) < 0) { perror("connect"); status = "ERROR"; goto output; }
        uint16_t netlen = htons(packet_len);
        if (send(sock, &netlen, 2, 0) < 0 || send(sock, buf, packet_len, 0) < 0) { perror("send"); status = "ERROR"; goto output; }
    } else {
        if (sendto(sock, buf, packet_len, 0, (struct sockaddr*)&dest, sizeof(dest)) < 0) { perror("sendto"); status = "ERROR"; goto output; }
    }

    int recv_len;
    if (use_tcp) {
        uint16_t resp_len;
        if (recv(sock, &resp_len, 2, MSG_WAITALL) <= 0) { status = "TIMEOUT"; goto output; }
        resp_len = ntohs(resp_len);
        recv_len = recv(sock, buf, resp_len, MSG_WAITALL);
    } else {
        socklen_t slen = sizeof(dest);
        recv_len = recvfrom(sock, buf, sizeof(buf), 0, (struct sockaddr*)&dest, &slen);
    }
    if (recv_len <= 0) { status = "TIMEOUT"; goto output; }

    gettimeofday(&end_tv, NULL);
    end_epoch = end_tv.tv_sec;
    duration = (end_tv.tv_sec - start_tv.tv_sec) + (end_tv.tv_usec - start_tv.tv_usec) / 1000000.0;

    dns = (struct DNS_HEADER*) buf;
    int rcode = ntohs(dns->flags) & 0x000F;
    if (rcode == 0) status = "NOERROR";
    else if (rcode == 3) status = "NXDOMAIN";
    else status = "ERROR";

    // Print answers human-readable if no timeout/error
    if (strcmp(status, "NOERROR") == 0) {
        unsigned char *reader = &buf[sizeof(struct DNS_HEADER)];

        for (int i = 0; i < ntohs(dns->qdcount); i++) {
            reader = skip_dns_name(reader);
            reader += sizeof(struct QUESTION);
        }

        printf("Server: %s\nQuery: %s (%s)\nStatus: %s\nDuration: %.6f s\n", dns_server, fqdn, qtype_str, status, duration);
        printf("Answers: %d\n", ntohs(dns->ancount));

        for (int i = 0; i < ntohs(dns->ancount); i++) {
            char name[256];
            reader = read_dns_name(reader, buf, name);

            struct R_DATA *res = (struct R_DATA*) reader;
            reader += sizeof(struct R_DATA);

            uint16_t type = ntohs(res->type);
            uint16_t data_len = ntohs(res->data_len);

            if (type == 1 && data_len == 4) { // A
                char ip[INET_ADDRSTRLEN];
                inet_ntop(AF_INET, reader, ip, sizeof(ip));
                printf("A: %s\n", ip);
            } else if (type == 28 && data_len == 16) { // AAAA
                char ip[INET6_ADDRSTRLEN];
                inet_ntop(AF_INET6, reader, ip, sizeof(ip));
                printf("AAAA: %s\n", ip);
            } else if (type == 5) { // CNAME
                char cname[256];
                read_dns_name(reader, buf, cname);
                printf("CNAME: %s\n", cname);
            } else if (type == 15) { // MX
                if (data_len < 3) { // sanity
                    printf("MX: <malformed>\n");
                } else {
                    uint16_t pref = ntohs(*(uint16_t*)reader);
                    char mx[256];
                    read_dns_name(reader + 2, buf, mx);
                    printf("MX: preference=%u, exchange=%s\n", pref, mx);
                }
            } else if (type == 2) { // NS
                char ns[256];
                read_dns_name(reader, buf, ns);
                printf("NS: %s\n", ns);
            } else {
                printf("Type %u: data_len=%u (not displayed)\n", type, data_len);
            }

            reader += data_len;
        }
    } else {
        // Just print brief info on error or timeout
        printf("Server: %s\nQuery: %s (%s)\nStatus: %s\nDuration: %.6f s\n", dns_server, fqdn, qtype_str, status, duration);
    }

output:
    // Print CSV timing line if requested (after all human output)
    if (time_mode) {
        printf("%s,%s,%s,%ld,%ld,%.6f,%s\n",
               dns_server ? dns_server : "",
               fqdn ? fqdn : "",
               qtype_str ? qtype_str : "",
               (long)start_epoch,
               (long)(end_epoch ? end_epoch : start_epoch),
               duration,
               status);
    }

    close(sock);
    return 0;
}