Added another write/read test, fixed bugs in splice() usage and IPv6

socket handling.
This commit is contained in:
Matthew Bloch
2012-05-27 14:40:16 +01:00
parent 5a5041a751
commit c54d4a68ba
7 changed files with 153 additions and 28 deletions

View File

@@ -14,13 +14,17 @@ if DEBUG
CCFLAGS << ["-g -DDEBUG"] CCFLAGS << ["-g -DDEBUG"]
end end
desc "Build flexnbd binary"
rule 'default' => 'flexnbd' rule 'default' => 'flexnbd'
namespace "test" do namespace "test" do
desc "Run all tests"
task 'run' => ["unit", "scenarios"] task 'run' => ["unit", "scenarios"]
desc "Build C tests"
task 'build' => TEST_MODULES.map { |n| "tests/check_#{n}" } task 'build' => TEST_MODULES.map { |n| "tests/check_#{n}" }
desc "Run C tests"
task 'unit' => 'build' do task 'unit' => 'build' do
TEST_MODULES.each do |n| TEST_MODULES.each do |n|
ENV['EF_DISABLE_BANNER'] = '1' ENV['EF_DISABLE_BANNER'] = '1'
@@ -28,7 +32,8 @@ namespace "test" do
end end
end end
task 'scenarios' do desc "Run NBD test scenarios"
task 'scenarios' => 'flexnbd' do
sh "cd tests; ruby nbd_scenarios" sh "cd tests; ruby nbd_scenarios"
end end
end end
@@ -52,6 +57,7 @@ rule '.o' => '.c' do |t|
sh "gcc -I. -c #{CCFLAGS.join(' ')} -o #{t.name} #{t.source} " sh "gcc -I. -c #{CCFLAGS.join(' ')} -o #{t.name} #{t.source} "
end end
desc "Remove all build targets, binaries and temporary files"
rule 'clean' do rule 'clean' do
sh "rm -f *~ flexnbd " + ( sh "rm -f *~ flexnbd " + (
OBJECTS + OBJECTS +

View File

@@ -9,7 +9,7 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <fcntl.h> #include <fcntl.h>
#include <unistd.h> #include <unistd.h>
#include <signal.h>
void syntax() void syntax()
{ {
@@ -175,6 +175,7 @@ void mode(char* mode, int argc, char **argv)
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
signal(SIGPIPE, SIG_IGN);
error_init(); error_init();
if (argc < 2) if (argc < 2)

View File

@@ -132,16 +132,48 @@ int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count)
{ {
size_t sent=0; size_t sent=0;
while (sent < count) { while (sent < count) {
size_t result = sendfile64(out_fd, in_fd, offset+sent, count-sent); size_t result = sendfile64(out_fd, in_fd, offset, count-sent);
debug("sendfile64(out_fd=%d, in_fd=%d, offset=%p, count-sent=%ld) = %ld", out_fd, in_fd, offset, count-sent, result);
if (result == -1) if (result == -1)
return -1; return -1;
sent += result; sent += result;
debug("sent=%ld, count=%ld", sent, count);
} }
debug("exiting sendfileloop");
return 0; return 0;
} }
#include <errno.h>
ssize_t spliceloop(int fd_in, loff_t *off_in, int fd_out, loff_t *off_out, size_t len, unsigned int flags2)
{
const unsigned int flags = SPLICE_F_MORE|SPLICE_F_MOVE|flags2;
size_t spliced=0;
//debug("spliceloop(%d, %ld, %d, %ld, %ld)", fd_in, off_in ? *off_in : 0, fd_out, off_out ? *off_out : 0, len);
while (spliced < len) {
ssize_t result = splice(fd_in, off_in, fd_out, off_out, len, flags);
if (result < 0) {
//debug("result=%ld (%s), spliced=%ld, len=%ld", result, strerror(errno), spliced, len);
if (errno == EAGAIN && (flags & SPLICE_F_NONBLOCK) ) {
return spliced;
}
else {
return -1;
}
} else {
spliced += result;
//debug("result=%ld (%s), spliced=%ld, len=%ld", result, strerror(errno), spliced, len);
}
}
return spliced;
}
int splice_via_pipe_loop(int fd_in, int fd_out, size_t len) int splice_via_pipe_loop(int fd_in, int fd_out, size_t len)
{ {
int pipefd[2]; /* read end, write end */ int pipefd[2]; /* read end, write end */
size_t spliced=0; size_t spliced=0;
@@ -149,18 +181,17 @@ int splice_via_pipe_loop(int fd_in, int fd_out, size_t len)
return -1; return -1;
while (spliced < len) { while (spliced < len) {
size_t r1,r2; ssize_t run = len-spliced;
size_t run = len-spliced; ssize_t s2, s1 = spliceloop(fd_in, NULL, pipefd[1], NULL, run, SPLICE_F_NONBLOCK);
/*if (run > 65535) /*if (run > 65535)
run = 65535;*/ run = 65535;*/
r1 = splice(fd_in, NULL, pipefd[1], NULL, run, SPLICE_F_MORE|SPLICE_F_MOVE|SPLICE_F_NONBLOCK); if (s1 < 0)
debug("%ld", r1);
if (r1 <= 0)
break; break;
r2 = splice(pipefd[0], NULL, fd_out, NULL, r1, SPLICE_F_MORE|SPLICE_F_MOVE);
if (r1 != r2) s2 = spliceloop(pipefd[0], NULL, fd_out, NULL, s1, 0);
if (s2 < 0)
break; break;
spliced += r1; spliced += s2;
} }
close(pipefd[0]); close(pipefd[0]);
close(pipefd[1]); close(pipefd[1]);

View File

@@ -30,6 +30,7 @@ struct control_params {
struct mode_serve_params* serve; struct mode_serve_params* serve;
}; };
#define MAX_NBD_CLIENTS 16
struct mode_serve_params { struct mode_serve_params {
/* address/port to bind to */ /* address/port to bind to */
union mysockaddr bind_to; union mysockaddr bind_to;
@@ -60,6 +61,9 @@ struct mode_serve_params {
int control; int control;
char* block_allocation_map; char* block_allocation_map;
struct { pthread_t thread; struct sockaddr address; }
nbd_client[MAX_NBD_CLIENTS];
}; };
struct mode_readwrite_params { struct mode_readwrite_params {
@@ -79,5 +83,18 @@ struct client_params {
struct mode_serve_params* serve; /* FIXME: remove above duplication */ struct mode_serve_params* serve; /* FIXME: remove above duplication */
}; };
/* FIXME: wrong place */
static inline void* sockaddr_address_data(struct sockaddr* sockaddr)
{
struct sockaddr_in* in = (struct sockaddr_in*) sockaddr;
struct sockaddr_in6* in6 = (struct sockaddr_in6*) sockaddr;
if (sockaddr->sa_family == AF_INET)
return &in->sin_addr;
if (sockaddr->sa_family == AF_INET6)
return &in6->sin6_addr;
return NULL;
}
#endif #endif

View File

@@ -9,9 +9,9 @@
int socket_connect(struct sockaddr* to) int socket_connect(struct sockaddr* to)
{ {
int fd = socket(PF_INET, SOCK_STREAM, 0); int fd = socket(to->sa_family == AF_INET ? PF_INET : PF_INET6, SOCK_STREAM, 0);
SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket"); SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket");
SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(*to)), SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(struct sockaddr_in6)),
"connect failed"); "connect failed");
return fd; return fd;
} }

81
serve.c
View File

@@ -326,14 +326,15 @@ int is_included_in_acl(int list_length, struct ip_and_mask (*list)[], struct soc
void serve_open_server_socket(struct mode_serve_params* params) void serve_open_server_socket(struct mode_serve_params* params)
{ {
params->server = socket(PF_INET, SOCK_STREAM, 0); params->server = socket(params->bind_to.generic.sa_family == AF_INET ?
PF_INET : PF_INET6, SOCK_STREAM, 0);
SERVER_ERROR_ON_FAILURE(params->server, SERVER_ERROR_ON_FAILURE(params->server,
"Couldn't create server socket"); "Couldn't create server socket");
SERVER_ERROR_ON_FAILURE( SERVER_ERROR_ON_FAILURE(
bind(params->server, &params->bind_to.generic, bind(params->server, &params->bind_to.generic,
sizeof(params->bind_to.generic)), sizeof(params->bind_to)),
"Couldn't bind server to IP address" "Couldn't bind server to IP address"
); );
@@ -343,10 +344,54 @@ void serve_open_server_socket(struct mode_serve_params* params)
); );
} }
int cleanup_and_find_client_slot(struct mode_serve_params* params)
{
int slot=-1, i;
for (i=0; i < MAX_NBD_CLIENTS; i++) {
void* status;
if (params->nbd_client[i].thread != 0) {
char s_client_address[64];
memset(s_client_address, 0, 64);
strcpy(s_client_address, "???");
inet_ntop(
params->nbd_client[i].address.sa_family,
sockaddr_address_data(&params->nbd_client[i].address),
s_client_address,
64
);
if (pthread_tryjoin_np(params->nbd_client[i].thread, &status) < 0) {
if (errno != EBUSY)
SERVER_ERROR_ON_FAILURE(-1, "Problem with joining thread");
}
else {
uint64_t status1 = (uint64_t) status;
params->nbd_client[i].thread = 0;
debug("nbd thread %d exited (%s)", (int) params->nbd_client[i].thread, s_client_address);
}
}
if (params->nbd_client[i].thread == 0 && slot == -1)
slot = i;
}
return slot;
}
void accept_nbd_client(struct mode_serve_params* params, int client_fd, struct sockaddr* client_address) void accept_nbd_client(struct mode_serve_params* params, int client_fd, struct sockaddr* client_address)
{ {
pthread_t client_thread;
struct client_params* client_params; struct client_params* client_params;
int slot = cleanup_and_find_client_slot(params);
char s_client_address[64];
if (inet_ntop(client_address->sa_family, sockaddr_address_data(client_address), s_client_address, 64) < 0) {
write(client_fd, "Bad client_address", 18);
close(client_fd);
return;
}
if (params->acl && if (params->acl &&
!is_included_in_acl(params->acl_entries, params->acl, client_address)) { !is_included_in_acl(params->acl_entries, params->acl, client_address)) {
@@ -355,21 +400,27 @@ void accept_nbd_client(struct mode_serve_params* params, int client_fd, struct s
return; return;
} }
if (slot < 0) {
write(client_fd, "Too many clients", 16);
close(client_fd);
return;
}
client_params = xmalloc(sizeof(struct client_params)); client_params = xmalloc(sizeof(struct client_params));
client_params->socket = client_fd; client_params->socket = client_fd;
client_params->serve = params; client_params->serve = params;
SERVER_ERROR_ON_FAILURE( if (pthread_create(&params->nbd_client[slot].thread, NULL, client_serve, client_params) < 0) {
pthread_create( write(client_fd, "Thread creation problem", 23);
&client_thread, free(client_params);
NULL, close(client_fd);
client_serve, return;
client_params }
),
"Failed to create client thread" memcpy(&params->nbd_client[slot].address, client_address,
); sizeof(struct sockaddr));
/* FIXME: keep track of them? */
/* FIXME: maybe shouldn't be fatal? */ debug("nbd thread %d started (%s)", (int) params->nbd_client[slot].thread, s_client_address);
} }
void serve_accept_loop(struct mode_serve_params* params) void serve_accept_loop(struct mode_serve_params* params)

View File

@@ -40,6 +40,8 @@ class NBDScenarios < Test::Unit::TestCase
end end
end end
# Check that we're not
#
def test_writeread1 def test_writeread1
writefile1("0"*64) writefile1("0"*64)
serve1 serve1
@@ -52,6 +54,23 @@ class NBDScenarios < Test::Unit::TestCase
end end
end end
# Check that we're not overstepping or understepping where our writes end
# up.
#
def test_writeread2
writefile1("0"*1024)
serve1
d0 = "\0"*@blocksize
d1 = "X"*@blocksize
(0..63).each do |num|
@nbd1.write(num*@blocksize*2, d1)
end
(0..63).each do |num|
assert_equal(d0, @nbd1.read(((2*num)+1)*@blocksize, d0.size))
end
end
protected protected
def serve1(*acl) def serve1(*acl)
@nbd1.serve(@ip, @port1, @filename1, *acl) @nbd1.serve(@ip, @port1, @filename1, *acl)