diff --git a/curds/data.py b/curds/data.py index 2ff2708..005473d 100644 --- a/curds/data.py +++ b/curds/data.py @@ -12,30 +12,26 @@ WRITE = 'wb' emmental_data = xdg.BaseDirectory.save_data_path(__resource) -def data_file_path(filename): - return os.path.join(emmental_data, filename) - class DataFile: def __init__(self, path, mode): - self.path = data_file_path(path) + self.path = os.path.join(emmental_data, path) + self.temp = os.path.join(emmental_data, f".{path}.tmp") self.mode = mode self.file = None def __enter__(self): - path = self.path if self.mode == WRITE: - path = data_file_path(f".{os.path.basename(self.path)}.tmp") - if self.mode == WRITE or os.path.exists(path): - self.file = open(path, self.mode) + self.file = open(self.temp, self.mode) + elif self.mode == READ and os.path.exists(self.path): + self.file = open(self.path, self.mode) return self def __exit__(self, exp_type, exp_value, traceback): if self.file: self.file.close() self.file = None - if self.mode == WRITE: - path = data_file_path(f".{os.path.basename(self.path)}.tmp") - os.rename(path, self.path) + if self.mode == WRITE: + os.rename(self.temp, self.path) return True def pickle(self, obj): diff --git a/curds/test_data.py b/curds/test_data.py index 473bd80..63a436d 100644 --- a/curds/test_data.py +++ b/curds/test_data.py @@ -18,22 +18,19 @@ class TestDataModule(unittest.TestCase): self.assertEqual(data.emmental_data, testing_data) self.assertTrue(os.path.exists(testing_data)) self.assertTrue(os.path.isdir(testing_data)) - - def test_data_file_path(self): - path = os.path.join(testing_data, "test") - self.assertEqual(data.data_file_path("test"), path) - self.assertFalse(os.path.exists(path)) self.assertEqual(data.READ, 'rb') self.assertEqual(data.WRITE, 'wb') def test_data_file_init(self): f = data.DataFile("test.file", data.READ) self.assertEqual(f.path, testing_file) + self.assertEqual(f.temp, testing_temp) self.assertFalse(os.path.exists(testing_file)) self.assertEqual(f.mode, data.READ) self.assertIsNone(f.file) f = data.DataFile("test.file", data.WRITE) + self.assertEqual(f.temp, testing_temp) self.assertFalse(os.path.exists(testing_file)) self.assertEqual(f.mode, data.WRITE) self.assertIsNone(f.file)