diff --git a/Rakefile b/Rakefile index 34ccc32..41054f1 100644 --- a/Rakefile +++ b/Rakefile @@ -12,7 +12,9 @@ TEST_SOURCES = FileList['tests/*.c'] TEST_OBJECTS = TEST_SOURCES.pathmap( "%{^tests,build/tests}X.o" ) LIBS = %w( pthread ) -CCFLAGS = %w( -Wall +CCFLAGS = %w( + -D_GNU_SOURCE=1 + -Wall -Wextra -Werror-implicit-function-declaration -Wstrict-prototypes @@ -129,8 +131,23 @@ file check("serve") => gcc_link t.name, t.prerequisites + [LIBCHECK] end +file check("readwrite") => +%w{build/tests/check_readwrite.o + build/readwrite.o + build/client.o + build/self_pipe.o + build/serve.o + build/parse.o + build/acl.o + build/control.o + build/nbdtypes.o + build/ioutil.o + build/util.o} do |t| + gcc_link t.name, t.prerequisites + [LIBCHECK] +end -(TEST_MODULES- %w{acl client serve}).each do |m| + +(TEST_MODULES- %w{acl client serve readwrite}).each do |m| tgt = "build/tests/check_#{m}.o" deps = ["build/ioutil.o", "build/util.o"] maybe_obj_name = "build/#{m}.o" diff --git a/src/acl.c b/src/acl.c index a8d8518..acd66c0 100644 --- a/src/acl.c +++ b/src/acl.c @@ -55,13 +55,13 @@ static int is_included_in_acl(int list_length, struct ip_and_mask (*list)[], uni for (testbits = entry->mask; testbits > 0; testbits -= 8) { debug("testbits=%d, c1=%02x, c2=%02x", testbits, raw_address1[0], raw_address2[0]); if (testbits >= 8) { - if (raw_address1[0] != raw_address2[0]) - goto no_match; + if (raw_address1[0] != raw_address2[0]) { goto no_match; } } else { if ((raw_address1[0] & testmasks[testbits%8]) != - (raw_address2[0] & testmasks[testbits%8]) ) + (raw_address2[0] & testmasks[testbits%8]) ) { goto no_match; + } } raw_address1++; diff --git a/src/bitset.h b/src/bitset.h index 522d291..23f688f 100644 --- a/src/bitset.h +++ b/src/bitset.h @@ -20,10 +20,8 @@ static inline int bit_is_clear(char* b, int idx) { } /** Tests whether the bit at ''idx'' in array ''b'' has value ''value'' */ static inline int bit_has_value(char* b, int idx, int value) { - if (value) - return bit_is_set(b, idx); - else - return bit_is_clear(b, idx); + if (value) { return bit_is_set(b, idx); } + else { return bit_is_clear(b, idx); } } /** Sets the bit ''idx'' in array ''b'' */ static inline void bit_set(char* b, int idx) { @@ -37,21 +35,15 @@ static inline void bit_clear(char* b, int idx) { } /** Sets ''len'' bits in array ''b'' starting at offset ''from'' */ static inline void bit_set_range(char* b, int from, int len) { - for (; from%8 != 0 && len > 0; len--) - bit_set(b, from++); - if (len >= 8) - memset(b+(from/8), 255, len/8); - for (; len > 0; len--) - bit_set(b, from++); + for (; from%8 != 0 && len > 0; len--) { bit_set(b, from++); } + if (len >= 8) { memset(b+(from/8), 255, len/8); } + for (; len > 0; len--) { bit_set(b, from++); } } /** Clears ''len'' bits in array ''b'' starting at offset ''from'' */ static inline void bit_clear_range(char* b, int from, int len) { - for (; from%8 != 0 && len > 0; len--) - bit_clear(b, from++); - if (len >= 8) - memset(b+(from/8), 0, len/8); - for (; len > 0; len--) - bit_clear(b, from++); + for (; from%8 != 0 && len > 0; len--) { bit_clear(b, from++); } + if (len >= 8) { memset(b+(from/8), 0, len/8); } + for (; len > 0; len--) { bit_clear(b, from++); } } /** Counts the number of contiguous bits in array ''b'', starting at ''from'' diff --git a/src/client.c b/src/client.c index 1b9ba7a..3e3f95d 100644 --- a/src/client.c +++ b/src/client.c @@ -19,6 +19,7 @@ struct client *client_create( struct server *serve, int socket ) struct client *c; c = xmalloc( sizeof( struct server ) ); + c->stopped = 0; c->socket = socket; c->serve = serve; @@ -88,11 +89,9 @@ void write_not_zeroes(struct client* client, uint64_t from, int len) for (i=0; iserve->size; i+=map->resolution) { int here = (from >= i && from < i+map->resolution); - if (here) - fprintf(stderr, ">"); + if (here) { fprintf(stderr, ">"); } fprintf(stderr, bitset_is_set_at(map, i) ? "1" : "0"); - if (here) - fprintf(stderr, "<"); + if (here) { fprintf(stderr, "<"); } } fprintf(stderr, "\n"); } @@ -172,6 +171,7 @@ int client_read_request( struct client * client , struct nbd_request *out_reques "select() failed"); if ( self_pipe_fd_isset( client->stop_signal, &fds ) ){ + debug("Client received stop signal."); return 0; } @@ -181,7 +181,7 @@ int client_read_request( struct client * client , struct nbd_request *out_reques return 0; /* neat point to close the socket */ } else { - FATAL_IF_NEGATIVE(-1, "Error reading request"); + fatal("Error reading request"); } } @@ -239,8 +239,9 @@ int client_request_needs_reply( struct client * client, struct nbd_request reque { debug("request type %d", request.type); - if (request.magic != REQUEST_MAGIC) + if (request.magic != REQUEST_MAGIC) { fatal("Bad magic %08x", request.magic); + } switch (request.type) { @@ -376,12 +377,11 @@ void client_cleanup(struct client* client, { info("client cleanup"); - if (client->socket) - close(client->socket); - if (client->mapped) + if (client->socket) { close(client->socket); } + if (client->mapped) { munmap(client->mapped, client->serve->size); - if (client->fileno) - close(client->fileno); + } + if (client->fileno) { close(client->fileno); } } void* client_serve(void* client_uncast) @@ -390,7 +390,6 @@ void* client_serve(void* client_uncast) error_set_handler((cleanup_handler*) client_cleanup, client); - //client_open_file(client); FATAL_IF_NEGATIVE( open_and_mmap( client->serve->filename, @@ -404,6 +403,7 @@ void* client_serve(void* client_uncast) while (client_serve_request(client) == 0) ; + client->stopped = 1; FATAL_IF_NEGATIVE( close(client->socket), @@ -411,6 +411,7 @@ void* client_serve(void* client_uncast) client->socket ); + debug("Cleaning up normally in thread %p", pthread_self()); client_cleanup(client, 0); return NULL; diff --git a/src/client.h b/src/client.h index eccdcc1..6af19f8 100644 --- a/src/client.h +++ b/src/client.h @@ -3,6 +3,14 @@ struct client { + /* When we call pthread_join, if the thread is already dead + * we can get an ESRCH. Since we have no other way to tell + * if that ESRCH is from a dead thread or a thread that never + * existed, we use a `stopped` flag to indicate a thread which + * did exist, but went away. Only check this after a + * pthread_join call. + */ + int stopped; int socket; int fileno; diff --git a/src/control.c b/src/control.c index d104eef..3e6aadf 100644 --- a/src/control.c +++ b/src/control.c @@ -39,6 +39,51 @@ #include #include +struct mirror_status * mirror_status_create( + struct server * serve, + int fd, + int max_Bps, + int action_at_finish) +{ + /* FIXME: shouldn't map_fd get closed? */ + int map_fd; + off64_t size; + struct mirror_status * mirror; + + NULLCHECK( serve ); + + mirror = xmalloc(sizeof(struct mirror_status)); + mirror->client = fd; + mirror->max_bytes_per_second = max_Bps; + mirror->action_at_finish = action_at_finish; + + FATAL_IF_NEGATIVE( + open_and_mmap( + serve->filename, + &map_fd, + &size, + (void**) &mirror->mapped + ), + "Failed to open and mmap %s", + serve->filename + ); + + mirror->dirty_map = bitset_alloc(size, 4096); + bitset_set_range(mirror->dirty_map, 0, size); + + return mirror; +} + + +void mirror_status_destroy( struct mirror_status *mirror ) +{ + NULLCHECK( mirror ); + close(mirror->client); + free(mirror->dirty_map); + free(mirror); +} + + /** The mirror code will split NBD writes, making them this long as a maximum */ static const int mirror_longest_write = 8<<20; @@ -52,122 +97,136 @@ static const unsigned int mirror_last_pass_after_bytes_written = 100<<20; */ static const int mirror_maximum_passes = 7; +/* A single mirror pass over the disc, optionally locking IO around the + * transfer. + */ +int mirror_pass(struct server * serve, int should_lock, uint64_t *written) +{ + uint64_t current = 0; + int success = 1; + struct bitset_mapping *map = serve->mirror->dirty_map; + *written = 0; + + while (current < serve->size) { + int run = bitset_run_count(map, current, mirror_longest_write); + + debug("mirror current=%ld, run=%d", current, run); + + /* FIXME: we could avoid sending sparse areas of the + * disc here, and probably save a lot of bandwidth and + * time (if we know the destination starts off zeroed). + */ + if (bitset_is_set_at(map, current)) { + /* We've found a dirty area, send it */ + debug("^^^ writing"); + + /* We need to stop the main thread from working + * because it might corrupt the dirty map. This + * is likely to slow things down but will be + * safe. + */ + if (should_lock) { server_lock_io( serve ); } + { + /** FIXME: do something useful with bytes/second */ + + /** FIXME: error handling code here won't unlock */ + socket_nbd_write( serve->mirror->client, + current, + run, + 0, + serve->mirror->mapped + current); + + /* now mark it clean */ + bitset_clear_range(map, current, run); + } + if (should_lock) { server_unlock_io( serve ); } + + *written += run; + } + current += run; + + if (serve->mirror->signal_abandon) { + success = 0; + break; + } + } + + return success; +} + + +void mirror_on_exit( struct server * serve ) +{ + serve_signal_close( serve ); + /* We have to wait until the server is closed before unlocking + * IO. This is because the client threads check to see if the + * server is still open before reading or writing inside their + * own locks. If we don't wait for the close, there's no way to + * guarantee the server thread will win the race and we risk the + * clients seeing a "successful" write to a dead disc image. + */ + serve_wait_for_close( serve ); +} + /** Thread launched to drive mirror process */ void* mirror_runner(void* serve_params_uncast) { - const int last_pass = mirror_maximum_passes-1; int pass; struct server *serve = (struct server*) serve_params_uncast; - struct bitset_mapping *map = serve->mirror->dirty_map; + uint64_t written; + + NULLCHECK( serve ); + NULLCHECK( serve->mirror ); + NULLCHECK( serve->mirror->dirty_map ); + + debug("Starting mirror" ); - for (pass=0; pass < mirror_maximum_passes; pass++) { - uint64_t current = 0; - uint64_t written = 0; - + for (pass=0; pass < mirror_maximum_passes-1; pass++) { debug("mirror start pass=%d", pass); - if (pass == last_pass) { - server_lock_io( serve ); - } - - while (current < serve->size) { - int run; - - run = bitset_run_count(map, current, mirror_longest_write); - - debug("mirror current=%ld, run=%d", current, run); - - /* FIXME: we could avoid sending sparse areas of the - * disc here, and probably save a lot of bandwidth and - * time (if we know the destination starts off zeroed). - */ - if (bitset_is_set_at(map, current)) { - /* We've found a dirty area, send it */ - debug("^^^ writing"); - - /* We need to stop the main thread from working - * because it might corrupt the dirty map. This - * is likely to slow things down but will be - * safe. - */ - if (pass < last_pass) { - server_lock_io( serve ); - } - - /** FIXME: do something useful with bytes/second */ - - /** FIXME: error handling code here won't unlock */ - socket_nbd_write( - serve->mirror->client, - current, - run, - 0, - serve->mirror->mapped + current - ); - - /* now mark it clean */ - bitset_clear_range(map, current, run); - - if (pass < last_pass) { - server_unlock_io( serve ); - } - - written += run; - } - current += run; - - if (serve->mirror->signal_abandon) { - if (pass == last_pass) - server_unlock_io( serve ); - close(serve->mirror->client); - goto abandon_mirror; - } + if ( !mirror_pass( serve, 1, &written ) ){ + goto abandon_mirror; } /* if we've not written anything */ - if (written < mirror_last_pass_after_bytes_written) - pass = last_pass; + if (written < mirror_last_pass_after_bytes_written) { break; } } - - /* a successful finish ends here */ - switch (serve->mirror->action_at_finish) + + server_lock_io( serve ); { - case ACTION_PROXY: - debug("proxy!"); - serve->proxy_fd = serve->mirror->client; - /* don't close our file descriptor, we still need it! */ - break; - case ACTION_EXIT: - debug("exit!"); - close(serve->mirror->client); - serve_signal_close( serve ); - /* fall through */ - case ACTION_NOTHING: - debug("nothing!"); - close(serve->mirror->client); + if ( mirror_pass( serve, 0, &written ) && + ACTION_EXIT == serve->mirror->action_at_finish) { + debug("exit!"); + mirror_on_exit( serve ); + info("Server closed, quitting " + "after successful migration"); + } } server_unlock_io( serve ); - + abandon_mirror: - free(serve->mirror->dirty_map); - free(serve->mirror); + mirror_status_destroy( serve->mirror ); serve->mirror = NULL; /* and we're gone */ return NULL; } + #define write_socket(msg) write(client->socket, (msg "\n"), strlen((msg))+1) /** Command parser to start mirror process from socket input */ int control_mirror(struct control_params* client, int linesc, char** lines) { - off64_t size, remote_size; - int fd, map_fd; + NULLCHECK( client ); + + off64_t remote_size; + struct server * serve = client->serve; + int fd; struct mirror_status *mirror; union mysockaddr connect_to; union mysockaddr connect_from; int use_connect_from = 0; - uint64_t max_bytes_per_second; + uint64_t max_Bps; int action_at_finish; int raw_port; @@ -197,21 +256,21 @@ int control_mirror(struct control_params* client, int linesc, char** lines) use_connect_from = 1; } - max_bytes_per_second = 0; + max_Bps = 0; if (linesc > 3) { - max_bytes_per_second = atoi(lines[2]); + max_Bps = atoi(lines[2]); } - action_at_finish = ACTION_PROXY; + action_at_finish = ACTION_EXIT; if (linesc > 4) { - if (strcmp("proxy", lines[3]) == 0) - action_at_finish = ACTION_PROXY; - else if (strcmp("exit", lines[3]) == 0) + if (strcmp("exit", lines[3]) == 0) { action_at_finish = ACTION_EXIT; - else if (strcmp("nothing", lines[3]) == 0) + } + else if (strcmp("nothing", lines[3]) == 0) { action_at_finish = ACTION_NOTHING; + } else { - write_socket("1: action must be one of 'proxy', 'exit' or 'nothing'"); + write_socket("1: action must be 'exit' or 'nothing'"); return -1; } } @@ -222,42 +281,29 @@ int control_mirror(struct control_params* client, int linesc, char** lines) } /** I don't like use_connect_from but socket_connect doesn't take *mysockaddr :( */ - if (use_connect_from) - fd = socket_connect(&connect_to.generic, &connect_from.generic); - else - fd = socket_connect(&connect_to.generic, NULL); - + struct sockaddr *afrom = use_connect_from ? &connect_from.generic : NULL; + fd = socket_connect(&connect_to.generic, afrom); remote_size = socket_nbd_read_hello(fd); - remote_size = remote_size; // shush compiler + if( remote_size != (off64_t)serve->size ){ + warn("Remote size (%d) doesn't match local (%d)", remote_size, serve->size ); + write_socket( "1: remote size (%d) doesn't match local (%d)"); + close(fd); + return -1; + } - mirror = xmalloc(sizeof(struct mirror_status)); - mirror->client = fd; - mirror->max_bytes_per_second = max_bytes_per_second; - mirror->action_at_finish = action_at_finish; + mirror = mirror_status_create( serve, + fd, + max_Bps , + action_at_finish ); + serve->mirror = mirror; - FATAL_IF_NEGATIVE( - open_and_mmap( - client->serve->filename, - &map_fd, - &size, - (void**) &mirror->mapped - ), - "Failed to open and mmap %s", - client->serve->filename - ); - - mirror->dirty_map = bitset_alloc(size, 4096); - bitset_set_range(mirror->dirty_map, 0, size); - - client->serve->mirror = mirror; - - FATAL_IF_NEGATIVE( /* FIXME should free mirror on error */ - pthread_create( + FATAL_IF( /* FIXME should free mirror on error */ + 0 != pthread_create( &mirror->thread, NULL, mirror_runner, - client->serve + serve ), "Failed to create mirror thread" ); @@ -303,8 +349,7 @@ int control_status( void control_cleanup(struct control_params* client, int fatal __attribute__ ((unused)) ) { - if (client->socket) - close(client->socket); + if (client->socket) { close(client->socket); } free(client); } @@ -328,24 +373,28 @@ void* control_serve(void* client_uncast) /* ignore failure */ } else if (strcmp(lines[0], "acl") == 0) { - if (control_acl(client, linesc-1, lines+1) < 0) + if (control_acl(client, linesc-1, lines+1) < 0) { finished = 1; + } } else if (strcmp(lines[0], "mirror") == 0) { - if (control_mirror(client, linesc-1, lines+1) < 0) + if (control_mirror(client, linesc-1, lines+1) < 0) { finished = 1; + } } else if (strcmp(lines[0], "status") == 0) { - if (control_status(client, linesc-1, lines+1) < 0) + if (control_status(client, linesc-1, lines+1) < 0) { finished = 1; + } } else { write(client->socket, "10: unknown command\n", 23); finished = 1; } - for (i=0; isocket = client_fd; control_params->serve = params; - FATAL_IF_NEGATIVE( - pthread_create( + FATAL_IF( + 0 != pthread_create( &control_thread, NULL, control_serve, @@ -379,8 +428,7 @@ void serve_open_control_socket(struct server* params) { struct sockaddr_un bind_address; - if (!params->control_socket_name) - return; + if (!params->control_socket_name) { return; } params->control_fd = socket(AF_UNIX, SOCK_STREAM, 0); FATAL_IF_NEGATIVE(params->control_fd , diff --git a/src/flexnbd.c b/src/flexnbd.c index 5992c0b..7b59417 100644 --- a/src/flexnbd.c +++ b/src/flexnbd.c @@ -81,8 +81,10 @@ void params_readwrite( s_ip_address ); - if (s_bind_address != NULL && parse_ip_to_sockaddr(&out->connect_from.generic, s_bind_address) == 0) + if (s_bind_address != NULL && + parse_ip_to_sockaddr(&out->connect_from.generic, s_bind_address) == 0) { fatal("Couldn't parse bind address '%s'", s_bind_address); + } parse_port( s_port, &out->connect_to.v4 ); @@ -252,8 +254,7 @@ int mode_serve( int argc, char *argv[] ) while (1) { c = getopt_long(argc, argv, serve_short_options, serve_options, NULL); - if ( c == -1 ) - break; + if ( c == -1 ) { break; } read_serve_param( c, &ip_addr, &ip_port, &file, &sock, &default_deny ); } @@ -290,8 +291,7 @@ int mode_read( int argc, char *argv[] ) while (1){ c = getopt_long(argc, argv, read_short_options, read_options, NULL); - if ( c == -1 ) - break; + if ( c == -1 ) { break; } read_readwrite_param( c, &ip_addr, &ip_port, &bind_addr, &from, &size ); } @@ -326,8 +326,7 @@ int mode_write( int argc, char *argv[] ) while (1){ c = getopt_long(argc, argv, write_short_options, write_options, NULL); - if ( c == -1 ) - break; + if ( c == -1 ) { break; } read_readwrite_param( c, &ip_addr, &ip_port, &bind_addr, &from, &size ); } @@ -355,7 +354,7 @@ int mode_acl( int argc, char *argv[] ) while (1) { c = getopt_long( argc, argv, acl_short_options, acl_options, NULL ); - if ( c == -1 ) break; + if ( c == -1 ) { break; } read_acl_param( c, &sock ); } @@ -382,7 +381,7 @@ int mode_mirror( int argc, char *argv[] ) while (1) { c = getopt_long( argc, argv, mirror_short_options, mirror_options, NULL); - if ( -1 == c ) break; + if ( -1 == c ) { break; } read_mirror_param( c, &sock, &remote_argv[0], &remote_argv[1], &remote_argv[2] ); } @@ -396,10 +395,12 @@ int mode_mirror( int argc, char *argv[] ) } if ( err ) { exit_err( mirror_help_text ); } - if (argv[2] == NULL) + if (remote_argv[2] == NULL) { do_remote_command( "mirror", sock, 2, remote_argv ); - else + } + else { do_remote_command( "mirror", sock, 3, remote_argv ); + } return 0; } @@ -412,7 +413,7 @@ int mode_status( int argc, char *argv[] ) while (1) { c = getopt_long( argc, argv, status_short_options, status_options, NULL ); - if ( -1 == c ) break; + if ( -1 == c ) { break; } read_status_param( c, &sock ); } diff --git a/src/ioutil.c b/src/ioutil.c index e8f3af7..0d1883c 100644 --- a/src/ioutil.c +++ b/src/ioutil.c @@ -1,6 +1,3 @@ -#define _LARGEFILE64_SOURCE -#define _GNU_SOURCE - #include #include #include @@ -29,8 +26,9 @@ struct bitset_mapping* build_allocation_map(int fd, uint64_t size, int resolutio fiemap_count->fm_mapped_extents = 0; /* Find out how many extents there are */ - if (ioctl(fd, FS_IOC_FIEMAP, fiemap_count) < 0) + if (ioctl(fd, FS_IOC_FIEMAP, fiemap_count) < 0) { return NULL; + } /* Resize fiemap to allow us to read in the extents */ fiemap = (struct fiemap*)xmalloc( @@ -80,20 +78,24 @@ int open_and_mmap(char* filename, int* out_fd, off64_t *out_size, void **out_map off64_t size; *out_fd = open(filename, O_RDWR|O_DIRECT|O_SYNC); - if (*out_fd < 1) + if (*out_fd < 1) { return *out_fd; + } size = lseek64(*out_fd, 0, SEEK_END); - if (size < 0) + if (size < 0) { return size; - if (out_size) + } + if (out_size) { *out_size = size; + } if (out_map) { *out_map = mmap64(NULL, size, PROT_READ|PROT_WRITE, MAP_SHARED, *out_fd, 0); - if (((long) *out_map) == -1) + if (((long) *out_map) == -1) { return -1; + } } debug("opened %s size %ld on fd %d @ %p", filename, size, *out_fd, *out_map); @@ -173,20 +175,19 @@ int splice_via_pipe_loop(int fd_in, int fd_out, size_t len) int pipefd[2]; /* read end, write end */ size_t spliced=0; - if (pipe(pipefd) == -1) + if (pipe(pipefd) == -1) { return -1; + } while (spliced < len) { ssize_t run = len-spliced; ssize_t s2, s1 = spliceloop(fd_in, NULL, pipefd[1], NULL, run, SPLICE_F_NONBLOCK); /*if (run > 65535) run = 65535;*/ - if (s1 < 0) - break; + if (s1 < 0) { break; } s2 = spliceloop(pipefd[0], NULL, fd_out, NULL, s1, 0); - if (s2 < 0) - break; + if (s2 < 0) { break; } spliced += s2; } close(pipefd[0]); @@ -202,10 +203,8 @@ int read_until_newline(int fd, char* buf, int bufsize) for (cur=0; cur < bufsize; cur++) { int result = read(fd, buf+cur, 1); - if (result < 0) - return -1; - if (buf[cur] == 10) - break; + if (result < 0) { return -1; } + if (buf[cur] == 10) { break; } } buf[cur++] = 0; @@ -221,12 +220,14 @@ int read_lines_until_blankline(int fd, int max_line_length, char ***lines) memset(line, 0, max_line_length+1); while (1) { - if (read_until_newline(fd, line, max_line_length) < 0) + if (read_until_newline(fd, line, max_line_length) < 0) { return lines_count; + } *lines = xrealloc(*lines, (lines_count+1) * sizeof(char*)); (*lines)[lines_count] = strdup(line); - if ((*lines)[lines_count][0] == 0) + if ((*lines)[lines_count][0] == 0) { return lines_count; + } lines_count++; } } diff --git a/src/nbdtypes.h b/src/nbdtypes.h index f902aa1..5f3cfaf 100644 --- a/src/nbdtypes.h +++ b/src/nbdtypes.h @@ -10,10 +10,7 @@ #define REQUEST_READ 0 #define REQUEST_WRITE 1 #define REQUEST_DISCONNECT 2 - -#ifndef _LARGEFILE64_SOURCE -# define _LARGEFILE64_SOURCE -#endif +#define REQUEST_ENTRUST (1<<16) #include #include diff --git a/src/parse.c b/src/parse.c index 70f34d4..790b72f 100644 --- a/src/parse.c +++ b/src/parse.c @@ -18,10 +18,10 @@ int parse_ip_to_sockaddr(struct sockaddr* out, char* src) /* allow user to start with [ and end with any other invalid char */ { int i=0, j=0; - if (src[i] == '[') - i++; - for (; i<64 && IS_IP_VALID_CHAR(src[i]); i++) + if (src[i] == '[') { i++; } + for (; i<64 && IS_IP_VALID_CHAR(src[i]); i++) { temp[j++] = src[i]; + } temp[j] = 0; } @@ -73,8 +73,9 @@ int parse_acl(struct ip_and_mask (**out)[], int max, char **entries) if (entries[i][j] == '/') { outentry->mask = atoi(entries[i]+j+1); - if (outentry->mask < 1 || outentry->mask > MAX_MASK_BITS) + if (outentry->mask < 1 || outentry->mask > MAX_MASK_BITS) { return i; + } } else { outentry->mask = MAX_MASK_BITS; diff --git a/src/readwrite.c b/src/readwrite.c index c427742..8f1ad9b 100644 --- a/src/readwrite.c +++ b/src/readwrite.c @@ -12,11 +12,12 @@ int socket_connect(struct sockaddr* to, struct sockaddr* from) int fd = socket(to->sa_family == AF_INET ? PF_INET : PF_INET6, SOCK_STREAM, 0); FATAL_IF_NEGATIVE(fd, "Couldn't create client socket"); - if (NULL != from) + if (NULL != from) { FATAL_IF_NEGATIVE( bind(fd, from, sizeof(struct sockaddr_in6)), "bind() failed" ); + } FATAL_IF_NEGATIVE( connect(fd, to, sizeof(struct sockaddr_in6)),"connect failed" @@ -29,10 +30,12 @@ off64_t socket_nbd_read_hello(int fd) struct nbd_init init; FATAL_IF_NEGATIVE(readloop(fd, &init, sizeof(init)), "Couldn't read init"); - if (strncmp(init.passwd, INIT_PASSWD, 8) != 0) + if (strncmp(init.passwd, INIT_PASSWD, 8) != 0) { fatal("wrong passwd"); - if (be64toh(init.magic) != INIT_MAGIC) + } + if (be64toh(init.magic) != INIT_MAGIC) { fatal("wrong magic (%x)", be64toh(init.magic)); + } return be64toh(init.size); } @@ -50,12 +53,15 @@ void read_reply(int fd, struct nbd_request *request, struct nbd_reply *reply) { FATAL_IF_NEGATIVE(readloop(fd, reply, sizeof(*reply)), "Couldn't read reply"); - if (be32toh(reply->magic) != REPLY_MAGIC) + if (be32toh(reply->magic) != REPLY_MAGIC) { fatal("Reply magic incorrect (%p)", be32toh(reply->magic)); - if (be32toh(reply->error) != 0) + } + if (be32toh(reply->error) != 0) { fatal("Server replied with error %d", be32toh(reply->error)); - if (strncmp(request->handle, reply->handle, 8) != 0) + } + if (strncmp(request->handle, reply->handle, 8) != 0) { fatal("Did not reply with correct handle"); + } } void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf) @@ -105,11 +111,11 @@ void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) #define CHECK_RANGE(error_type) { \ off64_t size = socket_nbd_read_hello(params->client); \ - if (params->from < 0 || (params->from + params->len) > size) \ + if (params->from < 0 || (params->from + params->len) > size) {\ fatal(error_type \ " request %d+%d is out of range given size %d", \ params->from, params->len, size\ - ); \ + ); }\ } void do_read(struct mode_readwrite_params* params) diff --git a/src/readwrite.h b/src/readwrite.h index 2dabbfc..22ce723 100644 --- a/src/readwrite.h +++ b/src/readwrite.h @@ -1,6 +1,6 @@ -#ifndef __READWRITE_H +#ifndef READWRITE_H -#define __READWRITE_H +#define READWRITE_H int socket_connect(struct sockaddr* to, struct sockaddr* from); off64_t socket_nbd_read_hello(int fd); diff --git a/src/remote.c b/src/remote.c index 29d805b..7ffbe48 100644 --- a/src/remote.c +++ b/src/remote.c @@ -30,7 +30,9 @@ void do_remote_command(char* command, char* socket_name, int argc, char** argv) write(remote, command, strlen(command)); write(remote, &newline, 1); for (i=0; i 0) + if (exit_status > 0) { fprintf(stderr, "%s\n", strchr(response, ':')+2); + } exit(atoi(response)); diff --git a/src/serve.c b/src/serve.c index e5bee9c..d30d8ae 100644 --- a/src/serve.c +++ b/src/serve.c @@ -27,10 +27,12 @@ static inline void* sockaddr_address_data(struct sockaddr* sockaddr) struct sockaddr_in* in = (struct sockaddr_in*) sockaddr; struct sockaddr_in6* in6 = (struct sockaddr_in6*) sockaddr; - if (sockaddr->sa_family == AF_INET) + if (sockaddr->sa_family == AF_INET) { return &in->sin_addr; - if (sockaddr->sa_family == AF_INET6) + } + if (sockaddr->sa_family == AF_INET6) { return &in6->sin6_addr; + } return NULL; } @@ -63,8 +65,9 @@ struct server * server_create ( out->control_socket_name = s_ctrl_sock; out->acl = acl_create( acl_entries, s_acl_entries, default_deny ); - if (out->acl && out->acl->len != acl_entries) + if (out->acl && out->acl->len != acl_entries) { fatal("Bad ACL entry '%s'", s_acl_entries[out->acl->len]); + } parse_port( s_port, &out->bind_to.v4 ); @@ -103,16 +106,17 @@ void server_dirty(struct server *serve, off64_t from, int len) { NULLCHECK( serve ); - if (serve->mirror) + if (serve->mirror) { bitset_set_range(serve->mirror->dirty_map, from, len); + } } #define SERVER_LOCK( s, f, msg ) \ - { NULLCHECK( s ); \ - FATAL_IF_NEGATIVE( pthread_mutex_lock( &s->f ), msg ); } + do { NULLCHECK( s ); \ + FATAL_IF( 0 != pthread_mutex_lock( &s->f ), msg ); } while (0) #define SERVER_UNLOCK( s, f, msg ) \ - { NULLCHECK( s ); \ - FATAL_IF_NEGATIVE( pthread_mutex_unlock( &s->f ), msg ); } + do { NULLCHECK( s ); \ + FATAL_IF( 0 != pthread_mutex_unlock( &s->f ), msg ); } while (0) void server_lock_io( struct server * serve) { @@ -197,6 +201,7 @@ int tryjoin_client_thread( struct client_tbl_entry *entry, int (*joinfunc)(pthre int was_closed = 0; void * status; + int join_errno; if (entry->thread != 0) { char s_client_address[64]; @@ -208,9 +213,14 @@ int tryjoin_client_thread( struct client_tbl_entry *entry, int (*joinfunc)(pthre s_client_address, 64 ); - if (joinfunc(entry->thread, &status) != 0) { - if (errno != EBUSY) - FATAL_IF_NEGATIVE(-1, "Problem with joining thread"); + join_errno = joinfunc(entry->thread, &status); + /* join_errno can legitimately be ESRCH if the thread is + * already dead, but the cluent still needs tidying up. */ + if (join_errno != 0 && !entry->client->stopped ) { + FATAL_UNLESS( join_errno == EBUSY, + "Problem with joining thread %p: %s", + entry->thread, + strerror(join_errno) ); } else { debug("nbd thread %p exited (%s) with status %ld", @@ -381,7 +391,7 @@ void accept_nbd_client( return; } - debug("nbd thread %d started (%s)", (int) params->nbd_client[slot].thread, s_client_address); + debug("nbd thread %p started (%s)", params->nbd_client[slot].thread, s_client_address); } @@ -433,7 +443,7 @@ void server_close_clients( struct server *params ) } } for( j = 0; j < MAX_NBD_CLIENTS; j++ ) { - join_client_thread( ¶ms->nbd_client[i] ); + join_client_thread( ¶ms->nbd_client[j] ); } } @@ -476,8 +486,9 @@ int server_accept( struct server * params ) FD_SET(params->server_fd, &fds); self_pipe_fd_set( params->close_signal, &fds ); self_pipe_fd_set( params->acl_updated_signal, &fds ); - if (params->control_socket_name) + if (params->control_socket_name) { FD_SET(params->control_fd, &fds); + } FATAL_IF_NEGATIVE(select(FD_SETSIZE, &fds, NULL, NULL, NULL), "select() failed"); @@ -548,6 +559,15 @@ void serve_signal_close( struct server * serve ) self_pipe_signal( serve->close_signal ); } +/* Block until the server closes the server_fd. + */ +void serve_wait_for_close( struct server * serve ) +{ + while( !fd_is_closed( serve->server_fd ) ){ + usleep(10000); + } +} + /** Closes sockets, frees memory and waits for all client threads to finish */ void serve_cleanup(struct server* params, @@ -562,7 +582,6 @@ void serve_cleanup(struct server* params, if (params->server_fd){ close(params->server_fd); } if (params->control_fd){ close(params->control_fd); } if (params->control_socket_name){ ; } - if (params->proxy_fd){ close(params->proxy_fd); } if (params->close_signal) { self_pipe_destroy( params->close_signal ); @@ -579,10 +598,11 @@ void serve_cleanup(struct server* params, for (i=0; i < MAX_NBD_CLIENTS; i++) { void* status; + pthread_t thread_id = params->nbd_client[i].thread; - if (params->nbd_client[i].thread != 0) { - debug("joining thread %d", i); - pthread_join(params->nbd_client[i].thread, &status); + if (thread_id != 0) { + debug("joining thread %p", thread_id); + pthread_join(thread_id, &status); } } } diff --git a/src/serve.h b/src/serve.h index 570c5f8..72f6c2c 100644 --- a/src/serve.h +++ b/src/serve.h @@ -1,10 +1,6 @@ #ifndef SERVE_H #define SERVE_H -#define _GNU_SOURCE - -#define _LARGEFILE64_SOURCE - #include #include @@ -15,7 +11,6 @@ static const int block_allocation_resolution = 4096;//128<<10; enum mirror_finish_action { - ACTION_PROXY, ACTION_EXIT, ACTION_NOTHING }; @@ -63,9 +58,6 @@ struct server { /** Claims around any I/O to this file */ pthread_mutex_t l_io; - /** set to non-zero to cause r/w requests to go via this fd */ - int proxy_fd; - /** to interrupt accept loop and clients, write() to close_signal[1] */ struct self_pipe * close_signal; @@ -94,6 +86,7 @@ void server_dirty(struct server *serve, off64_t from, int len); void server_lock_io( struct server * serve); void server_unlock_io( struct server* serve ); void serve_signal_close( struct server *serve ); +void serve_wait_for_close( struct server * serve ); void server_replace_acl( struct server *serve, struct acl * acl); diff --git a/src/util.c b/src/util.c index 95e2c69..2e2481d 100644 --- a/src/util.c +++ b/src/util.c @@ -11,7 +11,7 @@ pthread_key_t cleanup_handler_key; -int log_level = 1; +int log_level = 2; void error_init(void) { @@ -33,8 +33,7 @@ void mylog(int line_level, const char* format, ...) { va_list argptr; - if (line_level < log_level) - return; + if (line_level < log_level) { return; } va_start(argptr, format); vfprintf(stderr, format, argptr); diff --git a/src/util.h b/src/util.h index 253b894..52cec40 100644 --- a/src/util.h +++ b/src/util.h @@ -49,13 +49,13 @@ extern pthread_key_t cleanup_handler_key; switch (setjmp(context->jmp)) \ { \ case 0: /* setup time */ \ - if (old) \ - free(old); \ + if (old) { free(old); }\ pthread_setspecific(cleanup_handler_key, context); \ break; \ case 1: /* fatal error, terminate thread */ \ + debug( "Fatal error in thread %p", pthread_self() ); \ context->handler(context->data, 1); \ - pthread_exit((void*) 1); \ + /*pthread_exit((void*) 1);*/ \ abort(); \ case 2: /* non-fatal error, return to context of error handler setup */ \ context->handler(context->data, 0); \ @@ -84,25 +84,52 @@ void mylog(int line_level, const char* format, ...); #define warn(msg, ...) mylog(2, "%s:%d: " msg, __FILE__, __LINE__, ##__VA_ARGS__) /* mylog a message and invoke the error handler to recover */ -#define error(msg, ...) { \ +#define error(msg, ...) do { \ mylog(3, "*** %s:%d: " msg, __FILE__, __LINE__, ##__VA_ARGS__); \ error_handler(0); \ -} +} while(0) /* mylog a message and invoke the error handler to kill the current thread */ -#define fatal(msg, ...) { \ +#define fatal(msg, ...) do { \ mylog(4, "*** %s:%d: " msg, __FILE__, __LINE__, ##__VA_ARGS__); \ error_handler(1); \ -} +} while(0) + + +#define ERROR_IF( test, msg, ... ) do { if ((test)) { error(msg, ##__VA_ARGS__); } } while(0) +#define FATAL_IF( test, msg, ... ) do { if ((test)) { fatal(msg, ##__VA_ARGS__); } } while(0) + +#define ERROR_UNLESS( test, msg, ... ) ERROR_IF( !(test), msg, ##__VA_ARGS__ ) +#define FATAL_UNLESS( test, msg, ... ) FATAL_IF( !(test), msg, ##__VA_ARGS__ ) + + +#define ERROR_IF_NULL(value, msg, ...) \ + ERROR_IF( NULL == value, msg " (errno=%d, %s)", ##__VA_ARGS__, errno, strerror(errno) ) +#define FATAL_IF_NULL(value, msg, ...) \ + FATAL_IF( NULL == value, msg " (errno=%d, %s)", ##__VA_ARGS__, errno, strerror(errno) ) + +#define ERROR_IF_NEGATIVE( value, msg, ... ) ERROR_IF( value < 0, msg, ##__VA_ARGS__ ) +#define FATAL_IF_NEGATIVE( value, msg, ... ) FATAL_IF( value < 0, msg, ##__VA_ARGS__ ) + +#define ERROR_IF_ZERO( value, msg, ... ) ERROR_IF( 0 == value, msg, ##__VA_ARGS__ ) +#define FATAL_IF_ZERO( value, msg, ... ) FATAL_IF( 0 == value, msg, ##__VA_ARGS__ ) + + + +#define ERROR_UNLESS_NULL(value, msg, ...) \ + ERROR_UNLESS( NULL == value, msg " (errno=%d, %s)", ##__VA_ARGS__, errno, strerror(errno) ) +#define FATAL_UNLESS_NULL(value, msg, ...) \ + FATAL_UNLESS( NULL == value, msg " (errno=%d, %s)", ##__VA_ARGS__, errno, strerror(errno) ) + +#define ERROR_UNLESS_NEGATIVE( value, msg, ... ) ERROR_UNLESS( value < 0, msg, ##__VA_ARGS__ ) +#define FATAL_UNLESS_NEGATIVE( value, msg, ... ) FATAL_UNLESS( value < 0, msg, ##__VA_ARGS__ ) + +#define ERROR_UNLESS_ZERO( value, msg, ... ) ERROR_UNLESS( 0 == value, msg, ##__VA_ARGS__ ) +#define FATAL_UNLESS_ZERO( value, msg, ... ) FATAL_UNLESS( 0 == value, msg, ##__VA_ARGS__ ) -#define ERROR_IF_NULL(value, msg, ...) if (NULL == value) error(msg " (errno=%d, %s)", ##__VA_ARGS__, errno, strerror(errno)) -#define ERROR_IF_NEGATIVE(value, msg, ...) if (value < 0) error(msg, ##__VA_ARGS__) -#define ERROR_IF_ZERO(value, msg, ...) if (0 == value) error(msg, ##__VA_ARGS__) -#define FATAL_IF_NULL(value, msg, ...) if (NULL == value) fatal(msg, ##__VA_ARGS__) -#define FATAL_IF_NEGATIVE(value, msg, ...) if (value < 0) fatal(msg " (errno=%d, %s)", ##__VA_ARGS__, errno, strerror(errno)) -#define FATAL_IF_ZERO(value, msg, ...) if (0 == value) fatal(msg, ##__VA_ARGS__) #define NULLCHECK(value) FATAL_IF_NULL(value, "BUG: " #value " is null") + #endif diff --git a/tests/check_serve.c b/tests/check_serve.c index 0cd3011..de354c1 100644 --- a/tests/check_serve.c +++ b/tests/check_serve.c @@ -13,6 +13,19 @@ #include #include +#ifdef DEBUG +# define LOG_LEVEL 0 +#else +# define LOG_LEVEL 2 +#endif + + +/* Need these because libcheck is braindead and doesn't + * run teardown after a failing test + */ +#define myfail( msg ) do { teardown(); fail(msg); } while (0) +#define myfail_if( tst, msg ) do { if( tst ) { myfail( msg ); } } while (0) +#define myfail_unless( tst, msg ) myfail_if( !(tst), msg ) char * dummy_file; @@ -45,13 +58,6 @@ void teardown( void ) dummy_file = NULL; } -/* Need these because libcheck is braindead and doesn't - * run teardown after a failing test - */ -#define myfail( msg ) do { teardown(); fail(msg); } while (0) -#define myfail_if( tst, msg ) do { if( tst ) { myfail( msg ); } } while (0) -#define myfail_unless( tst, msg ) myfail_if( !(tst), msg ) - START_TEST( test_replaces_acl ) { @@ -80,13 +86,16 @@ START_TEST( test_signals_acl_updated ) END_TEST -int connect_client( char *addr, int actual_port ) +int connect_client( char *addr, int actual_port, char *source_addr ) { int client_fd; struct addrinfo hint; struct addrinfo *ailist, *aip; + + + memset( &hint, '\0', sizeof( struct addrinfo ) ); hint.ai_socktype = SOCK_STREAM; @@ -96,6 +105,16 @@ int connect_client( char *addr, int actual_port ) for( aip = ailist; aip; aip = aip->ai_next ) { ((struct sockaddr_in *)aip->ai_addr)->sin_port = htons( actual_port ); client_fd = socket( aip->ai_family, aip->ai_socktype, aip->ai_protocol ); + + if (source_addr) { + struct sockaddr src; + if( !parse_ip_to_sockaddr(&src, source_addr)) { + close(client_fd); + continue; + } + bind(client_fd, &src, sizeof(struct sockaddr_in6)); + } + if( client_fd == -1) { continue; } if( connect( client_fd, aip->ai_addr, aip->ai_addrlen) == 0 ) { connected = 1; @@ -135,7 +154,7 @@ START_TEST( test_acl_update_closes_bad_client ) serve_open_server_socket( s ); actual_port = server_port( s ); - client_fd = connect_client( "127.0.0.7", actual_port ); + client_fd = connect_client( "127.0.0.7", actual_port, "127.0.0.1" ); server_accept( s ); entry = &s->nbd_client[0]; c = entry->client; @@ -166,9 +185,8 @@ START_TEST( test_acl_update_leaves_good_client ) { struct server * s = server_create( "127.0.0.7", "0", dummy_file, NULL, 0, 0, NULL ); - // Client source address may be IPv4 or IPv6 localhost. Should be explicit - char *lines[] = {"127.0.0.1", "::1"}; - struct acl * new_acl = acl_create( 2, lines, 1 ); + char *lines[] = {"127.0.0.1"}; + struct acl * new_acl = acl_create( 1, lines, 1 ); struct client * c; struct client_tbl_entry * entry; @@ -176,12 +194,10 @@ START_TEST( test_acl_update_leaves_good_client ) int client_fd; int server_fd; - myfail_if(new_acl->len != 2, "sanity: new_acl length is not 2"); - serve_open_server_socket( s ); actual_port = server_port( s ); - client_fd = connect_client( "127.0.0.7", actual_port ); + client_fd = connect_client( "127.0.0.7", actual_port, "127.0.0.1" ); server_accept( s ); entry = &s->nbd_client[0]; c = entry->client; @@ -211,22 +227,22 @@ Suite* serve_suite(void) Suite *s = suite_create("serve"); TCase *tc_acl_update = tcase_create("acl_update"); - tcase_add_checked_fixture( tc_acl_update, setup, teardown ); + tcase_add_checked_fixture( tc_acl_update, setup, NULL ); + tcase_add_test(tc_acl_update, test_replaces_acl); tcase_add_test(tc_acl_update, test_signals_acl_updated); - tcase_add_test(tc_acl_update, test_acl_update_closes_bad_client); - tcase_add_test(tc_acl_update, test_acl_update_leaves_good_client); + tcase_add_exit_test(tc_acl_update, test_acl_update_closes_bad_client, 0); + tcase_add_exit_test(tc_acl_update, test_acl_update_leaves_good_client, 0); suite_add_tcase(s, tc_acl_update); return s; } - int main(void) { - log_level = 0; + log_level = LOG_LEVEL; int number_failed; Suite *s = serve_suite(); SRunner *sr = srunner_create(s); diff --git a/tests/flexnbd.rb b/tests/flexnbd.rb index df70bc6..64969a5 100644 --- a/tests/flexnbd.rb +++ b/tests/flexnbd.rb @@ -147,7 +147,7 @@ class FlexNBD def debug? - !@debug.empty? + !@debug.empty? || ENV['DEBUG'] end def debug( msg ) @@ -186,6 +186,14 @@ class FlexNBD end + def mirror_cmd(dest_ip, dest_port) + "#{@bin} mirror "\ + "--addr #{dest_ip} "\ + "--port #{dest_port} "\ + "--sock #{ctrl} "\ + "#{@debug} " + end + def serve(file, *acl) File.unlink(ctrl) if File.exists?(ctrl) cmd =serve_cmd( file, acl ) @@ -205,7 +213,10 @@ class FlexNBD def start_wait_thread( pid ) Thread.start do Process.waitpid2( pid ) - unless @kill + if @kill + fail "flexnbd quit with a bad status #{$?.exitstatus}" unless + $?.exitstatus == @kill + else $stderr.puts "flexnbd quit" fail "flexnbd quit early" end @@ -213,9 +224,18 @@ class FlexNBD end + def can_die(status=0) + @kill = status + end + def kill - @kill = true - Process.kill("INT", @pid) + can_die() + begin + Process.kill("INT", @pid) + rescue Errno::ESRCH => e + # already dead. Presumably this means it went away after a + # can_die() call. + end end def read(offset, length) @@ -240,8 +260,12 @@ class FlexNBD nil end - def mirror(bandwidth=nil, action=nil) - control_command("mirror", ip, port, ip, bandwidth, action) + def mirror(dest_ip, dest_port, bandwidth=nil, action=nil) + cmd = mirror_cmd( dest_ip, dest_port) + debug( cmd ) + system cmd + raise IOError.new( "Migrate command failed") unless $?.success? + nil end def acl(*acl) diff --git a/tests/nbd_scenarios b/tests/nbd_scenarios index be58cfe..8e44bf5 100644 --- a/tests/nbd_scenarios +++ b/tests/nbd_scenarios @@ -14,6 +14,7 @@ class NBDScenarios < Test::Unit::TestCase @port1 = @available_ports.shift @port2 = @available_ports.shift @nbd1 = FlexNBD.new("../build/flexnbd", @ip, @port1) + @nbd2 = FlexNBD.new("../build/flexnbd", @ip, @port2) end def teardown @@ -70,15 +71,43 @@ class NBDScenarios < Test::Unit::TestCase end end + + def test_mirror + writefile1( "f"*4 ) + serve1 + + writefile2( "0"*4 ) + serve2 + + @nbd1.can_die + mirror12 + assert_equal(@file1.read_original( 0, @blocksize ), + @file2.read( 0, @blocksize ) ) + end + protected def serve1(*acl) @nbd1.serve(@filename1, *acl) end + def serve2(*acl) + @nbd2.serve(@filename2, *acl) + end + + def mirror12 + @nbd1.mirror( @nbd2.ip, @nbd2.port ) + end + def writefile1(data) @file1 = TestFileWriter.new(@filename1, @blocksize).write(data) end + def writefile2(data) + @file2 = TestFileWriter.new(@filename2, @blocksize).write(data) + end + + + def listening_ports `netstat -ltn`. split("\n"). diff --git a/tests/test_file_writer.rb b/tests/test_file_writer.rb index 0ff6f5d..025e340 100644 --- a/tests/test_file_writer.rb +++ b/tests/test_file_writer.rb @@ -27,20 +27,15 @@ class TestFileWriter self end + # Returns what the data ought to be at the given offset and length # - def read_original(off, len) - r="" - current = 0 - @pattern.split("").each do |block| - if off >= current && (off+len) < current + blocksize - current += data(block, current)[ - current-off..(current+blocksize)-(off+len) - ] - end - current += @blocksize - end - r + def read_original( off, len ) + patterns = @pattern.split( "" ) + patterns.zip( (0...patterns.length).to_a ). + map { |blk, blk_off| + data(blk, blk_off) + }.join[off...(off+len)] end # Read what's actually in the file @@ -51,7 +46,7 @@ class TestFileWriter end def untouched?(offset, len) - read(off, len) == read_original(off, len) + read(offset, len) == read_original(offset, len) end def close @@ -81,3 +76,48 @@ class TestFileWriter end +if __FILE__==$0 + require 'tempfile' + require 'test/unit' + + class TestFileWriterTest < Test::Unit::TestCase + def test_read_original_zeros + Tempfile.open("test_read_original_zeros") do |tempfile| + tempfile.close + file = TestFileWriter.new( tempfile.path, 4096 ) + file.write( "0" ) + assert_equal file.read( 0, 4096 ), file.read_original( 0, 4096 ) + assert( file.untouched?(0,4096) , "Untouched file was touched." ) + end + end + + def test_read_original_offsets + Tempfile.open("test_read_original_offsets") do |tempfile| + tempfile.close + file = TestFileWriter.new( tempfile.path, 4096 ) + file.write( "f" ) + assert_equal file.read( 0, 4096 ), file.read_original( 0, 4096 ) + assert( file.untouched?(0,4096) , "Untouched file was touched." ) + end + end + + def test_file_size + Tempfile.open("test_file_size") do |tempfile| + tempfile.close + file = TestFileWriter.new( tempfile.path, 4096 ) + file.write( "f" ) + assert_equal 4096, File.stat( tempfile.path ).size + end + end + + def test_read_original_size + Tempfile.open("test_read_original_offsets") do |tempfile| + tempfile.close + file = TestFileWriter.new( tempfile.path, 4) + file.write( "f"*4 ) + assert_equal 4, file.read_original(0, 4).length + end + end + end +end +