From f688d416a55f54e2e491176113d10e02a0ac493d Mon Sep 17 00:00:00 2001 From: Matthew Bloch Date: Wed, 16 May 2012 11:58:41 +0100 Subject: [PATCH] Added write mode. --- flexnbd.c | 121 +++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 93 insertions(+), 28 deletions(-) diff --git a/flexnbd.c b/flexnbd.c index ada99d4..7e76624 100644 --- a/flexnbd.c +++ b/flexnbd.c @@ -54,7 +54,8 @@ void syntax() fprintf(stderr, "Syntax: flexnbd serve [ip addresses ...]\n" " flexnbd read > data\n" - " flexnbd write [length] < data\n" + " flexnbd write < data\n" + " flexnbd write \n" " flexnbd mirror \n" ); exit(1); @@ -487,26 +488,37 @@ off64_t socket_nbd_read_hello(int fd) return be64toh(init.size); } +void fill_request(struct nbd_request *request, int type, int from, int len) +{ + request->magic = htobe32(REQUEST_MAGIC); + request->type = htobe32(type); + ((int*) request->handle)[0] = rand(); + ((int*) request->handle)[1] = rand(); + request->from = htobe64(from); + request->len = htobe32(len); +} + +void read_reply(int fd, struct nbd_request *request, struct nbd_reply *reply) +{ + SERVER_ERROR_ON_FAILURE(readloop(fd, reply, sizeof(*reply)), + "Couldn't read reply"); + if (be32toh(reply->magic) != REPLY_MAGIC) + SERVER_ERROR("Reply magic incorrect (%p)", reply->magic); + if (be32toh(reply->error) != 0) + SERVER_ERROR("Server replied with error %d", reply->error); + if (strncmp(request->handle, reply->handle, 8) != 0) + SERVER_ERROR("Did not reply with correct handle"); +} + void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf) { struct nbd_request request; struct nbd_reply reply; - request.magic = htobe32(REQUEST_MAGIC); - request.type = htobe32(REQUEST_READ); - ((int*) request.handle)[0] = rand(); - ((int*) request.handle)[1] = rand(); - request.from = htobe64(from); - request.len = htobe32(len); - + fill_request(&request, REQUEST_READ, from, len); SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), "Couldn't write request"); - SERVER_ERROR_ON_FAILURE(readloop(fd, &reply, sizeof(reply)), - "Couldn't read reply"); - if (be32toh(reply.magic) != REPLY_MAGIC) - SERVER_ERROR("Reply magic incorrect (%p)", reply.magic); - if (be32toh(reply.error) != 0) - SERVER_ERROR("Server replied with error %d", reply.error); + read_reply(fd, &request, &reply); if (out_buf) { SERVER_ERROR_ON_FAILURE(readloop(fd, out_buf, len), @@ -520,14 +532,42 @@ void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf) } } +void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) +{ + struct nbd_request request; + struct nbd_reply reply; + + fill_request(&request, REQUEST_WRITE, from, len); + SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), + "Couldn't write request"); + + if (in_buf) { + SERVER_ERROR_ON_FAILURE(writeloop(fd, in_buf, len), + "Write failed"); + } + else { + SERVER_ERROR_ON_FAILURE( + splice_via_pipe_loop(in_fd, fd, len), + "Splice failed" + ); + } + + read_reply(fd, &request, &reply); +} + +#define CHECK_RANGE(error_type) { \ + off64_t size = socket_nbd_read_hello(params->client); \ + if (params->from < 0 || (params->from + params->len) >= size) \ + SERVER_ERROR(error_type \ + " request %d+%d is out of range given size %d", \ + params->from, params->len, size\ + ); \ +} + void do_read(struct mode_readwrite_params* params) { - off64_t size; params->client = socket_connect(¶ms->connect_to.generic); - size = socket_nbd_read_hello(params->client); - if (params->from < 0 || (params->from + params->len) >= size) - SERVER_ERROR("Read request %d+%d is out of range given size %d", - params->from, params->len, size); + CHECK_RANGE("read"); socket_nbd_read(params->client, params->from, params->len, params->data_fd, NULL); close(params->client); @@ -536,6 +576,9 @@ void do_read(struct mode_readwrite_params* params) void do_write(struct mode_readwrite_params* params) { params->client = socket_connect(¶ms->connect_to.generic); + CHECK_RANGE("write"); + socket_nbd_write(params->client, params->from, params->len, + params->data_fd, NULL); close(params->client); } @@ -653,11 +696,12 @@ void params_serve( } void params_readwrite( + int write_not_read, struct mode_readwrite_params* out, char* s_ip_address, char* s_port, char* s_from, - char* s_length + char* s_length_or_filename ) { if (s_ip_address == NULL) @@ -666,20 +710,43 @@ void params_readwrite( SERVER_ERROR("No port number supplied"); if (s_from == NULL) SERVER_ERROR("No from supplied"); - if (s_length == NULL) + if (s_length_or_filename == NULL) SERVER_ERROR("No length supplied"); if (parse_ip_to_sockaddr(&out->connect_to.generic, s_ip_address) == 0) - SERVER_ERROR("Couldn't parse connection address '%s'"); + SERVER_ERROR("Couldn't parse connection address '%s'", + s_ip_address); out->connect_to.v4.sin_port = atoi(s_port); if (out->connect_to.v4.sin_port < 0 || out->connect_to.v4.sin_port > 65535) SERVER_ERROR("Port number must be >= 0 and <= 65535"); out->connect_to.v4.sin_port = htobe16(out->connect_to.v4.sin_port); - out->from = atol(s_from); - out->len = atol(s_length); + + if (write_not_read) { + if (s_length_or_filename[0]-48 < 10) { + out->len = atol(s_length_or_filename); + out->data_fd = 0; + } + else { + out->data_fd = open( + s_length_or_filename, O_RDONLY); + SERVER_ERROR_ON_FAILURE(out->data_fd, + "Couldn't open %s", s_length_or_filename); + out->len = lseek64(out->data_fd, 0, SEEK_END); + SERVER_ERROR_ON_FAILURE(out->len, + "Couldn't find length of %s", s_length_or_filename); + SERVER_ERROR_ON_FAILURE( + lseek64(out->data_fd, 0, SEEK_SET), + "Couldn't rewind %s", s_length_or_filename + ); + } + } + else { + out->len = atol(s_length_or_filename); + out->data_fd = 1; + } } void mode(char* mode, int argc, char **argv) @@ -697,8 +764,7 @@ void mode(char* mode, int argc, char **argv) } else if (strcmp(mode, "read") == 0 ) { if (argc == 4) { - params_readwrite(¶ms.readwrite, argv[0], argv[1], argv[2], argv[3]); - params.readwrite.data_fd = 1; + params_readwrite(0, ¶ms.readwrite, argv[0], argv[1], argv[2], argv[3]); do_read(¶ms.readwrite); } else { @@ -707,8 +773,7 @@ void mode(char* mode, int argc, char **argv) } else if (strcmp(mode, "write") == 0 ) { if (argc == 4) { - params_readwrite(¶ms.readwrite, argv[0], argv[1], argv[2], argv[3]); - params.readwrite.data_fd = 0; + params_readwrite(1, ¶ms.readwrite, argv[0], argv[1], argv[2], argv[3]); do_write(¶ms.readwrite); } else {