diff --git a/Rakefile b/Rakefile index 5f6fec2..e136848 100644 --- a/Rakefile +++ b/Rakefile @@ -14,13 +14,17 @@ if DEBUG CCFLAGS << ["-g -DDEBUG"] end +desc "Build flexnbd binary" rule 'default' => 'flexnbd' namespace "test" do + desc "Run all tests" task 'run' => ["unit", "scenarios"] + desc "Build C tests" task 'build' => TEST_MODULES.map { |n| "tests/check_#{n}" } + desc "Run C tests" task 'unit' => 'build' do TEST_MODULES.each do |n| ENV['EF_DISABLE_BANNER'] = '1' @@ -28,7 +32,8 @@ namespace "test" do end end - task 'scenarios' do + desc "Run NBD test scenarios" + task 'scenarios' => 'flexnbd' do sh "cd tests; ruby nbd_scenarios" end end @@ -52,6 +57,7 @@ rule '.o' => '.c' do |t| sh "gcc -I. -c #{CCFLAGS.join(' ')} -o #{t.name} #{t.source} " end +desc "Remove all build targets, binaries and temporary files" rule 'clean' do sh "rm -f *~ flexnbd " + ( OBJECTS + diff --git a/flexnbd.c b/flexnbd.c index 2548340..15874c3 100644 --- a/flexnbd.c +++ b/flexnbd.c @@ -9,7 +9,7 @@ #include #include #include - +#include void syntax() { @@ -175,6 +175,7 @@ void mode(char* mode, int argc, char **argv) int main(int argc, char** argv) { + signal(SIGPIPE, SIG_IGN); error_init(); if (argc < 2) diff --git a/ioutil.c b/ioutil.c index fac17e7..1486790 100644 --- a/ioutil.c +++ b/ioutil.c @@ -132,16 +132,48 @@ int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count) { size_t sent=0; while (sent < count) { - size_t result = sendfile64(out_fd, in_fd, offset+sent, count-sent); + size_t result = sendfile64(out_fd, in_fd, offset, count-sent); + debug("sendfile64(out_fd=%d, in_fd=%d, offset=%p, count-sent=%ld) = %ld", out_fd, in_fd, offset, count-sent, result); + if (result == -1) return -1; sent += result; + debug("sent=%ld, count=%ld", sent, count); } + debug("exiting sendfileloop"); return 0; } +#include +ssize_t spliceloop(int fd_in, loff_t *off_in, int fd_out, loff_t *off_out, size_t len, unsigned int flags2) +{ + const unsigned int flags = SPLICE_F_MORE|SPLICE_F_MOVE|flags2; + size_t spliced=0; + + //debug("spliceloop(%d, %ld, %d, %ld, %ld)", fd_in, off_in ? *off_in : 0, fd_out, off_out ? *off_out : 0, len); + + while (spliced < len) { + ssize_t result = splice(fd_in, off_in, fd_out, off_out, len, flags); + if (result < 0) { + //debug("result=%ld (%s), spliced=%ld, len=%ld", result, strerror(errno), spliced, len); + if (errno == EAGAIN && (flags & SPLICE_F_NONBLOCK) ) { + return spliced; + } + else { + return -1; + } + } else { + spliced += result; + //debug("result=%ld (%s), spliced=%ld, len=%ld", result, strerror(errno), spliced, len); + } + } + + return spliced; +} + int splice_via_pipe_loop(int fd_in, int fd_out, size_t len) { + int pipefd[2]; /* read end, write end */ size_t spliced=0; @@ -149,18 +181,17 @@ int splice_via_pipe_loop(int fd_in, int fd_out, size_t len) return -1; while (spliced < len) { - size_t r1,r2; - size_t run = len-spliced; + ssize_t run = len-spliced; + ssize_t s2, s1 = spliceloop(fd_in, NULL, pipefd[1], NULL, run, SPLICE_F_NONBLOCK); /*if (run > 65535) run = 65535;*/ - r1 = splice(fd_in, NULL, pipefd[1], NULL, run, SPLICE_F_MORE|SPLICE_F_MOVE|SPLICE_F_NONBLOCK); - debug("%ld", r1); - if (r1 <= 0) + if (s1 < 0) break; - r2 = splice(pipefd[0], NULL, fd_out, NULL, r1, SPLICE_F_MORE|SPLICE_F_MOVE); - if (r1 != r2) + + s2 = spliceloop(pipefd[0], NULL, fd_out, NULL, s1, 0); + if (s2 < 0) break; - spliced += r1; + spliced += s2; } close(pipefd[0]); close(pipefd[1]); diff --git a/params.h b/params.h index 6a94ea1..dfcb653 100644 --- a/params.h +++ b/params.h @@ -30,6 +30,7 @@ struct control_params { struct mode_serve_params* serve; }; +#define MAX_NBD_CLIENTS 16 struct mode_serve_params { /* address/port to bind to */ union mysockaddr bind_to; @@ -60,6 +61,9 @@ struct mode_serve_params { int control; char* block_allocation_map; + + struct { pthread_t thread; struct sockaddr address; } + nbd_client[MAX_NBD_CLIENTS]; }; struct mode_readwrite_params { @@ -79,5 +83,18 @@ struct client_params { struct mode_serve_params* serve; /* FIXME: remove above duplication */ }; +/* FIXME: wrong place */ +static inline void* sockaddr_address_data(struct sockaddr* sockaddr) +{ + struct sockaddr_in* in = (struct sockaddr_in*) sockaddr; + struct sockaddr_in6* in6 = (struct sockaddr_in6*) sockaddr; + + if (sockaddr->sa_family == AF_INET) + return &in->sin_addr; + if (sockaddr->sa_family == AF_INET6) + return &in6->sin6_addr; + return NULL; +} + #endif diff --git a/readwrite.c b/readwrite.c index d469c76..ff1631b 100644 --- a/readwrite.c +++ b/readwrite.c @@ -9,9 +9,9 @@ int socket_connect(struct sockaddr* to) { - int fd = socket(PF_INET, SOCK_STREAM, 0); + int fd = socket(to->sa_family == AF_INET ? PF_INET : PF_INET6, SOCK_STREAM, 0); SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket"); - SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(*to)), + SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(struct sockaddr_in6)), "connect failed"); return fd; } diff --git a/serve.c b/serve.c index a19c3b7..9099408 100644 --- a/serve.c +++ b/serve.c @@ -326,14 +326,15 @@ int is_included_in_acl(int list_length, struct ip_and_mask (*list)[], struct soc void serve_open_server_socket(struct mode_serve_params* params) { - params->server = socket(PF_INET, SOCK_STREAM, 0); + params->server = socket(params->bind_to.generic.sa_family == AF_INET ? + PF_INET : PF_INET6, SOCK_STREAM, 0); SERVER_ERROR_ON_FAILURE(params->server, "Couldn't create server socket"); - + SERVER_ERROR_ON_FAILURE( bind(params->server, ¶ms->bind_to.generic, - sizeof(params->bind_to.generic)), + sizeof(params->bind_to)), "Couldn't bind server to IP address" ); @@ -343,10 +344,54 @@ void serve_open_server_socket(struct mode_serve_params* params) ); } +int cleanup_and_find_client_slot(struct mode_serve_params* params) +{ + int slot=-1, i; + + for (i=0; i < MAX_NBD_CLIENTS; i++) { + void* status; + + if (params->nbd_client[i].thread != 0) { + char s_client_address[64]; + + memset(s_client_address, 0, 64); + strcpy(s_client_address, "???"); + inet_ntop( + params->nbd_client[i].address.sa_family, + sockaddr_address_data(¶ms->nbd_client[i].address), + s_client_address, + 64 + ); + + if (pthread_tryjoin_np(params->nbd_client[i].thread, &status) < 0) { + if (errno != EBUSY) + SERVER_ERROR_ON_FAILURE(-1, "Problem with joining thread"); + } + else { + uint64_t status1 = (uint64_t) status; + params->nbd_client[i].thread = 0; + debug("nbd thread %d exited (%s)", (int) params->nbd_client[i].thread, s_client_address); + } + } + + if (params->nbd_client[i].thread == 0 && slot == -1) + slot = i; + } + + return slot; +} + void accept_nbd_client(struct mode_serve_params* params, int client_fd, struct sockaddr* client_address) { - pthread_t client_thread; struct client_params* client_params; + int slot = cleanup_and_find_client_slot(params); + char s_client_address[64]; + + if (inet_ntop(client_address->sa_family, sockaddr_address_data(client_address), s_client_address, 64) < 0) { + write(client_fd, "Bad client_address", 18); + close(client_fd); + return; + } if (params->acl && !is_included_in_acl(params->acl_entries, params->acl, client_address)) { @@ -355,21 +400,27 @@ void accept_nbd_client(struct mode_serve_params* params, int client_fd, struct s return; } + if (slot < 0) { + write(client_fd, "Too many clients", 16); + close(client_fd); + return; + } + client_params = xmalloc(sizeof(struct client_params)); client_params->socket = client_fd; client_params->serve = params; - SERVER_ERROR_ON_FAILURE( - pthread_create( - &client_thread, - NULL, - client_serve, - client_params - ), - "Failed to create client thread" - ); - /* FIXME: keep track of them? */ - /* FIXME: maybe shouldn't be fatal? */ + if (pthread_create(¶ms->nbd_client[slot].thread, NULL, client_serve, client_params) < 0) { + write(client_fd, "Thread creation problem", 23); + free(client_params); + close(client_fd); + return; + } + + memcpy(¶ms->nbd_client[slot].address, client_address, + sizeof(struct sockaddr)); + + debug("nbd thread %d started (%s)", (int) params->nbd_client[slot].thread, s_client_address); } void serve_accept_loop(struct mode_serve_params* params) diff --git a/tests/nbd_scenarios b/tests/nbd_scenarios index a45fc70..faa4767 100644 --- a/tests/nbd_scenarios +++ b/tests/nbd_scenarios @@ -40,6 +40,8 @@ class NBDScenarios < Test::Unit::TestCase end end + # Check that we're not + # def test_writeread1 writefile1("0"*64) serve1 @@ -52,6 +54,23 @@ class NBDScenarios < Test::Unit::TestCase end end + # Check that we're not overstepping or understepping where our writes end + # up. + # + def test_writeread2 + writefile1("0"*1024) + serve1 + + d0 = "\0"*@blocksize + d1 = "X"*@blocksize + (0..63).each do |num| + @nbd1.write(num*@blocksize*2, d1) + end + (0..63).each do |num| + assert_equal(d0, @nbd1.read(((2*num)+1)*@blocksize, d0.size)) + end + end + protected def serve1(*acl) @nbd1.serve(@ip, @port1, @filename1, *acl)