#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/time.h>
#include <time.h>
#include <errno.h>

#define NTP_PORT 123
#define NTP_TIMESTAMP_DELTA 2208988800ull

typedef struct {
    uint8_t li_vn_mode;
    uint8_t stratum;
    uint8_t poll;
    int8_t precision;
    uint32_t rootDelay;
    uint32_t rootDispersion;
    uint32_t refId;
    uint32_t refTm_s;
    uint32_t refTm_f;
    uint32_t origTm_s;
    uint32_t origTm_f;
    uint32_t rxTm_s;
    uint32_t rxTm_f;
    uint32_t txTm_s;
    uint32_t txTm_f;
} ntp_packet;

int main(int argc, char *argv[]) {
    if (argc != 2) {
        fprintf(stderr, "Usage: %s <NTP Server IP>\n", argv[0]);
        return EXIT_FAILURE;
    }

    const char *server_ip = argv[1];
    int sockfd;
    struct sockaddr_in server_addr;
    ntp_packet packet;
    socklen_t addr_len = sizeof(server_addr);
    struct timeval start, end;

    // Create UDP socket
    sockfd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
    if (sockfd < 0) {
        perror("Socket creation failed");
        return EXIT_FAILURE;
    }

    // Set 2 second receive timeout
    struct timeval timeout = {2, 0};
    setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout));

    // Server address setup
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(NTP_PORT);
    if (inet_pton(AF_INET, server_ip, &server_addr.sin_addr) != 1) {
        perror("Invalid IP address");
        close(sockfd);
        return EXIT_FAILURE;
    }

    // Prepare SNTP request packet
    memset(&packet, 0, sizeof(ntp_packet));
    packet.li_vn_mode = (0 << 6) | (4 << 3) | 3; // LI = 0, VN = 4, Mode = 3 (Client)

    printf("NTP request sent to %s. Waiting for response...\n", server_ip);
    gettimeofday(&start, NULL);

    // Send packet
    if (sendto(sockfd, &packet, sizeof(ntp_packet), 0,
               (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) {
        perror("Send failed");
        close(sockfd);
        return EXIT_FAILURE;
    }

    // Receive response
    if (recvfrom(sockfd, &packet, sizeof(ntp_packet), 0,
                 (struct sockaddr *)&server_addr, &addr_len) < 0) {
        perror("recvfrom failed or timeout");
        close(sockfd);
        return EXIT_FAILURE;
    }

    gettimeofday(&end, NULL);

    // Calculate response time in milliseconds
    double rtt = (end.tv_sec - start.tv_sec) * 1000.0 +
                 (end.tv_usec - start.tv_usec) / 1000.0;

    // Extract and convert transmit timestamp
    uint32_t txTm_s = ntohl(packet.txTm_s);
    time_t tx_time = txTm_s - NTP_TIMESTAMP_DELTA;

    char time_buffer[32];
    strftime(time_buffer, sizeof(time_buffer), "%Y-%m-%d %H:%M:%S", localtime(&tx_time));

    printf("Response received in %.2f ms\n", rtt);
    printf("Server time: %s\n", time_buffer);

    close(sockfd);
    return EXIT_SUCCESS;
}

