diff --git a/Rakefile b/Rakefile index 8e5cb5a..5f6fec2 100644 --- a/Rakefile +++ b/Rakefile @@ -17,13 +17,20 @@ end rule 'default' => 'flexnbd' namespace "test" do + task 'run' => ["unit", "scenarios"] + task 'build' => TEST_MODULES.map { |n| "tests/check_#{n}" } - task 'run' => 'build' do + + task 'unit' => 'build' do TEST_MODULES.each do |n| ENV['EF_DISABLE_BANNER'] = '1' sh "./tests/check_#{n}" end end + + task 'scenarios' do + sh "cd tests; ruby nbd_scenarios" + end end def gcc_link(target, objects) @@ -46,7 +53,7 @@ rule '.o' => '.c' do |t| end rule 'clean' do - sh "rm -f flexnbd " + ( + sh "rm -f *~ flexnbd " + ( OBJECTS + TEST_MODULES.map { |n| ["tests/check_#{n}", "tests/check_#{n}.o"] }.flatten ). diff --git a/control.c b/control.c index 497a52f..1930e9c 100644 --- a/control.c +++ b/control.c @@ -8,6 +8,7 @@ #include #include #include +#include static const int longest_run = 8<<20; @@ -15,16 +16,20 @@ void* mirror_runner(void* serve_params_uncast) { struct mode_serve_params *serve = (struct mode_serve_params*) serve_params_uncast; + const int max_passes = 7; /* biblical */ int pass; struct bitset_mapping *map = serve->mirror->dirty_map; - for (pass=0; pass < 7 /* biblical */; pass++) { + for (pass=0; pass < max_passes; pass++) { uint64_t current = 0; + uint64_t written = 0; debug("mirror start pass=%d", pass); while (current < serve->size) { - int run = bitset_run_count(map, current, longest_run); + int run; + + run = bitset_run_count(map, current, longest_run); debug("mirror current=%ld, run=%d", current, run); @@ -41,9 +46,13 @@ void* mirror_runner(void* serve_params_uncast) ); bitset_clear_range(map, current, run); + written += run; } current += run; } + + if (written == 0) + pass = max_passes-1; } return NULL; @@ -58,9 +67,11 @@ int control_mirror(struct control_params* client, int linesc, char** lines) struct mirror_status *mirror; union mysockaddr connect_to; char s_ip_address[64], s_port[8]; + uint64_t max_bytes_per_second; + int action_at_finish; - if (linesc != 2) { - write_socket("1: mirror only takes two parameters"); + if (linesc < 2) { + write_socket("1: mirror takes at least two parameters"); return -1; } @@ -76,6 +87,30 @@ int control_mirror(struct control_params* client, int linesc, char** lines) } connect_to.v4.sin_port = htobe16(connect_to.v4.sin_port); + max_bytes_per_second = 0; + if (linesc > 2) { + max_bytes_per_second = atoi(lines[2]); + } + + action_at_finish = ACTION_PROXY; + if (linesc > 3) { + if (strcmp("proxy", lines[3]) == 0) + action_at_finish = ACTION_PROXY; + else if (strcmp("exit", lines[3]) == 0) + action_at_finish = ACTION_EXIT; + else if (strcmp("nothing", lines[3]) == 0) + action_at_finish = ACTION_NOTHING; + else { + write_socket("1: action must be one of 'proxy', 'exit' or 'nothing'"); + return -1; + } + } + + if (linesc > 4) { + write_socket("1: unrecognised parameters to mirror"); + return -1; + } + fd = socket_connect(&connect_to.generic); remote_size = socket_nbd_read_hello(fd); @@ -83,7 +118,8 @@ int control_mirror(struct control_params* client, int linesc, char** lines) mirror = xmalloc(sizeof(struct mirror_status)); mirror->client = fd; - mirror->max_bytes_per_second = 0; + mirror->max_bytes_per_second = max_bytes_per_second; + mirror->action_at_finish = action_at_finish; CLIENT_ERROR_ON_FAILURE( open_and_mmap( @@ -122,7 +158,7 @@ int control_acl(struct control_params* client, int linesc, char** lines) parsed = parse_acl(&acl, linesc, lines); if (parsed != linesc) { - write(client->socket, "3: bad spec ", 12); + write(client->socket, "1: bad spec ", 12); write(client->socket, s_acl_entry[parsed], strlen(s_acl_entry[parsed])); write(client->socket, "\n", 1); diff --git a/params.h b/params.h index 5c1583b..6a94ea1 100644 --- a/params.h +++ b/params.h @@ -8,11 +8,18 @@ #include +enum mirror_finish_action { + ACTION_PROXY, + ACTION_EXIT, + ACTION_NOTHING +}; + struct mirror_status { pthread_t thread; int client; char *filename; off64_t max_bytes_per_second; + enum mirror_finish_action action_at_finish; char *mapped; struct bitset_mapping *dirty_map; @@ -24,13 +31,29 @@ struct control_params { }; struct mode_serve_params { + /* address/port to bind to */ union mysockaddr bind_to; + /* number of entries in current access control list*/ int acl_entries; + /* pointer to access control list entries*/ struct ip_and_mask (*acl)[0]; + /* file name to serve */ char* filename; + /* TCP backlog for listen() */ int tcp_backlog; + /* file name of UNIX control socket (or NULL if none) */ char* control_socket_name; + /* size of file */ off64_t size; + /* if you want the main thread to pause, set this to an writeable + * file descriptor. The main thread will then write a byte once it + * promises to hang any further writes. + */ + int pause_fd; + /* the main thread will set this when writes will be paused */ + int paused; + /* set to non-zero to use given destination connection as proxy */ + int proxy_fd; struct mirror_status* mirror; int server; @@ -49,13 +72,10 @@ struct mode_readwrite_params { struct client_params { int socket; - char* filename; int fileno; - off64_t size; char* mapped; - char* block_allocation_map; struct mode_serve_params* serve; /* FIXME: remove above duplication */ }; diff --git a/readwrite.c b/readwrite.c index 3f571d3..d469c76 100644 --- a/readwrite.c +++ b/readwrite.c @@ -97,7 +97,7 @@ void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) #define CHECK_RANGE(error_type) { \ off64_t size = socket_nbd_read_hello(params->client); \ - if (params->from < 0 || (params->from + params->len) >= size) \ + if (params->from < 0 || (params->from + params->len) > size) \ SERVER_ERROR(error_type \ " request %d+%d is out of range given size %d", \ params->from, params->len, size\ diff --git a/serve.c b/serve.c index df4c879..a19c3b7 100644 --- a/serve.c +++ b/serve.c @@ -37,7 +37,7 @@ static inline void dirty(struct mode_serve_params *serve, off64_t from, int len) */ void write_not_zeroes(struct client_params* client, off64_t from, int len) { - char *map = client->block_allocation_map; + char *map = client->serve->block_allocation_map; while (len > 0) { /* so we have to calculate how much of our input to consider @@ -152,7 +152,7 @@ int client_serve_request(struct client_params* client) case REQUEST_WRITE: /* check it's not out of range */ if (be64toh(request.from) < 0 || - be64toh(request.from)+be32toh(request.len) > client->size) { + be64toh(request.from)+be32toh(request.len) > client->serve->size) { debug("request read %ld+%d out of range", be64toh(request.from), be32toh(request.len) @@ -193,7 +193,7 @@ int client_serve_request(struct client_params* client) case REQUEST_WRITE: debug("request write %ld+%d", be64toh(request.from), be32toh(request.len)); - if (client->block_allocation_map) { + if (client->serve->block_allocation_map) { write_not_zeroes( client, be64toh(request.from), @@ -226,7 +226,7 @@ void client_send_hello(struct client_params* client) memcpy(init.passwd, INIT_PASSWD, sizeof(INIT_PASSWD)); init.magic = htobe64(INIT_MAGIC); - init.size = htobe64(client->size); + init.size = htobe64(client->serve->size); memset(init.reserved, 0, 128); CLIENT_ERROR_ON_FAILURE( writeloop(client->socket, &init, sizeof(init)), @@ -241,12 +241,12 @@ void* client_serve(void* client_uncast) //client_open_file(client); CLIENT_ERROR_ON_FAILURE( open_and_mmap( - client->filename, + client->serve->filename, &client->fileno, - &client->size, + NULL, (void**) &client->mapped ), - "Couldn't open/mmap file %s", client->filename + "Couldn't open/mmap file %s", client->serve->filename ); client_send_hello(client); @@ -262,7 +262,7 @@ void* client_serve(void* client_uncast) close(client->socket); close(client->fileno); - munmap(client->mapped, client->size); + munmap(client->mapped, client->serve->size); free(client); return NULL; @@ -357,9 +357,6 @@ void accept_nbd_client(struct mode_serve_params* params, int client_fd, struct s client_params = xmalloc(sizeof(struct client_params)); client_params->socket = client_fd; - client_params->filename = params->filename; - client_params->block_allocation_map = - params->block_allocation_map; client_params->serve = params; SERVER_ERROR_ON_FAILURE( diff --git a/tests/flexnbd.rb b/tests/flexnbd.rb new file mode 100644 index 0000000..58a014d --- /dev/null +++ b/tests/flexnbd.rb @@ -0,0 +1,63 @@ +require 'socket' + +# Noddy test class to exercise FlexNBD from the outside for testing. +# +class FlexNBD + attr_reader :bin, :ctrl, :pid, :ip, :port + + def initialize(bin, ip, port) + @bin = bin + raise "#{bin} not executable" unless File.executable?(bin) + @ctrl = "/tmp/.flexnbd.ctrl.#{Time.now.to_i}.#{rand}" + @ip = ip + @port = port + end + + def serve(ip, port, file, *acl) + @pid = fork do + exec("#{@bin} serve #{ip} #{port} #{file} #{ctrl} #{acl.join(' ')}") + end + end + + def kill + Process.kill("INT", @pid) + Process.wait(@pid) + end + + def read(offset, length) + IO.popen("#{@bin} read #{ip} #{port} #{offset} #{length}","r") do |fh| + return fh.read + end + raise "read failed" unless $?.success? + end + + def write(offset, data) + IO.popen("#{@bin} write #{ip} #{port} #{offset} #{data.length}","w") do |fh| + fh.write(data) + end + raise "write failed" unless $?.success? + nil + end + + def mirror(bandwidth=nil, action=nil) + control_command("mirror", ip, port, bandwidth, action) + end + + def acl(*acl) + control_command("acl", *acl) + end + + def status + end + + protected + def control_command(*args) + raise "Server not running" unless @pid + args = args.compact + UNIXSocket.open(@ctrl) do |u| + u.write(args.join("\n") + "\n") + code, message = u.readline.split(": ", 2) + return [code, message] + end + end +end diff --git a/tests/nbd_scenarios b/tests/nbd_scenarios new file mode 100644 index 0000000..a45fc70 --- /dev/null +++ b/tests/nbd_scenarios @@ -0,0 +1,71 @@ +#!/usr/bin/ruby + +require 'test/unit' +require 'flexnbd' +require 'test_file_writer' + +class NBDScenarios < Test::Unit::TestCase + def setup + @blocksize = 1024 + @filename1 = ".flexnbd.test.#{$$}.#{Time.now.to_i}.1" + @filename2 = ".flexnbd.test.#{$$}.#{Time.now.to_i}.2" + @ip = "127.0.0.1" + @available_ports = [*40000..41000] - listening_ports + @port1 = @available_ports.shift + @port2 = @available_ports.shift + @nbd1 = FlexNBD.new("../flexnbd", @ip, @port1) + end + + def teardown + @nbd1.kill rescue nil + [@filename1, @filename2].each do |f| + File.unlink(f) if File.exists?(f) + end + end + + def test_read1 + writefile1("f"*64) + serve1 + + [0, 12, 63].each do |num| + + assert_equal( + @nbd1.read(num*@blocksize, @blocksize), + @file1.read(num*@blocksize, @blocksize) + ) + end + + [124, 1200, 10028, 25488].each do |num| + assert_equal(@nbd1.read(num, 4), @file1.read(num, 4)) + end + end + + def test_writeread1 + writefile1("0"*64) + serve1 + + [0, 12, 63].each do |num| + data = "X"*@blocksize + @nbd1.write(num*@blocksize, data) + assert_equal(data, @file1.read(num*@blocksize, data.size)) + assert_equal(data, @nbd1.read(num*@blocksize, data.size)) + end + end + + protected + def serve1(*acl) + @nbd1.serve(@ip, @port1, @filename1, *acl) + end + + def writefile1(data) + @file1 = TestFileWriter.new(@filename1, @blocksize).write(data) + end + + def listening_ports + `netstat -ltn`. + split("\n"). + map { |x| x.split(/\s+/) }[2..-1]. + map { |l| l[3].split(":")[-1].to_i } + end +end + diff --git a/tests/test_file_writer.rb b/tests/test_file_writer.rb new file mode 100644 index 0000000..0ff6f5d --- /dev/null +++ b/tests/test_file_writer.rb @@ -0,0 +1,83 @@ +# Noddy test class for writing files to disc in predictable patterns +# in order to test FlexNBD. +# +class TestFileWriter + def initialize(filename, blocksize) + @fh = File.open(filename, "w+") + @blocksize = blocksize + @pattern = "" + end + + # We write in fixed block sizes, given by "blocksize" + # _ means skip a block + # 0 means write a block full of zeroes + # f means write a block with the file offset packed every 4 bytes + # + def write(data) + @pattern += data + + data.split("").each do |code| + if code == "_" + @fh.seek(@blocksize, IO::SEEK_CUR) + else + @fh.write(data(code)) + end + end + @fh.flush + self + end + + # Returns what the data ought to be at the given offset and length + # + def read_original(off, len) + r="" + current = 0 + @pattern.split("").each do |block| + if off >= current && (off+len) < current + blocksize + current += data(block, current)[ + current-off..(current+blocksize)-(off+len) + ] + end + current += @blocksize + end + r + end + + # Read what's actually in the file + # + def read(off, len) + @fh.seek(off, IO::SEEK_SET) + @fh.read(len) + end + + def untouched?(offset, len) + read(off, len) == read_original(off, len) + end + + def close + @fh.close + nil + end + + protected + + def data(code, at=@fh.tell) + case code + when "0", "_" + "\0" * @blocksize + when "X" + "X" * @blocksize + when "f" + r = "" + (@blocksize/4).times do + r += [at].pack("I") + at += 4 + end + r + else + raise "Unknown character '#{block}'" + end + end + +end +