#define _LARGEFILE64_SOURCE #define _GNU_SOURCE #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include /* http://linux.derkeiler.com/Mailing-Lists/Kernel/2003-09/2332.html */ #define INIT_PASSWD "NBDMAGIC" #define INIT_MAGIC 0x0000420281861253 #define REQUEST_MAGIC 0x25609513 #define REPLY_MAGIC 0x67446698 #define REQUEST_READ 0 #define REQUEST_WRITE 1 #define REQUEST_DISCONNECT 2 #include struct nbd_init { char passwd[8]; __be64 magic; __be64 size; char reserved[128]; }; struct nbd_request { __be32 magic; __be32 type; /* == READ || == WRITE */ char handle[8]; __be64 from; __be32 len; } __attribute__((packed)); struct nbd_reply { __be32 magic; __be32 error; /* 0 = ok, else error */ char handle[8]; /* handle you got from request */ }; void syntax() { fprintf(stderr, "Syntax: flexnbd serve [ip addresses ...]\n" " flexnbd read > data\n" " flexnbd write [length] < data\n" " flexnbd mirror \n" ); exit(1); } static pthread_t server_thread_id; void error(int consult_errno, int close_socket, const char* format, ...) { va_list argptr; fprintf(stderr, "*** "); va_start(argptr, format); vfprintf(stderr, format, argptr); va_end(argptr); if (consult_errno) { fprintf(stderr, " (errno=%d, %s)", errno, strerror(errno)); } if (close_socket) close(close_socket); fprintf(stderr, "\n"); if (pthread_equal(pthread_self(), server_thread_id)) pthread_exit((void*) 1); else exit(1); } #ifndef DEBUG # define debug(msg, ...) #else # include # define debug(msg, ...) fprintf(stderr, "% 4d % 4d: " msg "\n" , \ (int) pthread_self(), (int) times(NULL), ##__VA_ARGS__) #endif #define CLIENT_ERROR(msg, ...) \ error(0, client->socket, msg, ##__VA_ARGS__) #define CLIENT_ERROR_ON_FAILURE(test, msg, ...) \ if (test < 0) { error(1, client->socket, msg, ##__VA_ARGS__); } #define SERVER_ERROR(msg, ...) \ error(0, 0, msg, ##__VA_ARGS__) #define SERVER_ERROR_ON_FAILURE(test, msg, ...) \ if (test < 0) { error(1, 0, msg, ##__VA_ARGS__); } void* xmalloc(size_t size) { void* p = malloc(size); if (p == NULL) SERVER_ERROR("couldn't malloc %d bytes", size); return p; } struct ip_and_mask { /* FIXME */ }; struct mode_serve_params { union { struct sockaddr generic; struct sockaddr_in v4; struct sockaddr_in6 v6; } bind_to; struct ip_and_mask** acl; char* filename; int tcp_backlog; int server; int threads; }; struct client_params { int socket; char* filename; int fileno; off64_t size; char* mapped; }; union mode_params { struct mode_serve_params serve; }; int writeloop(int filedes, const void *buffer, size_t size) { size_t written=0; while (written < size) { size_t result = write(filedes, buffer+written, size-written); if (result == -1) return -1; written += result; } return 0; } int readloop(int filedes, void *buffer, size_t size) { size_t readden=0; while (readden < size) { size_t result = read(filedes, buffer+readden, size-readden); if (result == 0 /* EOF */ || result == -1 /* error */) return -1; readden += result; } return 0; } int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count) { size_t sent=0; while (sent < count) { size_t result = sendfile64(out_fd, in_fd, offset+sent, count-sent); if (result == -1) return -1; sent += result; } return 0; } int client_serve_request(struct client_params* client) { off64_t offset; struct nbd_request request; struct nbd_reply reply; if (readloop(client->socket, &request, sizeof(request)) == -1) { if (errno == 0) { debug("EOF reading request"); return 1; /* neat point to close the socket */ } else { CLIENT_ERROR_ON_FAILURE(-1, "Error reading request"); } } reply.magic = htobe32(REPLY_MAGIC); reply.error = htobe32(0); memcpy(reply.handle, request.handle, 8); debug("request type %d", be32toh(request.type)); if (be32toh(request.magic) != REQUEST_MAGIC) CLIENT_ERROR("Bad magic %08x", be32toh(request.magic)); switch (be32toh(request.type)) { case REQUEST_READ: case REQUEST_WRITE: /* check it's not out of range */ if (be64toh(request.from) < 0 || be64toh(request.from)+be32toh(request.len) > client->size) { debug("request read %ld+%d out of range", be64toh(request.from), be32toh(request.len) ); reply.error = htobe32(1); write(client->socket, &reply, sizeof(reply)); return 0; } break; case REQUEST_DISCONNECT: debug("request disconnect"); return 1; default: CLIENT_ERROR("Unknown request %08x", be32toh(request.type)); } switch (be32toh(request.type)) { case REQUEST_READ: debug("request read %ld+%d", be64toh(request.from), be32toh(request.len)); write(client->socket, &reply, sizeof(reply)); offset = be64toh(request.from); CLIENT_ERROR_ON_FAILURE( sendfileloop( client->socket, client->fileno, &offset, be32toh(request.len) ), "sendfile failed from=%ld, len=%d", offset, be32toh(request.len) ); break; case REQUEST_WRITE: debug("request write %ld+%d", be64toh(request.from), be32toh(request.len)); CLIENT_ERROR_ON_FAILURE( readloop( client->socket, client->mapped + be64toh(request.from), be32toh(request.len) ), "read failed from=%ld, len=%d", be64toh(request.from), be32toh(request.len) ); write(client->socket, &reply, sizeof(reply)); break; } return 0; } void client_open_file(struct client_params* client) { client->fileno = open(client->filename, O_RDWR|O_DIRECT|O_SYNC); CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't open %s", client->filename); client->size = lseek64(client->fileno, 0, SEEK_END); CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't seek to end of %s", client->filename); client->mapped = mmap64(NULL, client->size, PROT_READ|PROT_WRITE, MAP_SHARED, client->fileno, 0); CLIENT_ERROR_ON_FAILURE((long) client->mapped, "Couldn't map file %s", client->filename); debug("opened %s size %ld on fd %d @ %p", client->filename, client->size, client->fileno, client->mapped); } void client_send_hello(struct client_params* client) { struct nbd_init init; memcpy(init.passwd, INIT_PASSWD, sizeof(INIT_PASSWD)); init.magic = htobe64(INIT_MAGIC); init.size = htobe64(client->size); memset(init.reserved, 0, 128); CLIENT_ERROR_ON_FAILURE( writeloop(client->socket, &init, sizeof(init)), "Couldn't send hello" ); } void* client_serve(void* client_uncast) { struct client_params* client = (struct client_params*) client_uncast; client_open_file(client); client_send_hello(client); while (client_serve_request(client) == 0) ; CLIENT_ERROR_ON_FAILURE( close(client->socket), "Couldn't close socket %d", client->socket ); free(client); return NULL; } /* FIXME */ int is_included_in_acl(struct ip_and_mask** list, struct sockaddr* test) { return 1; } void serve_open_socket(struct mode_serve_params* params) { params->server = socket(PF_INET, SOCK_STREAM, 0); SERVER_ERROR_ON_FAILURE(params->server, "Couldn't create server socket"); SERVER_ERROR_ON_FAILURE( bind(params->server, ¶ms->bind_to.generic, sizeof(params->bind_to.generic)), "Couldn't bind server to IP address" ); SERVER_ERROR_ON_FAILURE( listen(params->server, params->tcp_backlog), "Couldn't listen on server socket" ); } void serve_accept_loop(struct mode_serve_params* params) { while (1) { pthread_t client_thread; struct sockaddr client_address; struct client_params* client_params; socklen_t socket_length; int client_socket = accept(params->server, &client_address, &socket_length); SERVER_ERROR_ON_FAILURE(client_socket, "accept() failed"); if (params->acl && !is_included_in_acl(params->acl, &client_address)) { write(client_socket, "Access control error", 20); close(client_socket); continue; } client_params = xmalloc(sizeof(struct client_params)); client_params->socket = client_socket; client_params->filename = params->filename; client_thread = pthread_create(&client_thread, NULL, client_serve, client_params); SERVER_ERROR_ON_FAILURE(client_thread, "Failed to create client thread"); /* FIXME: keep track of them? */ /* FIXME: maybe shouldn't be fatal? */ } } void serve(struct mode_serve_params* params) { serve_open_socket(params); serve_accept_loop(params); } void params_serve( struct mode_serve_params* out, char* s_ip_address, char* s_port, char* s_file ) { out->tcp_backlog = 10; /* does this need to be settable? */ out->acl = NULL; /* ignore for now */ if (s_ip_address == NULL) SERVER_ERROR("No IP address supplied"); if (s_port == NULL) SERVER_ERROR("No port number supplied"); if (s_file == NULL) SERVER_ERROR("No filename supplied"); if (s_ip_address[0] == '0' && s_ip_address[1] == '\0') { out->bind_to.v4.sin_family = AF_INET; out->bind_to.v4.sin_addr.s_addr = INADDR_ANY; } else if (inet_pton(AF_INET, s_ip_address, &out->bind_to.v4) == 0) { } else if (inet_pton(AF_INET6, s_ip_address, &out->bind_to.v6) == 0) { } else { SERVER_ERROR("Couldn't understand address '%%' " "(use 0 if you don't care)", s_ip_address); } out->bind_to.v4.sin_port = atoi(s_port); if (out->bind_to.v4.sin_port < 0 || out->bind_to.v4.sin_port > 65535) SERVER_ERROR("Port number must be >= 0 and <= 65535"); out->bind_to.v4.sin_port = htobe16(out->bind_to.v4.sin_port); out->filename = s_file; } void mode(char* mode, int argc, char **argv) { union mode_params params; if (strcmp(mode, "serve") == 0) { if (argc >= 3) { params_serve(¶ms.serve, argv[0], argv[1], argv[2]); serve(¶ms.serve); } else { syntax(); } } else { syntax(); } exit(0); } int main(int argc, char** argv) { server_thread_id = pthread_self(); if (argc < 2) syntax(); mode(argv[1], argc-2, argv+2); return 0; }