diff --git a/Rakefile b/Rakefile index e136848..a8d47cc 100644 --- a/Rakefile +++ b/Rakefile @@ -1,6 +1,6 @@ DEBUG = true -SOURCES = %w( flexnbd ioutil readwrite serve util parse control ) +SOURCES = %w( flexnbd ioutil readwrite serve util parse control remote ) OBJECTS = SOURCES.map { |s| "#{s}.o" } LIBS = %w( pthread ) CCFLAGS = %w( -Wall ) diff --git a/control.c b/control.c index 1930e9c..80a30a2 100644 --- a/control.c +++ b/control.c @@ -66,7 +66,6 @@ int control_mirror(struct control_params* client, int linesc, char** lines) int fd, map_fd; struct mirror_status *mirror; union mysockaddr connect_to; - char s_ip_address[64], s_port[8]; uint64_t max_bytes_per_second; int action_at_finish; @@ -75,12 +74,12 @@ int control_mirror(struct control_params* client, int linesc, char** lines) return -1; } - if (parse_ip_to_sockaddr(&connect_to.generic, s_ip_address) == 0) { + if (parse_ip_to_sockaddr(&connect_to.generic, lines[0]) == 0) { write_socket("1: bad IP address"); return -1; } - connect_to.v4.sin_port = atoi(s_port); + connect_to.v4.sin_port = atoi(lines[1]); if (connect_to.v4.sin_port < 0 || connect_to.v4.sin_port > 65535) { write_socket("1: bad IP port number"); return -1; @@ -153,14 +152,13 @@ int control_mirror(struct control_params* client, int linesc, char** lines) int control_acl(struct control_params* client, int linesc, char** lines) { int acl_entries = 0, parsed; - char** s_acl_entry = NULL; struct ip_and_mask (*acl)[], (*old_acl)[]; parsed = parse_acl(&acl, linesc, lines); if (parsed != linesc) { - write(client->socket, "1: bad spec ", 12); - write(client->socket, s_acl_entry[parsed], - strlen(s_acl_entry[parsed])); + write(client->socket, "1: bad spec: ", 13); + write(client->socket, lines[parsed], + strlen(lines[parsed])); write(client->socket, "\n", 1); free(acl); } @@ -256,7 +254,7 @@ void serve_open_control_socket(struct mode_serve_params* params) memset(&bind_address, 0, sizeof(bind_address)); bind_address.sun_family = AF_UNIX; - strcpy(bind_address.sun_path, params->control_socket_name); + strncpy(bind_address.sun_path, params->control_socket_name, sizeof(bind_address.sun_path)-1); unlink(params->control_socket_name); /* ignore failure */ diff --git a/flexnbd.c b/flexnbd.c index 15874c3..520119e 100644 --- a/flexnbd.c +++ b/flexnbd.c @@ -94,6 +94,7 @@ void params_readwrite( SERVER_ERROR("Couldn't parse connection address '%s'", s_ip_address); + /* FIXME: duplicated from above */ out->connect_to.v4.sin_port = atoi(s_port); if (out->connect_to.v4.sin_port < 0 || out->connect_to.v4.sin_port > 65535) SERVER_ERROR("Port number must be >= 0 and <= 65535"); @@ -129,6 +130,7 @@ void params_readwrite( void do_serve(struct mode_serve_params* params); void do_read(struct mode_readwrite_params* params); void do_write(struct mode_readwrite_params* params); +void do_remote_command(char* command, char* mode, int argc, char** argv); union mode_params { struct mode_serve_params serve; @@ -167,6 +169,14 @@ void mode(char* mode, int argc, char **argv) syntax(); } } + else if (strcmp(mode, "acl") == 0 || strcmp(mode, "mirror") == 0 || strcmp(mode, "status") == 0) { + if (argc >= 1) { + do_remote_command(mode, argv[0], argc-1, argv+1); + } + else { + syntax(); + } + } else { syntax(); } diff --git a/ioutil.c b/ioutil.c index 1486790..a77f959 100644 --- a/ioutil.c +++ b/ioutil.c @@ -208,7 +208,7 @@ int read_until_newline(int fd, char* buf, int bufsize) int result = read(fd, buf+cur, 1); if (result < 0) return -1; - if (buf[cur] == 10 || buf[cur] == 13) + if (buf[cur] == 10) break; } buf[cur++] = 0; diff --git a/params.h b/params.h index dfcb653..bcf2f6b 100644 --- a/params.h +++ b/params.h @@ -46,15 +46,13 @@ struct mode_serve_params { char* control_socket_name; /* size of file */ off64_t size; - /* if you want the main thread to pause, set this to an writeable - * file descriptor. The main thread will then write a byte once it - * promises to hang any further writes. - */ - int pause_fd; - /* the main thread will set this when writes will be paused */ - int paused; - /* set to non-zero to use given destination connection as proxy */ - int proxy_fd; + + /* NB dining philosophers if we ever mave more than one thread + * that might need to pause the whole server. At the moment we only + * have the one. + */ + pthread_mutex_t l_accept; /* accept connections lock */ + pthread_mutex_t l_io ; /* read/write request lock */ struct mirror_status* mirror; int server; diff --git a/parse.c b/parse.c index 5dcad9c..e5dab6c 100644 --- a/parse.c +++ b/parse.c @@ -8,6 +8,7 @@ int atoi(const char *nptr); ((x) >= 'A' && (x) <= 'F' ) || \ (x) == ':' || (x) == '.' \ ) +/* FIXME: should change this to return negative on error like everything else */ int parse_ip_to_sockaddr(struct sockaddr* out, char* src) { char temp[64]; diff --git a/remote.c b/remote.c new file mode 100644 index 0000000..8e56def --- /dev/null +++ b/remote.c @@ -0,0 +1,51 @@ +#include "ioutil.h" +#include "util.h" + +#include +#include + +static const int max_response=1024; + +void do_remote_command(char* command, char* socket_name, int argc, char** argv) +{ + char newline=10; + int i; + int exit_status; + int remote = socket(AF_UNIX, SOCK_STREAM, 0); + struct sockaddr_un address; + char response[max_response]; + + memset(&address, 0, sizeof(address)); + + SERVER_ERROR_ON_FAILURE(remote, "Couldn't create client socket"); + + address.sun_family = AF_UNIX; + strncpy(address.sun_path, socket_name, sizeof(address.sun_path)); + + SERVER_ERROR_ON_FAILURE( + connect(remote, (struct sockaddr*) &address, sizeof(address)), + "Couldn't connect to %s", socket_name + ); + + write(remote, command, strlen(command)); + write(remote, &newline, 1); + for (i=0; i 0) + fprintf(stderr, "%s\n", strchr(response, ':')+2); + + exit(atoi(response)); + + close(remote); +} + diff --git a/serve.c b/serve.c index 9099408..a0e910d 100644 --- a/serve.c +++ b/serve.c @@ -98,7 +98,7 @@ void write_not_zeroes(struct client_params* client, off64_t from, int len) * hand-optimized something specific. */ if (zerobuffer[0] != 0 || - memcmp(zerobuffer, zerobuffer + 1, blockrun)) { + memcmp(zerobuffer, zerobuffer + 1, blockrun - 1)) { memcpy(dst, zerobuffer, blockrun); bit_set(map, bit); dirty(client->serve, from, blockrun); @@ -171,6 +171,11 @@ int client_serve_request(struct client_params* client) CLIENT_ERROR("Unknown request %08x", be32toh(request.type)); } + CLIENT_ERROR_ON_FAILURE( + pthread_mutex_lock(&client->serve->l_io), + "Problem with I/O lock" + ); + switch (be32toh(request.type)) { case REQUEST_READ: @@ -217,6 +222,12 @@ int client_serve_request(struct client_params* client) break; } + + CLIENT_ERROR_ON_FAILURE( + pthread_mutex_unlock(&client->serve->l_io), + "Problem with I/O unlock" + ); + return 0; } @@ -326,12 +337,19 @@ int is_included_in_acl(int list_length, struct ip_and_mask (*list)[], struct soc void serve_open_server_socket(struct mode_serve_params* params) { + int optval=1; + params->server = socket(params->bind_to.generic.sa_family == AF_INET ? PF_INET : PF_INET6, SOCK_STREAM, 0); SERVER_ERROR_ON_FAILURE(params->server, "Couldn't create server socket"); + SERVER_ERROR_ON_FAILURE( + setsockopt(params->server, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)), + "Couldn't set SO_REUSEADDR" + ); + SERVER_ERROR_ON_FAILURE( bind(params->server, ¶ms->bind_to.generic, sizeof(params->bind_to)), @@ -370,7 +388,7 @@ int cleanup_and_find_client_slot(struct mode_serve_params* params) else { uint64_t status1 = (uint64_t) status; params->nbd_client[i].thread = 0; - debug("nbd thread %d exited (%s)", (int) params->nbd_client[i].thread, s_client_address); + debug("nbd thread %d exited (%s) with status %ld", (int) params->nbd_client[i].thread, s_client_address, status1); } } @@ -387,7 +405,7 @@ void accept_nbd_client(struct mode_serve_params* params, int client_fd, struct s int slot = cleanup_and_find_client_slot(params); char s_client_address[64]; - if (inet_ntop(client_address->sa_family, sockaddr_address_data(client_address), s_client_address, 64) < 0) { + if (inet_ntop(client_address->sa_family, sockaddr_address_data(client_address), s_client_address, 64) == NULL) { write(client_fd, "Bad client_address", 18); close(client_fd); return; @@ -446,10 +464,20 @@ void serve_accept_loop(struct mode_serve_params* params) client_fd = accept(activity_fd, &client_address, &socklen); SERVER_ERROR_ON_FAILURE(client_fd, "accept() failed"); + SERVER_ERROR_ON_FAILURE( + pthread_mutex_lock(¶ms->l_accept), + "Problem with accept lock" + ); + if (activity_fd == params->server) accept_nbd_client(params, client_fd, &client_address); if (activity_fd == params->control) accept_control_connection(params, client_fd, &client_address); + + SERVER_ERROR_ON_FAILURE( + pthread_mutex_unlock(¶ms->l_accept), + "Problem with accept unlock" + ); } } @@ -469,6 +497,9 @@ void serve_init_allocation_map(struct mode_serve_params* params) void do_serve(struct mode_serve_params* params) { + pthread_mutex_init(¶ms->l_accept, NULL); + pthread_mutex_init(¶ms->l_io, NULL); + serve_open_server_socket(params); serve_open_control_socket(params); serve_init_allocation_map(params); diff --git a/util.c b/util.c index fd6a1ac..0f7e2f1 100644 --- a/util.c +++ b/util.c @@ -16,7 +16,7 @@ void error_init() main_thread = pthread_self(); } -void error(int consult_errno, int close_socket, const char* format, ...) +void error(int consult_errno, int close_socket, pthread_mutex_t* unlock, const char* format, ...) { va_list argptr; @@ -33,6 +33,9 @@ void error(int consult_errno, int close_socket, const char* format, ...) if (close_socket) close(close_socket); + if (unlock) + pthread_mutex_unlock(unlock); + fprintf(stderr, "\n"); if (pthread_equal(pthread_self(), main_thread)) diff --git a/util.h b/util.h index e1c5888..932f97d 100644 --- a/util.h +++ b/util.h @@ -6,7 +6,7 @@ void error_init(); -void error(int consult_errno, int close_socket, const char* format, ...); +void error(int consult_errno, int close_socket, pthread_mutex_t* unlock, const char* format, ...); void* xrealloc(void* ptr, size_t size); @@ -21,14 +21,14 @@ void* xmalloc(size_t size); #endif #define CLIENT_ERROR(msg, ...) \ - error(0, client->socket, msg, ##__VA_ARGS__) + error(0, client->socket, &client->serve->l_io, msg, ##__VA_ARGS__) #define CLIENT_ERROR_ON_FAILURE(test, msg, ...) \ - if (test < 0) { error(1, client->socket, msg, ##__VA_ARGS__); } + if (test < 0) { error(1, client->socket, &client->serve->l_io, msg, ##__VA_ARGS__); } #define SERVER_ERROR(msg, ...) \ - error(0, 0, msg, ##__VA_ARGS__) + error(0, 0, NULL, msg, ##__VA_ARGS__) #define SERVER_ERROR_ON_FAILURE(test, msg, ...) \ - if (test < 0) { error(1, 0, msg, ##__VA_ARGS__); } + if (test < 0) { error(1, 0, NULL, msg, ##__VA_ARGS__); } #endif