diff --git a/curds/data.py b/curds/data.py index 1fd5a16..2ff2708 100644 --- a/curds/data.py +++ b/curds/data.py @@ -1,12 +1,47 @@ # Copyright 2019 (c) Anna Schumaker. import os +import pickle import xdg.BaseDirectory __resource = "emmental" if os.environ.get("EMMENTAL_TESTING"): __resource = "emmental-testing" +READ = 'rb' +WRITE = 'wb' + emmental_data = xdg.BaseDirectory.save_data_path(__resource) def data_file_path(filename): return os.path.join(emmental_data, filename) + +class DataFile: + def __init__(self, path, mode): + self.path = data_file_path(path) + self.mode = mode + self.file = None + + def __enter__(self): + path = self.path + if self.mode == WRITE: + path = data_file_path(f".{os.path.basename(self.path)}.tmp") + if self.mode == WRITE or os.path.exists(path): + self.file = open(path, self.mode) + return self + + def __exit__(self, exp_type, exp_value, traceback): + if self.file: + self.file.close() + self.file = None + if self.mode == WRITE: + path = data_file_path(f".{os.path.basename(self.path)}.tmp") + os.rename(path, self.path) + return True + + def pickle(self, obj): + if self.file: + pickle.dump(obj, self.file, pickle.HIGHEST_PROTOCOL) + + def unpickle(self): + if self.file: + return pickle.load(self.file) diff --git a/curds/test_data.py b/curds/test_data.py index e8e7dfa..473bd80 100644 --- a/curds/test_data.py +++ b/curds/test_data.py @@ -6,8 +6,14 @@ import xdg.BaseDirectory xdg_data_home = xdg.BaseDirectory.xdg_data_home testing_data = os.path.join(xdg_data_home, "emmental-testing") +testing_file = os.path.join(testing_data, "test.file") +testing_temp = os.path.join(testing_data, ".test.file.tmp") class TestDataModule(unittest.TestCase): + def setUp(self): + if os.path.exists(testing_file): + os.remove(testing_file) + def test_dir(self): self.assertEqual(data.emmental_data, testing_data) self.assertTrue(os.path.exists(testing_data)) @@ -17,3 +23,41 @@ class TestDataModule(unittest.TestCase): path = os.path.join(testing_data, "test") self.assertEqual(data.data_file_path("test"), path) self.assertFalse(os.path.exists(path)) + self.assertEqual(data.READ, 'rb') + self.assertEqual(data.WRITE, 'wb') + + def test_data_file_init(self): + f = data.DataFile("test.file", data.READ) + self.assertEqual(f.path, testing_file) + self.assertFalse(os.path.exists(testing_file)) + self.assertEqual(f.mode, data.READ) + self.assertIsNone(f.file) + + f = data.DataFile("test.file", data.WRITE) + self.assertFalse(os.path.exists(testing_file)) + self.assertEqual(f.mode, data.WRITE) + self.assertIsNone(f.file) + + def test_data_file_read_write(self): + test = [ "test", "saving", "a", "list" ] + with data.DataFile("test.file", data.READ) as f: + self.assertIsNone(f.file) + f.pickle(test) + self.assertIsNone(f.unpickle()) + + with data.DataFile("test.file", data.WRITE) as f: + self.assertIsNotNone(f.file) + self.assertEqual(f.file.name, testing_temp) + self.assertFalse(os.path.exists(testing_file)) + self.assertTrue(os.path.exists(testing_temp)) + f.pickle(test) + + self.assertIsNone(f.file) + self.assertFalse(os.path.exists(testing_temp)) + self.assertTrue(os.path.exists(testing_file)) + + with data.DataFile("test.file", data.READ) as f: + self.assertIsNotNone(f.file) + self.assertEqual(f.file.name, testing_file) + lst = f.unpickle() + self.assertEqual(test, lst)