From 0432fef8f5641e9b500c2f1621dc50caf7a78760 Mon Sep 17 00:00:00 2001 From: Matthew Bloch Date: Thu, 17 May 2012 20:14:22 +0100 Subject: [PATCH] Split code out into separate compilation units (first pass, anyway). --- Rakefile | 9 +- bitset.h | 39 +++ flexnbd.c | 690 +--------------------------------------------------- ioutil.c | 123 ++++++++++ ioutil.h | 14 ++ nbdtypes.h | 36 +++ params.h | 64 +++++ readwrite.c | 124 ++++++++++ serve.c | 283 +++++++++++++++++++++ util.c | 58 +++++ util.h | 33 +++ 11 files changed, 790 insertions(+), 683 deletions(-) create mode 100644 bitset.h create mode 100644 ioutil.c create mode 100644 ioutil.h create mode 100644 nbdtypes.h create mode 100644 params.h create mode 100644 readwrite.c create mode 100644 serve.c create mode 100644 util.c create mode 100644 util.h diff --git a/Rakefile b/Rakefile index b9c496f..a132933 100644 --- a/Rakefile +++ b/Rakefile @@ -1,5 +1,7 @@ DEBUG = true +SOURCES = %w( flexnbd ioutil readwrite serve util ) +OBJECTS = SOURCES.map { |s| "#{s}.o" } LIBS = %w( pthread ) CCFLAGS = %w( -Wall ) LDFLAGS = [] @@ -12,10 +14,11 @@ end rule 'default' => 'flexnbd' -rule 'flexnbd' => 'flexnbd.o' do |t| +rule 'flexnbd' => OBJECTS do |t| sh "gcc #{LDFLAGS.join(' ')} "+ LIBS.map { |l| "-l#{l}" }.join(" ")+ - " -o #{t.name} #{t.source}" + " -o #{t.name} "+ + t.sources.join(" ") end rule '.o' => '.c' do |t| @@ -23,5 +26,5 @@ rule '.o' => '.c' do |t| end rule 'clean' do - sh "rm -f flexnbd.o flexnbd" + sh "rm -f flexnbd "+OBJECTS.join(" ") end diff --git a/bitset.h b/bitset.h new file mode 100644 index 0000000..d136427 --- /dev/null +++ b/bitset.h @@ -0,0 +1,39 @@ +#ifndef __BITSET_H +#define __BITSET_H + +#include + +static inline char char_with_bit_set(int num) { + return 1<<(num%8); +} +static inline int bit_is_set(char* b, int idx) { + return (b[idx/8] & char_with_bit_set(idx)) != 0; +} +static inline int bit_is_clear(char* b, int idx) { + return !bit_is_set(b, idx); +} +static inline void bit_set(char* b, int idx) { + b[idx/8] &= char_with_bit_set(idx); +} +static inline void bit_clear(char* b, int idx) { + b[idx/8] &= ~char_with_bit_set(idx); +} +static inline void bit_set_range(char* b, int from, int len) { + for (; from%8 != 0 && len > 0; len--) + bit_set(b, from++); + if (len >= 8) + memset(b+(from/8), 255, len/8); + for (; len > 0; len--) + bit_set(b, from++); +} +static inline void bit_clear_range(char* b, int from, int len) { + for (; from%8 != 0 && len > 0; len--) + bit_clear(b, from++); + if (len >= 8) + memset(b+(from/8), 0, len/8); + for (; len > 0; len--) + bit_clear(b, from++); +} + +#endif + diff --git a/flexnbd.c b/flexnbd.c index 59852d5..c6a4574 100644 --- a/flexnbd.c +++ b/flexnbd.c @@ -1,57 +1,15 @@ -#define _LARGEFILE64_SOURCE -#define _GNU_SOURCE +#include "params.h" +#include "util.h" #include #include #include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include -#include -#include -#include - -/* http://linux.derkeiler.com/Mailing-Lists/Kernel/2003-09/2332.html */ -#define INIT_PASSWD "NBDMAGIC" -#define INIT_MAGIC 0x0000420281861253 -#define REQUEST_MAGIC 0x25609513 -#define REPLY_MAGIC 0x67446698 -#define REQUEST_READ 0 -#define REQUEST_WRITE 1 -#define REQUEST_DISCONNECT 2 - -#include -struct nbd_init { - char passwd[8]; - __be64 magic; - __be64 size; - char reserved[128]; -}; - -struct nbd_request { - __be32 magic; - __be32 type; /* == READ || == WRITE */ - char handle[8]; - __be64 from; - __be32 len; -} __attribute__((packed)); - -struct nbd_reply { - __be32 magic; - __be32 error; /* 0 = ok, else error */ - char handle[8]; /* handle you got from request */ -}; void syntax() { @@ -65,638 +23,6 @@ void syntax() exit(1); } -static pthread_t server_thread_id; - -void error(int consult_errno, int close_socket, const char* format, ...) -{ - va_list argptr; - - fprintf(stderr, "*** "); - - va_start(argptr, format); - vfprintf(stderr, format, argptr); - va_end(argptr); - - if (consult_errno) { - fprintf(stderr, " (errno=%d, %s)", errno, strerror(errno)); - } - - if (close_socket) - close(close_socket); - - fprintf(stderr, "\n"); - - if (pthread_equal(pthread_self(), server_thread_id)) - exit(1); - else - pthread_exit((void*) 1); -} - -#ifndef DEBUG -# define debug(msg, ...) -#else -# include -# define debug(msg, ...) fprintf(stderr, "%08x %4d: " msg "\n" , \ - (int) pthread_self(), (int) clock(), ##__VA_ARGS__) -#endif - -#define CLIENT_ERROR(msg, ...) \ - error(0, client->socket, msg, ##__VA_ARGS__) -#define CLIENT_ERROR_ON_FAILURE(test, msg, ...) \ - if (test < 0) { error(1, client->socket, msg, ##__VA_ARGS__); } -#define SERVER_ERROR(msg, ...) \ - error(0, 0, msg, ##__VA_ARGS__) -#define SERVER_ERROR_ON_FAILURE(test, msg, ...) \ - if (test < 0) { error(1, 0, msg, ##__VA_ARGS__); } - -void* xrealloc(void* ptr, size_t size) -{ - void* p = realloc(ptr, size); - if (p == NULL) - SERVER_ERROR("couldn't xrealloc %d bytes", size); - return p; -} - -void* xmalloc(size_t size) -{ - void* p = xrealloc(NULL, size); - memset(p, 0, size); - return p; -} - -union mysockaddr { - unsigned short family; - struct sockaddr generic; - struct sockaddr_in v4; - struct sockaddr_in6 v6; -}; - -struct ip_and_mask { - union mysockaddr ip; - int mask; -}; - -struct mode_serve_params { - union mysockaddr bind_to; - int acl_entries; - struct ip_and_mask** acl; - char* filename; - int tcp_backlog; - - int server; - int threads; - - pthread_mutex_t block_allocation_map_lock; - char* block_allocation_map; -}; - -struct mode_readwrite_params { - union mysockaddr connect_to; - off64_t from; - off64_t len; - int data_fd; - int client; -}; - -struct client_params { - int socket; - char* filename; - - int fileno; - off64_t size; - char* mapped; - - pthread_mutex_t block_allocation_map_lock; - char* block_allocation_map; -}}; - -union mode_params { - struct mode_serve_params serve; - struct mode_readwrite_params readwrite; -}; - -static inline int char_with_bit_set(int num) { - return 1<<(num%8); -} -static inline int bit_is_set(char* b, int idx) { - return (b[idx/8] & char_with_bit_set(idx)) != 0; -} -static inline int bit_is_clear(char* b, int idx) { - return !bit_is_set(b, idx); -} -static inline void bit_set(char* b, int idx) { - b[idx/8] &= char_with_bit_set(idx); -} -static inline void bit_clear(char* b, int idx) { - b[idx/8] &= ~char_with_bit_set(idx); -} -static inline void bit_set_range(char* b, int from, int len) { - for (; b%8 != 0 && len > 0; len--) - bit_set(b, from++, 1); - if (len >= 8) - memset(b+(from/8), 255, len/8); - for (; len > 0; len--) - bit_set(b, from++); -} -static inline void bit_clear_range(char* b, int from, int len) { - for (; b%8 != 0 && len > 0; len--) - bit_clear(b, from++, 1); - if (len >= 8) - memset(b+(from/8), 0, len/8); - for (; len > 0; len--) - bit_clear(b, from++); -} - - -char* build_allocation_map(int fd, off64_t size, int resolution) -{ - char *allocation_map = xmalloc((size+resolution)/resolution); - struct fiemap *fiemap; - - fiemap = (struct fiemap*) xmalloc(sizeof(struct fiemap)); - - fiemap->fm_start = from; - fiemap->fm_length = len; - fiemap->fm_flags = 0; - fiemap->fm_extent_count = 0; - fiemap->fm_mapped_extents = 0; - - /* Find out how many extents there are */ - if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < 0) - return NULL; - - /* Resize fiemap to allow us to read in the extents */ - fiemap = (struct fiemap*)xrealloc( - fiemap, - sizeof(struct fiemap) + ( - sizeof(struct fiemap_extent) * - fiemap->fm_mapped_extents - ) - ); - - fiemap->fm_extent_count = fiemap->fm_mapped_extents; - fiemap->fm_mapped_extents = 0; - - if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < -1) - return NULL; - - for (i=0;ifm_mapped_extents;i++) - bit_set_range( - allocation_map, - fiemap->fm_extents[i].fe_logical / resolution, - fiemap->fm_extents[i].fe_length / resolution - ); - - free(fiemap); - - return allocation_map; -} - - -int writeloop(int filedes, const void *buffer, size_t size) -{ - size_t written=0; - while (written < size) { - size_t result = write(filedes, buffer+written, size-written); - if (result == -1) - return -1; - written += result; - } - return 0; -} - -int readloop(int filedes, void *buffer, size_t size) -{ - size_t readden=0; - while (readden < size) { - size_t result = read(filedes, buffer+readden, size-readden); - if (result == 0 /* EOF */ || result == -1 /* error */) - return -1; - readden += result; - } - return 0; -} - -int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count) -{ - size_t sent=0; - while (sent < count) { - size_t result = sendfile64(out_fd, in_fd, offset+sent, count-sent); - if (result == -1) - return -1; - sent += result; - } - return 0; -} - -int splice_via_pipe_loop(int fd_in, int fd_out, size_t len) -{ - int pipefd[2]; - size_t spliced=0; - - if (pipe(pipefd) == -1) - return -1; - - while (spliced < len) { - size_t r1,r2; - r1 = splice(fd_in, NULL, pipefd[1], NULL, len-spliced, 0); - if (r1 <= 0) - break; - r2 = splice(pipefd[0], NULL, fd_out, NULL, r1, 0); - if (r1 != r2) - break; - spliced += r1; - } - close(pipefd[0]); - close(pipefd[1]); - - return spliced < len ? -1 : 0; -} - -int client_serve_request(struct client_params* client) -{ - off64_t offset; - struct nbd_request request; - struct nbd_reply reply; - struct unallocated_block** unallocated; - - if (readloop(client->socket, &request, sizeof(request)) == -1) { - if (errno == 0) { - debug("EOF reading request"); - return 1; /* neat point to close the socket */ - } - else { - CLIENT_ERROR_ON_FAILURE(-1, "Error reading request"); - } - } - - reply.magic = htobe32(REPLY_MAGIC); - reply.error = htobe32(0); - memcpy(reply.handle, request.handle, 8); - - debug("request type %d", be32toh(request.type)); - - if (be32toh(request.magic) != REQUEST_MAGIC) - CLIENT_ERROR("Bad magic %08x", be32toh(request.magic)); - - switch (be32toh(request.type)) - { - case REQUEST_READ: - case REQUEST_WRITE: - /* check it's not out of range */ - if (be64toh(request.from) < 0 || - be64toh(request.from)+be32toh(request.len) > client->size) { - debug("request read %ld+%d out of range", - be64toh(request.from), - be32toh(request.len) - ); - reply.error = htobe32(1); - write(client->socket, &reply, sizeof(reply)); - return 0; - } - break; - - case REQUEST_DISCONNECT: - debug("request disconnect"); - return 1; - - default: - CLIENT_ERROR("Unknown request %08x", be32toh(request.type)); - } - - switch (be32toh(request.type)) - { - case REQUEST_READ: - debug("request read %ld+%d", be64toh(request.from), be32toh(request.len)); - write(client->socket, &reply, sizeof(reply)); - - offset = be64toh(request.from); - CLIENT_ERROR_ON_FAILURE( - sendfileloop( - client->socket, - client->fileno, - &offset, - be32toh(request.len) - ), - "sendfile failed from=%ld, len=%d", - offset, - be32toh(request.len) - ); - break; - - case REQUEST_WRITE: - debug("request write %ld+%d", be64toh(request.from), be32toh(request.len)); -#ifdef _LINUX_FIEMAP_H - unallocated = read_unallocated_blocks( - client->fileno, - be64toh(request.from), - be32toh(request.len) - ); - if (unallocated == NULL) - CLIENT_ERROR("Couldn't read unallocated blocks list"); - - CLIENT_ERROR_ON_FAILURE( - read_from_socket_avoiding_holes( - client->socket, - ); - free(fiemap); -#else - CLIENT_ERROR_ON_FAILURE( - readloop( - client->socket, - client->mapped + be64toh(request.from), - be32toh(request.len) - ), - "read failed from=%ld, len=%d", - be64toh(request.from), - be32toh(request.len) - ); -#endif - write(client->socket, &reply, sizeof(reply)); - - break; - } - return 0; -} - -void client_open_file(struct client_params* client) -{ - client->fileno = open(client->filename, O_RDWR|O_DIRECT|O_SYNC); - CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't open %s", - client->filename); - client->size = lseek64(client->fileno, 0, SEEK_END); - CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't seek to end of %s", - client->filename); - client->mapped = mmap64(NULL, client->size, PROT_READ|PROT_WRITE, - MAP_SHARED, client->fileno, 0); - CLIENT_ERROR_ON_FAILURE((long) client->mapped, "Couldn't map file %s", - client->filename); - debug("opened %s size %ld on fd %d @ %p", client->filename, client->size, client->fileno, client->mapped); -} - -void client_send_hello(struct client_params* client) -{ - struct nbd_init init; - - memcpy(init.passwd, INIT_PASSWD, sizeof(INIT_PASSWD)); - init.magic = htobe64(INIT_MAGIC); - init.size = htobe64(client->size); - memset(init.reserved, 0, 128); - CLIENT_ERROR_ON_FAILURE( - writeloop(client->socket, &init, sizeof(init)), - "Couldn't send hello" - ); -} - -void* client_serve(void* client_uncast) -{ - struct client_params* client = (struct client_params*) client_uncast; - - client_open_file(client); - client_send_hello(client); - - while (client_serve_request(client) == 0) - ; - - CLIENT_ERROR_ON_FAILURE( - close(client->socket), - "Couldn't close socket %d", - client->socket - ); - - free(client); - return NULL; -} - -static int testmasks[9] = { 0,128,192,224,240,248,252,254,255 }; - -int is_included_in_acl(int list_length, struct ip_and_mask** list, struct sockaddr* test) -{ - int i; - - for (i=0; i < list_length; i++) { - struct ip_and_mask *entry = list[i]; - int testbits; - char *raw_address1, *raw_address2; - - debug("checking acl entry %d", i); - - if (test->sa_family != entry->ip.family) - continue; - - if (test->sa_family == AF_INET) { - raw_address1 = (char*) - &((struct sockaddr_in*) test)->sin_addr; - raw_address2 = (char*) &entry->ip.v4.sin_addr; - } - else if (test->sa_family == AF_INET6) { - raw_address1 = (char*) - &((struct sockaddr_in6*) test)->sin6_addr; - raw_address2 = (char*) &entry->ip.v6.sin6_addr; - } - - for (testbits = entry->mask; testbits > 0; testbits -= 8) { - debug("testbits=%d, c1=%d, c2=%d", testbits, raw_address1[0], raw_address2[0]); - if (testbits >= 8) { - if (raw_address1[0] != raw_address2[0]) - goto no_match; - } - else { - if ((raw_address1[0] & testmasks[testbits%8]) != - (raw_address2[0] & testmasks[testbits%8]) ) - goto no_match; - } - - raw_address1++; - raw_address2++; - } - - return 1; - - no_match: ; - debug("no match"); - } - - return 0; -} - -void serve_open_socket(struct mode_serve_params* params) -{ - params->server = socket(PF_INET, SOCK_STREAM, 0); - - SERVER_ERROR_ON_FAILURE(params->server, - "Couldn't create server socket"); - - SERVER_ERROR_ON_FAILURE( - bind(params->server, ¶ms->bind_to.generic, - sizeof(params->bind_to.generic)), - "Couldn't bind server to IP address" - ); - - SERVER_ERROR_ON_FAILURE( - listen(params->server, params->tcp_backlog), - "Couldn't listen on server socket" - ); -} - -void serve_accept_loop(struct mode_serve_params* params) -{ - while (1) { - pthread_t client_thread; - struct sockaddr client_address; - struct client_params* client_params; - socklen_t socket_length=0; - - int client_socket = accept(params->server, &client_address, - &socket_length); - - SERVER_ERROR_ON_FAILURE(client_socket, "accept() failed"); - - if (params->acl && - !is_included_in_acl(params->acl_entries, params->acl, &client_address)) { - write(client_socket, "Access control error", 20); - close(client_socket); - continue; - } - - client_params = xmalloc(sizeof(struct client_params)); - client_params->socket = client_socket; - client_params->filename = params->filename; - client_params->block_allocation_map = - params->block_allocation_map; - client_params->block_allocation_map_lock = - params->block_allocation_map_lock; - - client_thread = pthread_create(&client_thread, NULL, - client_serve, client_params); - SERVER_ERROR_ON_FAILURE(client_thread, - "Failed to create client thread"); - /* FIXME: keep track of them? */ - /* FIXME: maybe shouldn't be fatal? */ - } -} - -void do_serve(struct mode_serve_params* params) -{ - serve_open_socket(params); - serve_accept_loop(params); -} - -int socket_connect(struct sockaddr* to) -{ - int fd = socket(PF_INET, SOCK_STREAM, 0); - SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket"); - SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(*to)), - "connect failed"); - return fd; -} - -off64_t socket_nbd_read_hello(int fd) -{ - struct nbd_init init; - SERVER_ERROR_ON_FAILURE(readloop(fd, &init, sizeof(init)), - "Couldn't read init"); - if (strncmp(init.passwd, INIT_PASSWD, 8) != 0) - SERVER_ERROR("wrong passwd"); - if (be64toh(init.magic) != INIT_MAGIC) - SERVER_ERROR("wrong magic (%x)", be64toh(init.magic)); - return be64toh(init.size); -} - -void fill_request(struct nbd_request *request, int type, int from, int len) -{ - request->magic = htobe32(REQUEST_MAGIC); - request->type = htobe32(type); - ((int*) request->handle)[0] = rand(); - ((int*) request->handle)[1] = rand(); - request->from = htobe64(from); - request->len = htobe32(len); -} - -void read_reply(int fd, struct nbd_request *request, struct nbd_reply *reply) -{ - SERVER_ERROR_ON_FAILURE(readloop(fd, reply, sizeof(*reply)), - "Couldn't read reply"); - if (be32toh(reply->magic) != REPLY_MAGIC) - SERVER_ERROR("Reply magic incorrect (%p)", reply->magic); - if (be32toh(reply->error) != 0) - SERVER_ERROR("Server replied with error %d", reply->error); - if (strncmp(request->handle, reply->handle, 8) != 0) - SERVER_ERROR("Did not reply with correct handle"); -} - -void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf) -{ - struct nbd_request request; - struct nbd_reply reply; - - fill_request(&request, REQUEST_READ, from, len); - SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), - "Couldn't write request"); - read_reply(fd, &request, &reply); - - if (out_buf) { - SERVER_ERROR_ON_FAILURE(readloop(fd, out_buf, len), - "Read failed"); - } - else { - SERVER_ERROR_ON_FAILURE( - splice_via_pipe_loop(fd, out_fd, len), - "Splice failed" - ); - } -} - -void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) -{ - struct nbd_request request; - struct nbd_reply reply; - - fill_request(&request, REQUEST_WRITE, from, len); - SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), - "Couldn't write request"); - - if (in_buf) { - SERVER_ERROR_ON_FAILURE(writeloop(fd, in_buf, len), - "Write failed"); - } - else { - SERVER_ERROR_ON_FAILURE( - splice_via_pipe_loop(in_fd, fd, len), - "Splice failed" - ); - } - - read_reply(fd, &request, &reply); -} - -#define CHECK_RANGE(error_type) { \ - off64_t size = socket_nbd_read_hello(params->client); \ - if (params->from < 0 || (params->from + params->len) >= size) \ - SERVER_ERROR(error_type \ - " request %d+%d is out of range given size %d", \ - params->from, params->len, size\ - ); \ -} - -void do_read(struct mode_readwrite_params* params) -{ - params->client = socket_connect(¶ms->connect_to.generic); - CHECK_RANGE("read"); - socket_nbd_read(params->client, params->from, params->len, - params->data_fd, NULL); - close(params->client); -} - -void do_write(struct mode_readwrite_params* params) -{ - params->client = socket_connect(¶ms->connect_to.generic); - CHECK_RANGE("write"); - socket_nbd_write(params->client, params->from, params->len, - params->data_fd, NULL); - close(params->client); -} - #define IS_IP_VALID_CHAR(x) ( ((x) >= '0' && (x) <= '9' ) || \ ((x) >= 'a' && (x) <= 'f') || \ ((x) >= 'A' && (x) <= 'F' ) || \ @@ -862,6 +188,10 @@ void params_readwrite( } } +void do_serve(struct mode_serve_params* params); +void do_read(struct mode_readwrite_params* params); +void do_write(struct mode_readwrite_params* params); + void mode(char* mode, int argc, char **argv) { union mode_params params; @@ -901,7 +231,7 @@ void mode(char* mode, int argc, char **argv) int main(int argc, char** argv) { - server_thread_id = pthread_self(); + error_init(); if (argc < 2) syntax(); diff --git a/ioutil.c b/ioutil.c new file mode 100644 index 0000000..761b74d --- /dev/null +++ b/ioutil.c @@ -0,0 +1,123 @@ +#define _LARGEFILE64_SOURCE +#define _GNU_SOURCE + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "util.h" +#include "bitset.h" + +char* build_allocation_map(int fd, off64_t size, int resolution) +{ + int i; + char *allocation_map = xmalloc((size+resolution)/resolution); + struct fiemap *fiemap; + + fiemap = (struct fiemap*) xmalloc(sizeof(struct fiemap)); + + fiemap->fm_start = 0; + fiemap->fm_length = size; + fiemap->fm_flags = 0; + fiemap->fm_extent_count = 0; + fiemap->fm_mapped_extents = 0; + + /* Find out how many extents there are */ + if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < 0) + return NULL; + + /* Resize fiemap to allow us to read in the extents */ + fiemap = (struct fiemap*)xrealloc( + fiemap, + sizeof(struct fiemap) + ( + sizeof(struct fiemap_extent) * + fiemap->fm_mapped_extents + ) + ); + + fiemap->fm_extent_count = fiemap->fm_mapped_extents; + fiemap->fm_mapped_extents = 0; + + if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < -1) + return NULL; + + for (i=0;ifm_mapped_extents;i++) + bit_set_range( + allocation_map, + fiemap->fm_extents[i].fe_logical / resolution, + fiemap->fm_extents[i].fe_length / resolution + ); + + free(fiemap); + + return allocation_map; +} + + +int writeloop(int filedes, const void *buffer, size_t size) +{ + size_t written=0; + while (written < size) { + size_t result = write(filedes, buffer+written, size-written); + if (result == -1) + return -1; + written += result; + } + return 0; +} + +int readloop(int filedes, void *buffer, size_t size) +{ + size_t readden=0; + while (readden < size) { + size_t result = read(filedes, buffer+readden, size-readden); + if (result == 0 /* EOF */ || result == -1 /* error */) + return -1; + readden += result; + } + return 0; +} + +int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count) +{ + size_t sent=0; + while (sent < count) { + size_t result = sendfile64(out_fd, in_fd, offset+sent, count-sent); + if (result == -1) + return -1; + sent += result; + } + return 0; +} + +int splice_via_pipe_loop(int fd_in, int fd_out, size_t len) +{ + int pipefd[2]; + size_t spliced=0; + + if (pipe(pipefd) == -1) + return -1; + + while (spliced < len) { + size_t r1,r2; + r1 = splice(fd_in, NULL, pipefd[1], NULL, len-spliced, 0); + if (r1 <= 0) + break; + r2 = splice(pipefd[0], NULL, fd_out, NULL, r1, 0); + if (r1 != r2) + break; + spliced += r1; + } + close(pipefd[0]); + close(pipefd[1]); + + return spliced < len ? -1 : 0; +} + + diff --git a/ioutil.h b/ioutil.h new file mode 100644 index 0000000..055a057 --- /dev/null +++ b/ioutil.h @@ -0,0 +1,14 @@ +#ifndef __IOUTIL_H +#define __IOUTIL_H + + +#include "params.h" + +char* build_allocation_map(int fd, off64_t size, int resolution); +int writeloop(int filedes, const void *buffer, size_t size); +int readloop(int filedes, void *buffer, size_t size); +int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count); +int splice_via_pipe_loop(int fd_in, int fd_out, size_t len); + +#endif + diff --git a/nbdtypes.h b/nbdtypes.h new file mode 100644 index 0000000..499a4a1 --- /dev/null +++ b/nbdtypes.h @@ -0,0 +1,36 @@ +#ifndef __NBDTYPES_H +#define __NBDTYPES_H + +/* http://linux.derkeiler.com/Mailing-Lists/Kernel/2003-09/2332.html */ +#define INIT_PASSWD "NBDMAGIC" +#define INIT_MAGIC 0x0000420281861253 +#define REQUEST_MAGIC 0x25609513 +#define REPLY_MAGIC 0x67446698 +#define REQUEST_READ 0 +#define REQUEST_WRITE 1 +#define REQUEST_DISCONNECT 2 + +#include +struct nbd_init { + char passwd[8]; + __be64 magic; + __be64 size; + char reserved[128]; +}; + +struct nbd_request { + __be32 magic; + __be32 type; /* == READ || == WRITE */ + char handle[8]; + __be64 from; + __be32 len; +} __attribute__((packed)); + +struct nbd_reply { + __be32 magic; + __be32 error; /* 0 = ok, else error */ + char handle[8]; /* handle you got from request */ +}; + +#endif + diff --git a/params.h b/params.h new file mode 100644 index 0000000..ac396cc --- /dev/null +++ b/params.h @@ -0,0 +1,64 @@ +#ifndef __PARAMS_H +#define __PARAMS_H + +#define _GNU_SOURCE +#define _LARGEFILE64_SOURCE + +#include +#include +#include +#include + +union mysockaddr { + unsigned short family; + struct sockaddr generic; + struct sockaddr_in v4; + struct sockaddr_in6 v6; +}; + +struct ip_and_mask { + union mysockaddr ip; + int mask; +}; + +struct mode_serve_params { + union mysockaddr bind_to; + int acl_entries; + struct ip_and_mask** acl; + char* filename; + int tcp_backlog; + + int server; + int threads; + + pthread_mutex_t block_allocation_map_lock; + char* block_allocation_map; +}; + +struct mode_readwrite_params { + union mysockaddr connect_to; + off64_t from; + off64_t len; + int data_fd; + int client; +}; + +struct client_params { + int socket; + char* filename; + + int fileno; + off64_t size; + char* mapped; + + pthread_mutex_t block_allocation_map_lock; + char* block_allocation_map; +}; + +union mode_params { + struct mode_serve_params serve; + struct mode_readwrite_params readwrite; +}; + +#endif + diff --git a/readwrite.c b/readwrite.c new file mode 100644 index 0000000..3f571d3 --- /dev/null +++ b/readwrite.c @@ -0,0 +1,124 @@ +#include "nbdtypes.h" +#include "ioutil.h" +#include "util.h" +#include "params.h" + +#include +#include +#include + +int socket_connect(struct sockaddr* to) +{ + int fd = socket(PF_INET, SOCK_STREAM, 0); + SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket"); + SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(*to)), + "connect failed"); + return fd; +} + +off64_t socket_nbd_read_hello(int fd) +{ + struct nbd_init init; + SERVER_ERROR_ON_FAILURE(readloop(fd, &init, sizeof(init)), + "Couldn't read init"); + if (strncmp(init.passwd, INIT_PASSWD, 8) != 0) + SERVER_ERROR("wrong passwd"); + if (be64toh(init.magic) != INIT_MAGIC) + SERVER_ERROR("wrong magic (%x)", be64toh(init.magic)); + return be64toh(init.size); +} + +void fill_request(struct nbd_request *request, int type, int from, int len) +{ + request->magic = htobe32(REQUEST_MAGIC); + request->type = htobe32(type); + ((int*) request->handle)[0] = rand(); + ((int*) request->handle)[1] = rand(); + request->from = htobe64(from); + request->len = htobe32(len); +} + +void read_reply(int fd, struct nbd_request *request, struct nbd_reply *reply) +{ + SERVER_ERROR_ON_FAILURE(readloop(fd, reply, sizeof(*reply)), + "Couldn't read reply"); + if (be32toh(reply->magic) != REPLY_MAGIC) + SERVER_ERROR("Reply magic incorrect (%p)", reply->magic); + if (be32toh(reply->error) != 0) + SERVER_ERROR("Server replied with error %d", reply->error); + if (strncmp(request->handle, reply->handle, 8) != 0) + SERVER_ERROR("Did not reply with correct handle"); +} + +void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf) +{ + struct nbd_request request; + struct nbd_reply reply; + + fill_request(&request, REQUEST_READ, from, len); + SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), + "Couldn't write request"); + read_reply(fd, &request, &reply); + + if (out_buf) { + SERVER_ERROR_ON_FAILURE(readloop(fd, out_buf, len), + "Read failed"); + } + else { + SERVER_ERROR_ON_FAILURE( + splice_via_pipe_loop(fd, out_fd, len), + "Splice failed" + ); + } +} + +void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) +{ + struct nbd_request request; + struct nbd_reply reply; + + fill_request(&request, REQUEST_WRITE, from, len); + SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), + "Couldn't write request"); + + if (in_buf) { + SERVER_ERROR_ON_FAILURE(writeloop(fd, in_buf, len), + "Write failed"); + } + else { + SERVER_ERROR_ON_FAILURE( + splice_via_pipe_loop(in_fd, fd, len), + "Splice failed" + ); + } + + read_reply(fd, &request, &reply); +} + +#define CHECK_RANGE(error_type) { \ + off64_t size = socket_nbd_read_hello(params->client); \ + if (params->from < 0 || (params->from + params->len) >= size) \ + SERVER_ERROR(error_type \ + " request %d+%d is out of range given size %d", \ + params->from, params->len, size\ + ); \ +} + +void do_read(struct mode_readwrite_params* params) +{ + params->client = socket_connect(¶ms->connect_to.generic); + CHECK_RANGE("read"); + socket_nbd_read(params->client, params->from, params->len, + params->data_fd, NULL); + close(params->client); +} + +void do_write(struct mode_readwrite_params* params) +{ + params->client = socket_connect(¶ms->connect_to.generic); + CHECK_RANGE("write"); + socket_nbd_write(params->client, params->from, params->len, + params->data_fd, NULL); + close(params->client); +} + diff --git a/serve.c b/serve.c new file mode 100644 index 0000000..10d2e61 --- /dev/null +++ b/serve.c @@ -0,0 +1,283 @@ +#include "params.h" +#include "nbdtypes.h" +#include "ioutil.h" +#include "util.h" + +#include +#include +#include +#include + +#include +#include +#include + +int client_serve_request(struct client_params* client) +{ + off64_t offset; + struct nbd_request request; + struct nbd_reply reply; +// struct unallocated_block** unallocated; + + if (readloop(client->socket, &request, sizeof(request)) == -1) { + if (errno == 0) { + debug("EOF reading request"); + return 1; /* neat point to close the socket */ + } + else { + CLIENT_ERROR_ON_FAILURE(-1, "Error reading request"); + } + } + + reply.magic = htobe32(REPLY_MAGIC); + reply.error = htobe32(0); + memcpy(reply.handle, request.handle, 8); + + debug("request type %d", be32toh(request.type)); + + if (be32toh(request.magic) != REQUEST_MAGIC) + CLIENT_ERROR("Bad magic %08x", be32toh(request.magic)); + + switch (be32toh(request.type)) + { + case REQUEST_READ: + case REQUEST_WRITE: + /* check it's not out of range */ + if (be64toh(request.from) < 0 || + be64toh(request.from)+be32toh(request.len) > client->size) { + debug("request read %ld+%d out of range", + be64toh(request.from), + be32toh(request.len) + ); + reply.error = htobe32(1); + write(client->socket, &reply, sizeof(reply)); + return 0; + } + break; + + case REQUEST_DISCONNECT: + debug("request disconnect"); + return 1; + + default: + CLIENT_ERROR("Unknown request %08x", be32toh(request.type)); + } + + switch (be32toh(request.type)) + { + case REQUEST_READ: + debug("request read %ld+%d", be64toh(request.from), be32toh(request.len)); + write(client->socket, &reply, sizeof(reply)); + + offset = be64toh(request.from); + CLIENT_ERROR_ON_FAILURE( + sendfileloop( + client->socket, + client->fileno, + &offset, + be32toh(request.len) + ), + "sendfile failed from=%ld, len=%d", + offset, + be32toh(request.len) + ); + break; + + case REQUEST_WRITE: + debug("request write %ld+%d", be64toh(request.from), be32toh(request.len)); +#ifdef _LINUX_FIEMAP_H + unallocated = read_unallocated_blocks( + client->fileno, + be64toh(request.from), + be32toh(request.len) + ); + if (unallocated == NULL) + CLIENT_ERROR("Couldn't read unallocated blocks list"); + + CLIENT_ERROR_ON_FAILURE( + read_from_socket_avoiding_holes( + client->socket, + ); + free(fiemap); +#else + CLIENT_ERROR_ON_FAILURE( + readloop( + client->socket, + client->mapped + be64toh(request.from), + be32toh(request.len) + ), + "read failed from=%ld, len=%d", + be64toh(request.from), + be32toh(request.len) + ); +#endif + write(client->socket, &reply, sizeof(reply)); + + break; + } + return 0; +} + +void client_open_file(struct client_params* client) +{ + client->fileno = open(client->filename, O_RDWR|O_DIRECT|O_SYNC); + CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't open %s", + client->filename); + client->size = lseek64(client->fileno, 0, SEEK_END); + CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't seek to end of %s", + client->filename); + client->mapped = mmap64(NULL, client->size, PROT_READ|PROT_WRITE, + MAP_SHARED, client->fileno, 0); + CLIENT_ERROR_ON_FAILURE((long) client->mapped, "Couldn't map file %s", + client->filename); + debug("opened %s size %ld on fd %d @ %p", client->filename, client->size, client->fileno, client->mapped); +} + +void client_send_hello(struct client_params* client) +{ + struct nbd_init init; + + memcpy(init.passwd, INIT_PASSWD, sizeof(INIT_PASSWD)); + init.magic = htobe64(INIT_MAGIC); + init.size = htobe64(client->size); + memset(init.reserved, 0, 128); + CLIENT_ERROR_ON_FAILURE( + writeloop(client->socket, &init, sizeof(init)), + "Couldn't send hello" + ); +} + +void* client_serve(void* client_uncast) +{ + struct client_params* client = (struct client_params*) client_uncast; + + client_open_file(client); + client_send_hello(client); + + while (client_serve_request(client) == 0) + ; + + CLIENT_ERROR_ON_FAILURE( + close(client->socket), + "Couldn't close socket %d", + client->socket + ); + + free(client); + return NULL; +} + +static int testmasks[9] = { 0,128,192,224,240,248,252,254,255 }; + +int is_included_in_acl(int list_length, struct ip_and_mask** list, struct sockaddr* test) +{ + int i; + + for (i=0; i < list_length; i++) { + struct ip_and_mask *entry = list[i]; + int testbits; + char *raw_address1, *raw_address2; + + debug("checking acl entry %d", i); + + if (test->sa_family != entry->ip.family) + continue; + + if (test->sa_family == AF_INET) { + raw_address1 = (char*) + &((struct sockaddr_in*) test)->sin_addr; + raw_address2 = (char*) &entry->ip.v4.sin_addr; + } + else if (test->sa_family == AF_INET6) { + raw_address1 = (char*) + &((struct sockaddr_in6*) test)->sin6_addr; + raw_address2 = (char*) &entry->ip.v6.sin6_addr; + } + + for (testbits = entry->mask; testbits > 0; testbits -= 8) { + debug("testbits=%d, c1=%d, c2=%d", testbits, raw_address1[0], raw_address2[0]); + if (testbits >= 8) { + if (raw_address1[0] != raw_address2[0]) + goto no_match; + } + else { + if ((raw_address1[0] & testmasks[testbits%8]) != + (raw_address2[0] & testmasks[testbits%8]) ) + goto no_match; + } + + raw_address1++; + raw_address2++; + } + + return 1; + + no_match: ; + debug("no match"); + } + + return 0; +} + +void serve_open_socket(struct mode_serve_params* params) +{ + params->server = socket(PF_INET, SOCK_STREAM, 0); + + SERVER_ERROR_ON_FAILURE(params->server, + "Couldn't create server socket"); + + SERVER_ERROR_ON_FAILURE( + bind(params->server, ¶ms->bind_to.generic, + sizeof(params->bind_to.generic)), + "Couldn't bind server to IP address" + ); + + SERVER_ERROR_ON_FAILURE( + listen(params->server, params->tcp_backlog), + "Couldn't listen on server socket" + ); +} + +void serve_accept_loop(struct mode_serve_params* params) +{ + while (1) { + pthread_t client_thread; + struct sockaddr client_address; + struct client_params* client_params; + socklen_t socket_length=0; + + int client_socket = accept(params->server, &client_address, + &socket_length); + + SERVER_ERROR_ON_FAILURE(client_socket, "accept() failed"); + + if (params->acl && + !is_included_in_acl(params->acl_entries, params->acl, &client_address)) { + write(client_socket, "Access control error", 20); + close(client_socket); + continue; + } + + client_params = xmalloc(sizeof(struct client_params)); + client_params->socket = client_socket; + client_params->filename = params->filename; + client_params->block_allocation_map = + params->block_allocation_map; + client_params->block_allocation_map_lock = + params->block_allocation_map_lock; + + client_thread = pthread_create(&client_thread, NULL, + client_serve, client_params); + SERVER_ERROR_ON_FAILURE(client_thread, + "Failed to create client thread"); + /* FIXME: keep track of them? */ + /* FIXME: maybe shouldn't be fatal? */ + } +} + +void do_serve(struct mode_serve_params* params) +{ + serve_open_socket(params); + serve_accept_loop(params); +} + diff --git a/util.c b/util.c new file mode 100644 index 0000000..fd6a1ac --- /dev/null +++ b/util.c @@ -0,0 +1,58 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "util.h" + +static pthread_t main_thread; + +void error_init() +{ + main_thread = pthread_self(); +} + +void error(int consult_errno, int close_socket, const char* format, ...) +{ + va_list argptr; + + fprintf(stderr, "*** "); + + va_start(argptr, format); + vfprintf(stderr, format, argptr); + va_end(argptr); + + if (consult_errno) { + fprintf(stderr, " (errno=%d, %s)", errno, strerror(errno)); + } + + if (close_socket) + close(close_socket); + + fprintf(stderr, "\n"); + + if (pthread_equal(pthread_self(), main_thread)) + exit(1); + else + pthread_exit((void*) 1); +} + +void* xrealloc(void* ptr, size_t size) +{ + void* p = realloc(ptr, size); + if (p == NULL) + SERVER_ERROR("couldn't xrealloc %d bytes", size); + return p; +} + +void* xmalloc(size_t size) +{ + void* p = xrealloc(NULL, size); + memset(p, 0, size); + return p; +} + diff --git a/util.h b/util.h new file mode 100644 index 0000000..b01f633 --- /dev/null +++ b/util.h @@ -0,0 +1,33 @@ +#ifndef __UTIL_H +#define __UTIL_H + +#include +#include + +void error_init(); + +void error(int consult_errno, int close_socket, const char* format, ...); + +void* xrealloc(void* ptr, size_t size); + +void* xmalloc(size_t size); + +#ifndef DEBUG +# define debug(msg, ...) +#else +# include +# define debug(msg, ...) fprintf(stderr, "%08x %4d: " msg "\n" , \ + (int) pthread_self(), (int) clock(), ##__VA_ARGS__) +#endif + +#define CLIENT_ERROR(msg, ...) \ + error(0, client->socket, msg, ##__VA_ARGS__) +#define CLIENT_ERROR_ON_FAILURE(test, msg, ...) \ + if (test < 0) { error(1, client->socket, msg, ##__VA_ARGS__); } +#define SERVER_ERROR(msg, ...) \ + error(0, 0, msg, ##__VA_ARGS__) +#define SERVER_ERROR_ON_FAILURE(test, msg, ...) \ + if (test < 0) { error(1, 0, msg, ##__VA_ARGS__); } + +#endif +