diff --git a/Rakefile b/Rakefile index b9c496f..a132933 100644 --- a/Rakefile +++ b/Rakefile @@ -1,5 +1,7 @@ DEBUG = true +SOURCES = %w( flexnbd ioutil readwrite serve util ) +OBJECTS = SOURCES.map { |s| "#{s}.o" } LIBS = %w( pthread ) CCFLAGS = %w( -Wall ) LDFLAGS = [] @@ -12,10 +14,11 @@ end rule 'default' => 'flexnbd' -rule 'flexnbd' => 'flexnbd.o' do |t| +rule 'flexnbd' => OBJECTS do |t| sh "gcc #{LDFLAGS.join(' ')} "+ LIBS.map { |l| "-l#{l}" }.join(" ")+ - " -o #{t.name} #{t.source}" + " -o #{t.name} "+ + t.sources.join(" ") end rule '.o' => '.c' do |t| @@ -23,5 +26,5 @@ rule '.o' => '.c' do |t| end rule 'clean' do - sh "rm -f flexnbd.o flexnbd" + sh "rm -f flexnbd "+OBJECTS.join(" ") end diff --git a/bitset.h b/bitset.h new file mode 100644 index 0000000..d136427 --- /dev/null +++ b/bitset.h @@ -0,0 +1,39 @@ +#ifndef __BITSET_H +#define __BITSET_H + +#include + +static inline char char_with_bit_set(int num) { + return 1<<(num%8); +} +static inline int bit_is_set(char* b, int idx) { + return (b[idx/8] & char_with_bit_set(idx)) != 0; +} +static inline int bit_is_clear(char* b, int idx) { + return !bit_is_set(b, idx); +} +static inline void bit_set(char* b, int idx) { + b[idx/8] &= char_with_bit_set(idx); +} +static inline void bit_clear(char* b, int idx) { + b[idx/8] &= ~char_with_bit_set(idx); +} +static inline void bit_set_range(char* b, int from, int len) { + for (; from%8 != 0 && len > 0; len--) + bit_set(b, from++); + if (len >= 8) + memset(b+(from/8), 255, len/8); + for (; len > 0; len--) + bit_set(b, from++); +} +static inline void bit_clear_range(char* b, int from, int len) { + for (; from%8 != 0 && len > 0; len--) + bit_clear(b, from++); + if (len >= 8) + memset(b+(from/8), 0, len/8); + for (; len > 0; len--) + bit_clear(b, from++); +} + +#endif + diff --git a/flexnbd.c b/flexnbd.c index 59852d5..c6a4574 100644 --- a/flexnbd.c +++ b/flexnbd.c @@ -1,57 +1,15 @@ -#define _LARGEFILE64_SOURCE -#define _GNU_SOURCE +#include "params.h" +#include "util.h" #include #include #include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include -#include -#include -#include - -/* http://linux.derkeiler.com/Mailing-Lists/Kernel/2003-09/2332.html */ -#define INIT_PASSWD "NBDMAGIC" -#define INIT_MAGIC 0x0000420281861253 -#define REQUEST_MAGIC 0x25609513 -#define REPLY_MAGIC 0x67446698 -#define REQUEST_READ 0 -#define REQUEST_WRITE 1 -#define REQUEST_DISCONNECT 2 - -#include -struct nbd_init { - char passwd[8]; - __be64 magic; - __be64 size; - char reserved[128]; -}; - -struct nbd_request { - __be32 magic; - __be32 type; /* == READ || == WRITE */ - char handle[8]; - __be64 from; - __be32 len; -} __attribute__((packed)); - -struct nbd_reply { - __be32 magic; - __be32 error; /* 0 = ok, else error */ - char handle[8]; /* handle you got from request */ -}; void syntax() { @@ -65,638 +23,6 @@ void syntax() exit(1); } -static pthread_t server_thread_id; - -void error(int consult_errno, int close_socket, const char* format, ...) -{ - va_list argptr; - - fprintf(stderr, "*** "); - - va_start(argptr, format); - vfprintf(stderr, format, argptr); - va_end(argptr); - - if (consult_errno) { - fprintf(stderr, " (errno=%d, %s)", errno, strerror(errno)); - } - - if (close_socket) - close(close_socket); - - fprintf(stderr, "\n"); - - if (pthread_equal(pthread_self(), server_thread_id)) - exit(1); - else - pthread_exit((void*) 1); -} - -#ifndef DEBUG -# define debug(msg, ...) -#else -# include -# define debug(msg, ...) fprintf(stderr, "%08x %4d: " msg "\n" , \ - (int) pthread_self(), (int) clock(), ##__VA_ARGS__) -#endif - -#define CLIENT_ERROR(msg, ...) \ - error(0, client->socket, msg, ##__VA_ARGS__) -#define CLIENT_ERROR_ON_FAILURE(test, msg, ...) \ - if (test < 0) { error(1, client->socket, msg, ##__VA_ARGS__); } -#define SERVER_ERROR(msg, ...) \ - error(0, 0, msg, ##__VA_ARGS__) -#define SERVER_ERROR_ON_FAILURE(test, msg, ...) \ - if (test < 0) { error(1, 0, msg, ##__VA_ARGS__); } - -void* xrealloc(void* ptr, size_t size) -{ - void* p = realloc(ptr, size); - if (p == NULL) - SERVER_ERROR("couldn't xrealloc %d bytes", size); - return p; -} - -void* xmalloc(size_t size) -{ - void* p = xrealloc(NULL, size); - memset(p, 0, size); - return p; -} - -union mysockaddr { - unsigned short family; - struct sockaddr generic; - struct sockaddr_in v4; - struct sockaddr_in6 v6; -}; - -struct ip_and_mask { - union mysockaddr ip; - int mask; -}; - -struct mode_serve_params { - union mysockaddr bind_to; - int acl_entries; - struct ip_and_mask** acl; - char* filename; - int tcp_backlog; - - int server; - int threads; - - pthread_mutex_t block_allocation_map_lock; - char* block_allocation_map; -}; - -struct mode_readwrite_params { - union mysockaddr connect_to; - off64_t from; - off64_t len; - int data_fd; - int client; -}; - -struct client_params { - int socket; - char* filename; - - int fileno; - off64_t size; - char* mapped; - - pthread_mutex_t block_allocation_map_lock; - char* block_allocation_map; -}}; - -union mode_params { - struct mode_serve_params serve; - struct mode_readwrite_params readwrite; -}; - -static inline int char_with_bit_set(int num) { - return 1<<(num%8); -} -static inline int bit_is_set(char* b, int idx) { - return (b[idx/8] & char_with_bit_set(idx)) != 0; -} -static inline int bit_is_clear(char* b, int idx) { - return !bit_is_set(b, idx); -} -static inline void bit_set(char* b, int idx) { - b[idx/8] &= char_with_bit_set(idx); -} -static inline void bit_clear(char* b, int idx) { - b[idx/8] &= ~char_with_bit_set(idx); -} -static inline void bit_set_range(char* b, int from, int len) { - for (; b%8 != 0 && len > 0; len--) - bit_set(b, from++, 1); - if (len >= 8) - memset(b+(from/8), 255, len/8); - for (; len > 0; len--) - bit_set(b, from++); -} -static inline void bit_clear_range(char* b, int from, int len) { - for (; b%8 != 0 && len > 0; len--) - bit_clear(b, from++, 1); - if (len >= 8) - memset(b+(from/8), 0, len/8); - for (; len > 0; len--) - bit_clear(b, from++); -} - - -char* build_allocation_map(int fd, off64_t size, int resolution) -{ - char *allocation_map = xmalloc((size+resolution)/resolution); - struct fiemap *fiemap; - - fiemap = (struct fiemap*) xmalloc(sizeof(struct fiemap)); - - fiemap->fm_start = from; - fiemap->fm_length = len; - fiemap->fm_flags = 0; - fiemap->fm_extent_count = 0; - fiemap->fm_mapped_extents = 0; - - /* Find out how many extents there are */ - if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < 0) - return NULL; - - /* Resize fiemap to allow us to read in the extents */ - fiemap = (struct fiemap*)xrealloc( - fiemap, - sizeof(struct fiemap) + ( - sizeof(struct fiemap_extent) * - fiemap->fm_mapped_extents - ) - ); - - fiemap->fm_extent_count = fiemap->fm_mapped_extents; - fiemap->fm_mapped_extents = 0; - - if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < -1) - return NULL; - - for (i=0;ifm_mapped_extents;i++) - bit_set_range( - allocation_map, - fiemap->fm_extents[i].fe_logical / resolution, - fiemap->fm_extents[i].fe_length / resolution - ); - - free(fiemap); - - return allocation_map; -} - - -int writeloop(int filedes, const void *buffer, size_t size) -{ - size_t written=0; - while (written < size) { - size_t result = write(filedes, buffer+written, size-written); - if (result == -1) - return -1; - written += result; - } - return 0; -} - -int readloop(int filedes, void *buffer, size_t size) -{ - size_t readden=0; - while (readden < size) { - size_t result = read(filedes, buffer+readden, size-readden); - if (result == 0 /* EOF */ || result == -1 /* error */) - return -1; - readden += result; - } - return 0; -} - -int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count) -{ - size_t sent=0; - while (sent < count) { - size_t result = sendfile64(out_fd, in_fd, offset+sent, count-sent); - if (result == -1) - return -1; - sent += result; - } - return 0; -} - -int splice_via_pipe_loop(int fd_in, int fd_out, size_t len) -{ - int pipefd[2]; - size_t spliced=0; - - if (pipe(pipefd) == -1) - return -1; - - while (spliced < len) { - size_t r1,r2; - r1 = splice(fd_in, NULL, pipefd[1], NULL, len-spliced, 0); - if (r1 <= 0) - break; - r2 = splice(pipefd[0], NULL, fd_out, NULL, r1, 0); - if (r1 != r2) - break; - spliced += r1; - } - close(pipefd[0]); - close(pipefd[1]); - - return spliced < len ? -1 : 0; -} - -int client_serve_request(struct client_params* client) -{ - off64_t offset; - struct nbd_request request; - struct nbd_reply reply; - struct unallocated_block** unallocated; - - if (readloop(client->socket, &request, sizeof(request)) == -1) { - if (errno == 0) { - debug("EOF reading request"); - return 1; /* neat point to close the socket */ - } - else { - CLIENT_ERROR_ON_FAILURE(-1, "Error reading request"); - } - } - - reply.magic = htobe32(REPLY_MAGIC); - reply.error = htobe32(0); - memcpy(reply.handle, request.handle, 8); - - debug("request type %d", be32toh(request.type)); - - if (be32toh(request.magic) != REQUEST_MAGIC) - CLIENT_ERROR("Bad magic %08x", be32toh(request.magic)); - - switch (be32toh(request.type)) - { - case REQUEST_READ: - case REQUEST_WRITE: - /* check it's not out of range */ - if (be64toh(request.from) < 0 || - be64toh(request.from)+be32toh(request.len) > client->size) { - debug("request read %ld+%d out of range", - be64toh(request.from), - be32toh(request.len) - ); - reply.error = htobe32(1); - write(client->socket, &reply, sizeof(reply)); - return 0; - } - break; - - case REQUEST_DISCONNECT: - debug("request disconnect"); - return 1; - - default: - CLIENT_ERROR("Unknown request %08x", be32toh(request.type)); - } - - switch (be32toh(request.type)) - { - case REQUEST_READ: - debug("request read %ld+%d", be64toh(request.from), be32toh(request.len)); - write(client->socket, &reply, sizeof(reply)); - - offset = be64toh(request.from); - CLIENT_ERROR_ON_FAILURE( - sendfileloop( - client->socket, - client->fileno, - &offset, - be32toh(request.len) - ), - "sendfile failed from=%ld, len=%d", - offset, - be32toh(request.len) - ); - break; - - case REQUEST_WRITE: - debug("request write %ld+%d", be64toh(request.from), be32toh(request.len)); -#ifdef _LINUX_FIEMAP_H - unallocated = read_unallocated_blocks( - client->fileno, - be64toh(request.from), - be32toh(request.len) - ); - if (unallocated == NULL) - CLIENT_ERROR("Couldn't read unallocated blocks list"); - - CLIENT_ERROR_ON_FAILURE( - read_from_socket_avoiding_holes( - client->socket, - ); - free(fiemap); -#else - CLIENT_ERROR_ON_FAILURE( - readloop( - client->socket, - client->mapped + be64toh(request.from), - be32toh(request.len) - ), - "read failed from=%ld, len=%d", - be64toh(request.from), - be32toh(request.len) - ); -#endif - write(client->socket, &reply, sizeof(reply)); - - break; - } - return 0; -} - -void client_open_file(struct client_params* client) -{ - client->fileno = open(client->filename, O_RDWR|O_DIRECT|O_SYNC); - CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't open %s", - client->filename); - client->size = lseek64(client->fileno, 0, SEEK_END); - CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't seek to end of %s", - client->filename); - client->mapped = mmap64(NULL, client->size, PROT_READ|PROT_WRITE, - MAP_SHARED, client->fileno, 0); - CLIENT_ERROR_ON_FAILURE((long) client->mapped, "Couldn't map file %s", - client->filename); - debug("opened %s size %ld on fd %d @ %p", client->filename, client->size, client->fileno, client->mapped); -} - -void client_send_hello(struct client_params* client) -{ - struct nbd_init init; - - memcpy(init.passwd, INIT_PASSWD, sizeof(INIT_PASSWD)); - init.magic = htobe64(INIT_MAGIC); - init.size = htobe64(client->size); - memset(init.reserved, 0, 128); - CLIENT_ERROR_ON_FAILURE( - writeloop(client->socket, &init, sizeof(init)), - "Couldn't send hello" - ); -} - -void* client_serve(void* client_uncast) -{ - struct client_params* client = (struct client_params*) client_uncast; - - client_open_file(client); - client_send_hello(client); - - while (client_serve_request(client) == 0) - ; - - CLIENT_ERROR_ON_FAILURE( - close(client->socket), - "Couldn't close socket %d", - client->socket - ); - - free(client); - return NULL; -} - -static int testmasks[9] = { 0,128,192,224,240,248,252,254,255 }; - -int is_included_in_acl(int list_length, struct ip_and_mask** list, struct sockaddr* test) -{ - int i; - - for (i=0; i < list_length; i++) { - struct ip_and_mask *entry = list[i]; - int testbits; - char *raw_address1, *raw_address2; - - debug("checking acl entry %d", i); - - if (test->sa_family != entry->ip.family) - continue; - - if (test->sa_family == AF_INET) { - raw_address1 = (char*) - &((struct sockaddr_in*) test)->sin_addr; - raw_address2 = (char*) &entry->ip.v4.sin_addr; - } - else if (test->sa_family == AF_INET6) { - raw_address1 = (char*) - &((struct sockaddr_in6*) test)->sin6_addr; - raw_address2 = (char*) &entry->ip.v6.sin6_addr; - } - - for (testbits = entry->mask; testbits > 0; testbits -= 8) { - debug("testbits=%d, c1=%d, c2=%d", testbits, raw_address1[0], raw_address2[0]); - if (testbits >= 8) { - if (raw_address1[0] != raw_address2[0]) - goto no_match; - } - else { - if ((raw_address1[0] & testmasks[testbits%8]) != - (raw_address2[0] & testmasks[testbits%8]) ) - goto no_match; - } - - raw_address1++; - raw_address2++; - } - - return 1; - - no_match: ; - debug("no match"); - } - - return 0; -} - -void serve_open_socket(struct mode_serve_params* params) -{ - params->server = socket(PF_INET, SOCK_STREAM, 0); - - SERVER_ERROR_ON_FAILURE(params->server, - "Couldn't create server socket"); - - SERVER_ERROR_ON_FAILURE( - bind(params->server, ¶ms->bind_to.generic, - sizeof(params->bind_to.generic)), - "Couldn't bind server to IP address" - ); - - SERVER_ERROR_ON_FAILURE( - listen(params->server, params->tcp_backlog), - "Couldn't listen on server socket" - ); -} - -void serve_accept_loop(struct mode_serve_params* params) -{ - while (1) { - pthread_t client_thread; - struct sockaddr client_address; - struct client_params* client_params; - socklen_t socket_length=0; - - int client_socket = accept(params->server, &client_address, - &socket_length); - - SERVER_ERROR_ON_FAILURE(client_socket, "accept() failed"); - - if (params->acl && - !is_included_in_acl(params->acl_entries, params->acl, &client_address)) { - write(client_socket, "Access control error", 20); - close(client_socket); - continue; - } - - client_params = xmalloc(sizeof(struct client_params)); - client_params->socket = client_socket; - client_params->filename = params->filename; - client_params->block_allocation_map = - params->block_allocation_map; - client_params->block_allocation_map_lock = - params->block_allocation_map_lock; - - client_thread = pthread_create(&client_thread, NULL, - client_serve, client_params); - SERVER_ERROR_ON_FAILURE(client_thread, - "Failed to create client thread"); - /* FIXME: keep track of them? */ - /* FIXME: maybe shouldn't be fatal? */ - } -} - -void do_serve(struct mode_serve_params* params) -{ - serve_open_socket(params); - serve_accept_loop(params); -} - -int socket_connect(struct sockaddr* to) -{ - int fd = socket(PF_INET, SOCK_STREAM, 0); - SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket"); - SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(*to)), - "connect failed"); - return fd; -} - -off64_t socket_nbd_read_hello(int fd) -{ - struct nbd_init init; - SERVER_ERROR_ON_FAILURE(readloop(fd, &init, sizeof(init)), - "Couldn't read init"); - if (strncmp(init.passwd, INIT_PASSWD, 8) != 0) - SERVER_ERROR("wrong passwd"); - if (be64toh(init.magic) != INIT_MAGIC) - SERVER_ERROR("wrong magic (%x)", be64toh(init.magic)); - return be64toh(init.size); -} - -void fill_request(struct nbd_request *request, int type, int from, int len) -{ - request->magic = htobe32(REQUEST_MAGIC); - request->type = htobe32(type); - ((int*) request->handle)[0] = rand(); - ((int*) request->handle)[1] = rand(); - request->from = htobe64(from); - request->len = htobe32(len); -} - -void read_reply(int fd, struct nbd_request *request, struct nbd_reply *reply) -{ - SERVER_ERROR_ON_FAILURE(readloop(fd, reply, sizeof(*reply)), - "Couldn't read reply"); - if (be32toh(reply->magic) != REPLY_MAGIC) - SERVER_ERROR("Reply magic incorrect (%p)", reply->magic); - if (be32toh(reply->error) != 0) - SERVER_ERROR("Server replied with error %d", reply->error); - if (strncmp(request->handle, reply->handle, 8) != 0) - SERVER_ERROR("Did not reply with correct handle"); -} - -void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf) -{ - struct nbd_request request; - struct nbd_reply reply; - - fill_request(&request, REQUEST_READ, from, len); - SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), - "Couldn't write request"); - read_reply(fd, &request, &reply); - - if (out_buf) { - SERVER_ERROR_ON_FAILURE(readloop(fd, out_buf, len), - "Read failed"); - } - else { - SERVER_ERROR_ON_FAILURE( - splice_via_pipe_loop(fd, out_fd, len), - "Splice failed" - ); - } -} - -void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) -{ - struct nbd_request request; - struct nbd_reply reply; - - fill_request(&request, REQUEST_WRITE, from, len); - SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), - "Couldn't write request"); - - if (in_buf) { - SERVER_ERROR_ON_FAILURE(writeloop(fd, in_buf, len), - "Write failed"); - } - else { - SERVER_ERROR_ON_FAILURE( - splice_via_pipe_loop(in_fd, fd, len), - "Splice failed" - ); - } - - read_reply(fd, &request, &reply); -} - -#define CHECK_RANGE(error_type) { \ - off64_t size = socket_nbd_read_hello(params->client); \ - if (params->from < 0 || (params->from + params->len) >= size) \ - SERVER_ERROR(error_type \ - " request %d+%d is out of range given size %d", \ - params->from, params->len, size\ - ); \ -} - -void do_read(struct mode_readwrite_params* params) -{ - params->client = socket_connect(¶ms->connect_to.generic); - CHECK_RANGE("read"); - socket_nbd_read(params->client, params->from, params->len, - params->data_fd, NULL); - close(params->client); -} - -void do_write(struct mode_readwrite_params* params) -{ - params->client = socket_connect(¶ms->connect_to.generic); - CHECK_RANGE("write"); - socket_nbd_write(params->client, params->from, params->len, - params->data_fd, NULL); - close(params->client); -} - #define IS_IP_VALID_CHAR(x) ( ((x) >= '0' && (x) <= '9' ) || \ ((x) >= 'a' && (x) <= 'f') || \ ((x) >= 'A' && (x) <= 'F' ) || \ @@ -862,6 +188,10 @@ void params_readwrite( } } +void do_serve(struct mode_serve_params* params); +void do_read(struct mode_readwrite_params* params); +void do_write(struct mode_readwrite_params* params); + void mode(char* mode, int argc, char **argv) { union mode_params params; @@ -901,7 +231,7 @@ void mode(char* mode, int argc, char **argv) int main(int argc, char** argv) { - server_thread_id = pthread_self(); + error_init(); if (argc < 2) syntax(); diff --git a/ioutil.c b/ioutil.c new file mode 100644 index 0000000..761b74d --- /dev/null +++ b/ioutil.c @@ -0,0 +1,123 @@ +#define _LARGEFILE64_SOURCE +#define _GNU_SOURCE + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "util.h" +#include "bitset.h" + +char* build_allocation_map(int fd, off64_t size, int resolution) +{ + int i; + char *allocation_map = xmalloc((size+resolution)/resolution); + struct fiemap *fiemap; + + fiemap = (struct fiemap*) xmalloc(sizeof(struct fiemap)); + + fiemap->fm_start = 0; + fiemap->fm_length = size; + fiemap->fm_flags = 0; + fiemap->fm_extent_count = 0; + fiemap->fm_mapped_extents = 0; + + /* Find out how many extents there are */ + if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < 0) + return NULL; + + /* Resize fiemap to allow us to read in the extents */ + fiemap = (struct fiemap*)xrealloc( + fiemap, + sizeof(struct fiemap) + ( + sizeof(struct fiemap_extent) * + fiemap->fm_mapped_extents + ) + ); + + fiemap->fm_extent_count = fiemap->fm_mapped_extents; + fiemap->fm_mapped_extents = 0; + + if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < -1) + return NULL; + + for (i=0;ifm_mapped_extents;i++) + bit_set_range( + allocation_map, + fiemap->fm_extents[i].fe_logical / resolution, + fiemap->fm_extents[i].fe_length / resolution + ); + + free(fiemap); + + return allocation_map; +} + + +int writeloop(int filedes, const void *buffer, size_t size) +{ + size_t written=0; + while (written < size) { + size_t result = write(filedes, buffer+written, size-written); + if (result == -1) + return -1; + written += result; + } + return 0; +} + +int readloop(int filedes, void *buffer, size_t size) +{ + size_t readden=0; + while (readden < size) { + size_t result = read(filedes, buffer+readden, size-readden); + if (result == 0 /* EOF */ || result == -1 /* error */) + return -1; + readden += result; + } + return 0; +} + +int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count) +{ + size_t sent=0; + while (sent < count) { + size_t result = sendfile64(out_fd, in_fd, offset+sent, count-sent); + if (result == -1) + return -1; + sent += result; + } + return 0; +} + +int splice_via_pipe_loop(int fd_in, int fd_out, size_t len) +{ + int pipefd[2]; + size_t spliced=0; + + if (pipe(pipefd) == -1) + return -1; + + while (spliced < len) { + size_t r1,r2; + r1 = splice(fd_in, NULL, pipefd[1], NULL, len-spliced, 0); + if (r1 <= 0) + break; + r2 = splice(pipefd[0], NULL, fd_out, NULL, r1, 0); + if (r1 != r2) + break; + spliced += r1; + } + close(pipefd[0]); + close(pipefd[1]); + + return spliced < len ? -1 : 0; +} + + diff --git a/ioutil.h b/ioutil.h new file mode 100644 index 0000000..055a057 --- /dev/null +++ b/ioutil.h @@ -0,0 +1,14 @@ +#ifndef __IOUTIL_H +#define __IOUTIL_H + + +#include "params.h" + +char* build_allocation_map(int fd, off64_t size, int resolution); +int writeloop(int filedes, const void *buffer, size_t size); +int readloop(int filedes, void *buffer, size_t size); +int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count); +int splice_via_pipe_loop(int fd_in, int fd_out, size_t len); + +#endif + diff --git a/nbdtypes.h b/nbdtypes.h new file mode 100644 index 0000000..499a4a1 --- /dev/null +++ b/nbdtypes.h @@ -0,0 +1,36 @@ +#ifndef __NBDTYPES_H +#define __NBDTYPES_H + +/* http://linux.derkeiler.com/Mailing-Lists/Kernel/2003-09/2332.html */ +#define INIT_PASSWD "NBDMAGIC" +#define INIT_MAGIC 0x0000420281861253 +#define REQUEST_MAGIC 0x25609513 +#define REPLY_MAGIC 0x67446698 +#define REQUEST_READ 0 +#define REQUEST_WRITE 1 +#define REQUEST_DISCONNECT 2 + +#include +struct nbd_init { + char passwd[8]; + __be64 magic; + __be64 size; + char reserved[128]; +}; + +struct nbd_request { + __be32 magic; + __be32 type; /* == READ || == WRITE */ + char handle[8]; + __be64 from; + __be32 len; +} __attribute__((packed)); + +struct nbd_reply { + __be32 magic; + __be32 error; /* 0 = ok, else error */ + char handle[8]; /* handle you got from request */ +}; + +#endif + diff --git a/params.h b/params.h new file mode 100644 index 0000000..ac396cc --- /dev/null +++ b/params.h @@ -0,0 +1,64 @@ +#ifndef __PARAMS_H +#define __PARAMS_H + +#define _GNU_SOURCE +#define _LARGEFILE64_SOURCE + +#include +#include +#include +#include + +union mysockaddr { + unsigned short family; + struct sockaddr generic; + struct sockaddr_in v4; + struct sockaddr_in6 v6; +}; + +struct ip_and_mask { + union mysockaddr ip; + int mask; +}; + +struct mode_serve_params { + union mysockaddr bind_to; + int acl_entries; + struct ip_and_mask** acl; + char* filename; + int tcp_backlog; + + int server; + int threads; + + pthread_mutex_t block_allocation_map_lock; + char* block_allocation_map; +}; + +struct mode_readwrite_params { + union mysockaddr connect_to; + off64_t from; + off64_t len; + int data_fd; + int client; +}; + +struct client_params { + int socket; + char* filename; + + int fileno; + off64_t size; + char* mapped; + + pthread_mutex_t block_allocation_map_lock; + char* block_allocation_map; +}; + +union mode_params { + struct mode_serve_params serve; + struct mode_readwrite_params readwrite; +}; + +#endif + diff --git a/readwrite.c b/readwrite.c new file mode 100644 index 0000000..3f571d3 --- /dev/null +++ b/readwrite.c @@ -0,0 +1,124 @@ +#include "nbdtypes.h" +#include "ioutil.h" +#include "util.h" +#include "params.h" + +#include +#include +#include + +int socket_connect(struct sockaddr* to) +{ + int fd = socket(PF_INET, SOCK_STREAM, 0); + SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket"); + SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(*to)), + "connect failed"); + return fd; +} + +off64_t socket_nbd_read_hello(int fd) +{ + struct nbd_init init; + SERVER_ERROR_ON_FAILURE(readloop(fd, &init, sizeof(init)), + "Couldn't read init"); + if (strncmp(init.passwd, INIT_PASSWD, 8) != 0) + SERVER_ERROR("wrong passwd"); + if (be64toh(init.magic) != INIT_MAGIC) + SERVER_ERROR("wrong magic (%x)", be64toh(init.magic)); + return be64toh(init.size); +} + +void fill_request(struct nbd_request *request, int type, int from, int len) +{ + request->magic = htobe32(REQUEST_MAGIC); + request->type = htobe32(type); + ((int*) request->handle)[0] = rand(); + ((int*) request->handle)[1] = rand(); + request->from = htobe64(from); + request->len = htobe32(len); +} + +void read_reply(int fd, struct nbd_request *request, struct nbd_reply *reply) +{ + SERVER_ERROR_ON_FAILURE(readloop(fd, reply, sizeof(*reply)), + "Couldn't read reply"); + if (be32toh(reply->magic) != REPLY_MAGIC) + SERVER_ERROR("Reply magic incorrect (%p)", reply->magic); + if (be32toh(reply->error) != 0) + SERVER_ERROR("Server replied with error %d", reply->error); + if (strncmp(request->handle, reply->handle, 8) != 0) + SERVER_ERROR("Did not reply with correct handle"); +} + +void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf) +{ + struct nbd_request request; + struct nbd_reply reply; + + fill_request(&request, REQUEST_READ, from, len); + SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), + "Couldn't write request"); + read_reply(fd, &request, &reply); + + if (out_buf) { + SERVER_ERROR_ON_FAILURE(readloop(fd, out_buf, len), + "Read failed"); + } + else { + SERVER_ERROR_ON_FAILURE( + splice_via_pipe_loop(fd, out_fd, len), + "Splice failed" + ); + } +} + +void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) +{ + struct nbd_request request; + struct nbd_reply reply; + + fill_request(&request, REQUEST_WRITE, from, len); + SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), + "Couldn't write request"); + + if (in_buf) { + SERVER_ERROR_ON_FAILURE(writeloop(fd, in_buf, len), + "Write failed"); + } + else { + SERVER_ERROR_ON_FAILURE( + splice_via_pipe_loop(in_fd, fd, len), + "Splice failed" + ); + } + + read_reply(fd, &request, &reply); +} + +#define CHECK_RANGE(error_type) { \ + off64_t size = socket_nbd_read_hello(params->client); \ + if (params->from < 0 || (params->from + params->len) >= size) \ + SERVER_ERROR(error_type \ + " request %d+%d is out of range given size %d", \ + params->from, params->len, size\ + ); \ +} + +void do_read(struct mode_readwrite_params* params) +{ + params->client = socket_connect(¶ms->connect_to.generic); + CHECK_RANGE("read"); + socket_nbd_read(params->client, params->from, params->len, + params->data_fd, NULL); + close(params->client); +} + +void do_write(struct mode_readwrite_params* params) +{ + params->client = socket_connect(¶ms->connect_to.generic); + CHECK_RANGE("write"); + socket_nbd_write(params->client, params->from, params->len, + params->data_fd, NULL); + close(params->client); +} + diff --git a/serve.c b/serve.c new file mode 100644 index 0000000..10d2e61 --- /dev/null +++ b/serve.c @@ -0,0 +1,283 @@ +#include "params.h" +#include "nbdtypes.h" +#include "ioutil.h" +#include "util.h" + +#include +#include +#include +#include + +#include +#include +#include + +int client_serve_request(struct client_params* client) +{ + off64_t offset; + struct nbd_request request; + struct nbd_reply reply; +// struct unallocated_block** unallocated; + + if (readloop(client->socket, &request, sizeof(request)) == -1) { + if (errno == 0) { + debug("EOF reading request"); + return 1; /* neat point to close the socket */ + } + else { + CLIENT_ERROR_ON_FAILURE(-1, "Error reading request"); + } + } + + reply.magic = htobe32(REPLY_MAGIC); + reply.error = htobe32(0); + memcpy(reply.handle, request.handle, 8); + + debug("request type %d", be32toh(request.type)); + + if (be32toh(request.magic) != REQUEST_MAGIC) + CLIENT_ERROR("Bad magic %08x", be32toh(request.magic)); + + switch (be32toh(request.type)) + { + case REQUEST_READ: + case REQUEST_WRITE: + /* check it's not out of range */ + if (be64toh(request.from) < 0 || + be64toh(request.from)+be32toh(request.len) > client->size) { + debug("request read %ld+%d out of range", + be64toh(request.from), + be32toh(request.len) + ); + reply.error = htobe32(1); + write(client->socket, &reply, sizeof(reply)); + return 0; + } + break; + + case REQUEST_DISCONNECT: + debug("request disconnect"); + return 1; + + default: + CLIENT_ERROR("Unknown request %08x", be32toh(request.type)); + } + + switch (be32toh(request.type)) + { + case REQUEST_READ: + debug("request read %ld+%d", be64toh(request.from), be32toh(request.len)); + write(client->socket, &reply, sizeof(reply)); + + offset = be64toh(request.from); + CLIENT_ERROR_ON_FAILURE( + sendfileloop( + client->socket, + client->fileno, + &offset, + be32toh(request.len) + ), + "sendfile failed from=%ld, len=%d", + offset, + be32toh(request.len) + ); + break; + + case REQUEST_WRITE: + debug("request write %ld+%d", be64toh(request.from), be32toh(request.len)); +#ifdef _LINUX_FIEMAP_H + unallocated = read_unallocated_blocks( + client->fileno, + be64toh(request.from), + be32toh(request.len) + ); + if (unallocated == NULL) + CLIENT_ERROR("Couldn't read unallocated blocks list"); + + CLIENT_ERROR_ON_FAILURE( + read_from_socket_avoiding_holes( + client->socket, + ); + free(fiemap); +#else + CLIENT_ERROR_ON_FAILURE( + readloop( + client->socket, + client->mapped + be64toh(request.from), + be32toh(request.len) + ), + "read failed from=%ld, len=%d", + be64toh(request.from), + be32toh(request.len) + ); +#endif + write(client->socket, &reply, sizeof(reply)); + + break; + } + return 0; +} + +void client_open_file(struct client_params* client) +{ + client->fileno = open(client->filename, O_RDWR|O_DIRECT|O_SYNC); + CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't open %s", + client->filename); + client->size = lseek64(client->fileno, 0, SEEK_END); + CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't seek to end of %s", + client->filename); + client->mapped = mmap64(NULL, client->size, PROT_READ|PROT_WRITE, + MAP_SHARED, client->fileno, 0); + CLIENT_ERROR_ON_FAILURE((long) client->mapped, "Couldn't map file %s", + client->filename); + debug("opened %s size %ld on fd %d @ %p", client->filename, client->size, client->fileno, client->mapped); +} + +void client_send_hello(struct client_params* client) +{ + struct nbd_init init; + + memcpy(init.passwd, INIT_PASSWD, sizeof(INIT_PASSWD)); + init.magic = htobe64(INIT_MAGIC); + init.size = htobe64(client->size); + memset(init.reserved, 0, 128); + CLIENT_ERROR_ON_FAILURE( + writeloop(client->socket, &init, sizeof(init)), + "Couldn't send hello" + ); +} + +void* client_serve(void* client_uncast) +{ + struct client_params* client = (struct client_params*) client_uncast; + + client_open_file(client); + client_send_hello(client); + + while (client_serve_request(client) == 0) + ; + + CLIENT_ERROR_ON_FAILURE( + close(client->socket), + "Couldn't close socket %d", + client->socket + ); + + free(client); + return NULL; +} + +static int testmasks[9] = { 0,128,192,224,240,248,252,254,255 }; + +int is_included_in_acl(int list_length, struct ip_and_mask** list, struct sockaddr* test) +{ + int i; + + for (i=0; i < list_length; i++) { + struct ip_and_mask *entry = list[i]; + int testbits; + char *raw_address1, *raw_address2; + + debug("checking acl entry %d", i); + + if (test->sa_family != entry->ip.family) + continue; + + if (test->sa_family == AF_INET) { + raw_address1 = (char*) + &((struct sockaddr_in*) test)->sin_addr; + raw_address2 = (char*) &entry->ip.v4.sin_addr; + } + else if (test->sa_family == AF_INET6) { + raw_address1 = (char*) + &((struct sockaddr_in6*) test)->sin6_addr; + raw_address2 = (char*) &entry->ip.v6.sin6_addr; + } + + for (testbits = entry->mask; testbits > 0; testbits -= 8) { + debug("testbits=%d, c1=%d, c2=%d", testbits, raw_address1[0], raw_address2[0]); + if (testbits >= 8) { + if (raw_address1[0] != raw_address2[0]) + goto no_match; + } + else { + if ((raw_address1[0] & testmasks[testbits%8]) != + (raw_address2[0] & testmasks[testbits%8]) ) + goto no_match; + } + + raw_address1++; + raw_address2++; + } + + return 1; + + no_match: ; + debug("no match"); + } + + return 0; +} + +void serve_open_socket(struct mode_serve_params* params) +{ + params->server = socket(PF_INET, SOCK_STREAM, 0); + + SERVER_ERROR_ON_FAILURE(params->server, + "Couldn't create server socket"); + + SERVER_ERROR_ON_FAILURE( + bind(params->server, ¶ms->bind_to.generic, + sizeof(params->bind_to.generic)), + "Couldn't bind server to IP address" + ); + + SERVER_ERROR_ON_FAILURE( + listen(params->server, params->tcp_backlog), + "Couldn't listen on server socket" + ); +} + +void serve_accept_loop(struct mode_serve_params* params) +{ + while (1) { + pthread_t client_thread; + struct sockaddr client_address; + struct client_params* client_params; + socklen_t socket_length=0; + + int client_socket = accept(params->server, &client_address, + &socket_length); + + SERVER_ERROR_ON_FAILURE(client_socket, "accept() failed"); + + if (params->acl && + !is_included_in_acl(params->acl_entries, params->acl, &client_address)) { + write(client_socket, "Access control error", 20); + close(client_socket); + continue; + } + + client_params = xmalloc(sizeof(struct client_params)); + client_params->socket = client_socket; + client_params->filename = params->filename; + client_params->block_allocation_map = + params->block_allocation_map; + client_params->block_allocation_map_lock = + params->block_allocation_map_lock; + + client_thread = pthread_create(&client_thread, NULL, + client_serve, client_params); + SERVER_ERROR_ON_FAILURE(client_thread, + "Failed to create client thread"); + /* FIXME: keep track of them? */ + /* FIXME: maybe shouldn't be fatal? */ + } +} + +void do_serve(struct mode_serve_params* params) +{ + serve_open_socket(params); + serve_accept_loop(params); +} + diff --git a/util.c b/util.c new file mode 100644 index 0000000..fd6a1ac --- /dev/null +++ b/util.c @@ -0,0 +1,58 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "util.h" + +static pthread_t main_thread; + +void error_init() +{ + main_thread = pthread_self(); +} + +void error(int consult_errno, int close_socket, const char* format, ...) +{ + va_list argptr; + + fprintf(stderr, "*** "); + + va_start(argptr, format); + vfprintf(stderr, format, argptr); + va_end(argptr); + + if (consult_errno) { + fprintf(stderr, " (errno=%d, %s)", errno, strerror(errno)); + } + + if (close_socket) + close(close_socket); + + fprintf(stderr, "\n"); + + if (pthread_equal(pthread_self(), main_thread)) + exit(1); + else + pthread_exit((void*) 1); +} + +void* xrealloc(void* ptr, size_t size) +{ + void* p = realloc(ptr, size); + if (p == NULL) + SERVER_ERROR("couldn't xrealloc %d bytes", size); + return p; +} + +void* xmalloc(size_t size) +{ + void* p = xrealloc(NULL, size); + memset(p, 0, size); + return p; +} + diff --git a/util.h b/util.h new file mode 100644 index 0000000..b01f633 --- /dev/null +++ b/util.h @@ -0,0 +1,33 @@ +#ifndef __UTIL_H +#define __UTIL_H + +#include +#include + +void error_init(); + +void error(int consult_errno, int close_socket, const char* format, ...); + +void* xrealloc(void* ptr, size_t size); + +void* xmalloc(size_t size); + +#ifndef DEBUG +# define debug(msg, ...) +#else +# include +# define debug(msg, ...) fprintf(stderr, "%08x %4d: " msg "\n" , \ + (int) pthread_self(), (int) clock(), ##__VA_ARGS__) +#endif + +#define CLIENT_ERROR(msg, ...) \ + error(0, client->socket, msg, ##__VA_ARGS__) +#define CLIENT_ERROR_ON_FAILURE(test, msg, ...) \ + if (test < 0) { error(1, client->socket, msg, ##__VA_ARGS__); } +#define SERVER_ERROR(msg, ...) \ + error(0, 0, msg, ##__VA_ARGS__) +#define SERVER_ERROR_ON_FAILURE(test, msg, ...) \ + if (test < 0) { error(1, 0, msg, ##__VA_ARGS__); } + +#endif +