xref: /openbmc/qemu/tests/qemu-iotests/fat16.py (revision 7d87775f)
1# A simple FAT16 driver that is used to test the `vvfat` driver in QEMU.
2#
3# Copyright (C) 2024 Amjad Alsharafi <amjadsharafi10@gmail.com>
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18from typing import Callable, List, Optional, Protocol, Set
19import string
20
21SECTOR_SIZE = 512
22DIRENTRY_SIZE = 32
23ALLOWED_FILE_CHARS = set(
24    "!#$%&'()-@^_`{}~" + string.digits + string.ascii_uppercase
25)
26
27
28class MBR:
29    def __init__(self, data: bytes):
30        assert len(data) == 512
31        self.partition_table = []
32        for i in range(4):
33            partition = data[446 + i * 16 : 446 + (i + 1) * 16]
34            self.partition_table.append(
35                {
36                    "status": partition[0],
37                    "start_head": partition[1],
38                    "start_sector": partition[2] & 0x3F,
39                    "start_cylinder": ((partition[2] & 0xC0) << 2)
40                                      | partition[3],
41                    "type": partition[4],
42                    "end_head": partition[5],
43                    "end_sector": partition[6] & 0x3F,
44                    "end_cylinder": ((partition[6] & 0xC0) << 2)
45                                    | partition[7],
46                    "start_lba": int.from_bytes(partition[8:12], "little"),
47                    "size": int.from_bytes(partition[12:16], "little"),
48                }
49            )
50
51    def __str__(self):
52        return "\n".join(
53            [
54                f"{i}: {partition}"
55                for i, partition in enumerate(self.partition_table)
56            ]
57        )
58
59
60class FatBootSector:
61    # pylint: disable=too-many-instance-attributes
62    def __init__(self, data: bytes):
63        assert len(data) == 512
64        self.bytes_per_sector = int.from_bytes(data[11:13], "little")
65        self.sectors_per_cluster = data[13]
66        self.reserved_sectors = int.from_bytes(data[14:16], "little")
67        self.fat_count = data[16]
68        self.root_entries = int.from_bytes(data[17:19], "little")
69        total_sectors_16 = int.from_bytes(data[19:21], "little")
70        self.media_descriptor = data[21]
71        self.sectors_per_fat = int.from_bytes(data[22:24], "little")
72        self.sectors_per_track = int.from_bytes(data[24:26], "little")
73        self.heads = int.from_bytes(data[26:28], "little")
74        self.hidden_sectors = int.from_bytes(data[28:32], "little")
75        total_sectors_32 = int.from_bytes(data[32:36], "little")
76        assert (
77            total_sectors_16 == 0 or total_sectors_32 == 0
78        ), "Both total sectors (16 and 32) fields are non-zero"
79        self.total_sectors = total_sectors_16 or total_sectors_32
80        self.drive_number = data[36]
81        self.volume_id = int.from_bytes(data[39:43], "little")
82        self.volume_label = data[43:54].decode("ascii").strip()
83        self.fs_type = data[54:62].decode("ascii").strip()
84
85    def root_dir_start(self):
86        """
87        Calculate the start sector of the root directory.
88        """
89        return self.reserved_sectors + self.fat_count * self.sectors_per_fat
90
91    def root_dir_size(self):
92        """
93        Calculate the size of the root directory in sectors.
94        """
95        return (
96            self.root_entries * DIRENTRY_SIZE + self.bytes_per_sector - 1
97        ) // self.bytes_per_sector
98
99    def data_sector_start(self):
100        """
101        Calculate the start sector of the data region.
102        """
103        return self.root_dir_start() + self.root_dir_size()
104
105    def first_sector_of_cluster(self, cluster: int) -> int:
106        """
107        Calculate the first sector of the given cluster.
108        """
109        return (
110            self.data_sector_start() + (cluster - 2) * self.sectors_per_cluster
111        )
112
113    def cluster_bytes(self):
114        """
115        Calculate the number of bytes in a cluster.
116        """
117        return self.bytes_per_sector * self.sectors_per_cluster
118
119    def __str__(self):
120        return (
121            f"Bytes per sector: {self.bytes_per_sector}\n"
122            f"Sectors per cluster: {self.sectors_per_cluster}\n"
123            f"Reserved sectors: {self.reserved_sectors}\n"
124            f"FAT count: {self.fat_count}\n"
125            f"Root entries: {self.root_entries}\n"
126            f"Total sectors: {self.total_sectors}\n"
127            f"Media descriptor: {self.media_descriptor}\n"
128            f"Sectors per FAT: {self.sectors_per_fat}\n"
129            f"Sectors per track: {self.sectors_per_track}\n"
130            f"Heads: {self.heads}\n"
131            f"Hidden sectors: {self.hidden_sectors}\n"
132            f"Drive number: {self.drive_number}\n"
133            f"Volume ID: {self.volume_id}\n"
134            f"Volume label: {self.volume_label}\n"
135            f"FS type: {self.fs_type}\n"
136        )
137
138
139class FatDirectoryEntry:
140    # pylint: disable=too-many-instance-attributes
141    def __init__(self, data: bytes, sector: int, offset: int):
142        self.name = data[0:8].decode("ascii").strip()
143        self.ext = data[8:11].decode("ascii").strip()
144        self.attributes = data[11]
145        self.reserved = data[12]
146        self.create_time_tenth = data[13]
147        self.create_time = int.from_bytes(data[14:16], "little")
148        self.create_date = int.from_bytes(data[16:18], "little")
149        self.last_access_date = int.from_bytes(data[18:20], "little")
150        high_cluster = int.from_bytes(data[20:22], "little")
151        self.last_mod_time = int.from_bytes(data[22:24], "little")
152        self.last_mod_date = int.from_bytes(data[24:26], "little")
153        low_cluster = int.from_bytes(data[26:28], "little")
154        self.cluster = (high_cluster << 16) | low_cluster
155        self.size_bytes = int.from_bytes(data[28:32], "little")
156
157        # extra (to help write back to disk)
158        self.sector = sector
159        self.offset = offset
160
161    def as_bytes(self) -> bytes:
162        return (
163            self.name.ljust(8, " ").encode("ascii")
164            + self.ext.ljust(3, " ").encode("ascii")
165            + self.attributes.to_bytes(1, "little")
166            + self.reserved.to_bytes(1, "little")
167            + self.create_time_tenth.to_bytes(1, "little")
168            + self.create_time.to_bytes(2, "little")
169            + self.create_date.to_bytes(2, "little")
170            + self.last_access_date.to_bytes(2, "little")
171            + (self.cluster >> 16).to_bytes(2, "little")
172            + self.last_mod_time.to_bytes(2, "little")
173            + self.last_mod_date.to_bytes(2, "little")
174            + (self.cluster & 0xFFFF).to_bytes(2, "little")
175            + self.size_bytes.to_bytes(4, "little")
176        )
177
178    def whole_name(self):
179        if self.ext:
180            return f"{self.name}.{self.ext}"
181        else:
182            return self.name
183
184    def __str__(self):
185        return (
186            f"Name: {self.name}\n"
187            f"Ext: {self.ext}\n"
188            f"Attributes: {self.attributes}\n"
189            f"Reserved: {self.reserved}\n"
190            f"Create time tenth: {self.create_time_tenth}\n"
191            f"Create time: {self.create_time}\n"
192            f"Create date: {self.create_date}\n"
193            f"Last access date: {self.last_access_date}\n"
194            f"Last mod time: {self.last_mod_time}\n"
195            f"Last mod date: {self.last_mod_date}\n"
196            f"Cluster: {self.cluster}\n"
197            f"Size: {self.size_bytes}\n"
198        )
199
200    def __repr__(self):
201        # convert to dict
202        return str(vars(self))
203
204
205class SectorReader(Protocol):
206    def __call__(self, start_sector: int, num_sectors: int = 1) -> bytes: ...
207
208# pylint: disable=broad-exception-raised
209class Fat16:
210    def __init__(
211        self,
212        start_sector: int,
213        size: int,
214        sector_reader: SectorReader,
215        sector_writer: Callable[[int, bytes], None]
216    ):
217        self.start_sector = start_sector
218        self.size_in_sectors = size
219        self.sector_reader = sector_reader
220        self.sector_writer = sector_writer
221
222        self.boot_sector = FatBootSector(self.sector_reader(start_sector, 1))
223
224        fat_size_in_sectors = (
225            self.boot_sector.sectors_per_fat * self.boot_sector.fat_count
226        )
227        self.fats = self.read_sectors(
228            self.boot_sector.reserved_sectors, fat_size_in_sectors
229        )
230        self.fats_dirty_sectors: Set[int] = set()
231
232    def read_sectors(self, start_sector: int, num_sectors: int) -> bytes:
233        return self.sector_reader(start_sector + self.start_sector,
234                                  num_sectors)
235
236    def write_sectors(self, start_sector: int, data: bytes) -> None:
237        return self.sector_writer(start_sector + self.start_sector, data)
238
239    def directory_from_bytes(
240        self, data: bytes, start_sector: int
241    ) -> List[FatDirectoryEntry]:
242        """
243        Convert `bytes` into a list of `FatDirectoryEntry` objects.
244        Will ignore long file names.
245        Will stop when it encounters a 0x00 byte.
246        """
247
248        entries = []
249        for i in range(0, len(data), DIRENTRY_SIZE):
250            entry = data[i : i + DIRENTRY_SIZE]
251
252            current_sector = start_sector + (i // SECTOR_SIZE)
253            current_offset = i % SECTOR_SIZE
254
255            if entry[0] == 0:
256                break
257
258            if entry[0] == 0xE5:
259                # Deleted file
260                continue
261
262            if entry[11] & 0xF == 0xF:
263                # Long file name
264                continue
265
266            entries.append(
267                FatDirectoryEntry(entry, current_sector, current_offset)
268            )
269        return entries
270
271    def read_root_directory(self) -> List[FatDirectoryEntry]:
272        root_dir = self.read_sectors(
273            self.boot_sector.root_dir_start(), self.boot_sector.root_dir_size()
274        )
275        return self.directory_from_bytes(
276            root_dir, self.boot_sector.root_dir_start()
277        )
278
279    def read_fat_entry(self, cluster: int) -> int:
280        """
281        Read the FAT entry for the given cluster.
282        """
283        fat_offset = cluster * 2  # FAT16
284        return int.from_bytes(self.fats[fat_offset : fat_offset + 2], "little")
285
286    def write_fat_entry(self, cluster: int, value: int) -> None:
287        """
288        Write the FAT entry for the given cluster.
289        """
290        fat_offset = cluster * 2
291        self.fats = (
292            self.fats[:fat_offset]
293            + value.to_bytes(2, "little")
294            + self.fats[fat_offset + 2 :]
295        )
296        self.fats_dirty_sectors.add(fat_offset // SECTOR_SIZE)
297
298    def flush_fats(self) -> None:
299        """
300        Write the FATs back to the disk.
301        """
302        for sector in self.fats_dirty_sectors:
303            data = self.fats[sector * SECTOR_SIZE : (sector + 1) * SECTOR_SIZE]
304            sector = self.boot_sector.reserved_sectors + sector
305            self.write_sectors(sector, data)
306        self.fats_dirty_sectors = set()
307
308    def next_cluster(self, cluster: int) -> Optional[int]:
309        """
310        Get the next cluster in the chain.
311        If its `None`, then its the last cluster.
312        The function will crash if the next cluster
313        is `FREE` (unexpected) or invalid entry.
314        """
315        fat_entry = self.read_fat_entry(cluster)
316        if fat_entry == 0:
317            raise Exception("Unexpected: FREE cluster")
318        if fat_entry == 1:
319            raise Exception("Unexpected: RESERVED cluster")
320        if fat_entry >= 0xFFF8:
321            return None
322        if fat_entry >= 0xFFF7:
323            raise Exception("Invalid FAT entry")
324
325        return fat_entry
326
327    def next_free_cluster(self) -> int:
328        """
329        Find the next free cluster.
330        """
331        # simple linear search
332        for i in range(2, 0xFFFF):
333            if self.read_fat_entry(i) == 0:
334                return i
335        raise Exception("No free clusters")
336
337    def next_free_cluster_non_continuous(self) -> int:
338        """
339        Find the next free cluster, but makes sure
340        that the cluster before and after it are not allocated.
341        """
342        # simple linear search
343        before = False
344        for i in range(2, 0xFFFF):
345            if self.read_fat_entry(i) == 0:
346                if before and self.read_fat_entry(i + 1) == 0:
347                    return i
348                else:
349                    before = True
350            else:
351                before = False
352
353        raise Exception("No free clusters")
354
355    def read_cluster(self, cluster: int) -> bytes:
356        """
357        Read the cluster at the given cluster.
358        """
359        return self.read_sectors(
360            self.boot_sector.first_sector_of_cluster(cluster),
361            self.boot_sector.sectors_per_cluster,
362        )
363
364    def write_cluster(self, cluster: int, data: bytes) -> None:
365        """
366        Write the cluster at the given cluster.
367        """
368        assert len(data) == self.boot_sector.cluster_bytes()
369        self.write_sectors(
370            self.boot_sector.first_sector_of_cluster(cluster),
371            data,
372        )
373
374    def read_directory(
375        self, cluster: Optional[int]
376    ) -> List[FatDirectoryEntry]:
377        """
378        Read the directory at the given cluster.
379        """
380        entries = []
381        while cluster is not None:
382            data = self.read_cluster(cluster)
383            entries.extend(
384                self.directory_from_bytes(
385                    data, self.boot_sector.first_sector_of_cluster(cluster)
386                )
387            )
388            cluster = self.next_cluster(cluster)
389        return entries
390
391    def add_direntry(
392        self, cluster: Optional[int], name: str, ext: str, attributes: int
393    ) -> FatDirectoryEntry:
394        """
395        Add a new directory entry to the given cluster.
396        If the cluster is `None`, then it will be added to the root directory.
397        """
398
399        def find_free_entry(data: bytes) -> Optional[int]:
400            for i in range(0, len(data), DIRENTRY_SIZE):
401                entry = data[i : i + DIRENTRY_SIZE]
402                if entry[0] == 0 or entry[0] == 0xE5:
403                    return i
404            return None
405
406        assert len(name) <= 8, "Name must be 8 characters or less"
407        assert len(ext) <= 3, "Ext must be 3 characters or less"
408        assert attributes % 0x15 != 0x15, "Invalid attributes"
409
410        # initial dummy data
411        new_entry = FatDirectoryEntry(b"\0" * 32, 0, 0)
412        new_entry.name = name.ljust(8, " ")
413        new_entry.ext = ext.ljust(3, " ")
414        new_entry.attributes = attributes
415        new_entry.reserved = 0
416        new_entry.create_time_tenth = 0
417        new_entry.create_time = 0
418        new_entry.create_date = 0
419        new_entry.last_access_date = 0
420        new_entry.last_mod_time = 0
421        new_entry.last_mod_date = 0
422        new_entry.cluster = self.next_free_cluster()
423        new_entry.size_bytes = 0
424
425        # mark as EOF
426        self.write_fat_entry(new_entry.cluster, 0xFFFF)
427
428        if cluster is None:
429            for i in range(self.boot_sector.root_dir_size()):
430                sector_data = self.read_sectors(
431                    self.boot_sector.root_dir_start() + i, 1
432                )
433                offset = find_free_entry(sector_data)
434                if offset is not None:
435                    new_entry.sector = self.boot_sector.root_dir_start() + i
436                    new_entry.offset = offset
437                    self.update_direntry(new_entry)
438                    return new_entry
439        else:
440            while cluster is not None:
441                data = self.read_cluster(cluster)
442                offset = find_free_entry(data)
443                if offset is not None:
444                    new_entry.sector = (
445                        self.boot_sector.first_sector_of_cluster(cluster)
446                         + (offset // SECTOR_SIZE))
447                    new_entry.offset = offset % SECTOR_SIZE
448                    self.update_direntry(new_entry)
449                    return new_entry
450                cluster = self.next_cluster(cluster)
451
452        raise Exception("No free directory entries")
453
454    def update_direntry(self, entry: FatDirectoryEntry) -> None:
455        """
456        Write the directory entry back to the disk.
457        """
458        sector = self.read_sectors(entry.sector, 1)
459        sector = (
460            sector[: entry.offset]
461            + entry.as_bytes()
462            + sector[entry.offset + DIRENTRY_SIZE :]
463        )
464        self.write_sectors(entry.sector, sector)
465
466    def find_direntry(self, path: str) -> Optional[FatDirectoryEntry]:
467        """
468        Find the directory entry for the given path.
469        """
470        assert path[0] == "/", "Path must start with /"
471
472        path = path[1:]  # remove the leading /
473        parts = path.split("/")
474        directory = self.read_root_directory()
475
476        current_entry = None
477
478        for i, part in enumerate(parts):
479            is_last = i == len(parts) - 1
480
481            for entry in directory:
482                if entry.whole_name() == part:
483                    current_entry = entry
484                    break
485            if current_entry is None:
486                return None
487
488            if is_last:
489                return current_entry
490
491            if current_entry.attributes & 0x10 == 0:
492                raise Exception(
493                    f"{current_entry.whole_name()} is not a directory"
494                )
495
496            directory = self.read_directory(current_entry.cluster)
497
498        assert False, "Exited loop with is_last == False"
499
500    def read_file(self, entry: Optional[FatDirectoryEntry]) -> Optional[bytes]:
501        """
502        Read the content of the file at the given path.
503        """
504        if entry is None:
505            return None
506        if entry.attributes & 0x10 != 0:
507            raise Exception(f"{entry.whole_name()} is a directory")
508
509        data = b""
510        cluster: Optional[int] = entry.cluster
511        while cluster is not None and len(data) <= entry.size_bytes:
512            data += self.read_cluster(cluster)
513            cluster = self.next_cluster(cluster)
514        return data[: entry.size_bytes]
515
516    def truncate_file(
517        self,
518        entry: FatDirectoryEntry,
519        new_size: int,
520        allocate_non_continuous: bool = False,
521    ) -> None:
522        """
523        Truncate the file at the given path to the new size.
524        """
525        if entry is None:
526            raise Exception("entry is None")
527        if entry.attributes & 0x10 != 0:
528            raise Exception(f"{entry.whole_name()} is a directory")
529
530        def clusters_from_size(size: int) -> int:
531            return (
532                size + self.boot_sector.cluster_bytes() - 1
533            ) // self.boot_sector.cluster_bytes()
534
535        # First, allocate new FATs if we need to
536        required_clusters = clusters_from_size(new_size)
537        current_clusters = clusters_from_size(entry.size_bytes)
538
539        affected_clusters = set()
540
541        # Keep at least one cluster, easier to manage this way
542        if required_clusters == 0:
543            required_clusters = 1
544        if current_clusters == 0:
545            current_clusters = 1
546
547        cluster: Optional[int]
548
549        if required_clusters > current_clusters:
550            # Allocate new clusters
551            cluster = entry.cluster
552            to_add = required_clusters
553            for _ in range(current_clusters - 1):
554                to_add -= 1
555                assert cluster is not None, "Cluster is None"
556                affected_clusters.add(cluster)
557                cluster = self.next_cluster(cluster)
558            assert required_clusters > 0, "No new clusters to allocate"
559            assert cluster is not None, "Cluster is None"
560            assert (
561                self.next_cluster(cluster) is None
562            ), "Cluster is not the last cluster"
563
564            # Allocate new clusters
565            for _ in range(to_add - 1):
566                if allocate_non_continuous:
567                    new_cluster = self.next_free_cluster_non_continuous()
568                else:
569                    new_cluster = self.next_free_cluster()
570                self.write_fat_entry(cluster, new_cluster)
571                self.write_fat_entry(new_cluster, 0xFFFF)
572                cluster = new_cluster
573
574        elif required_clusters < current_clusters:
575            # Truncate the file
576            cluster = entry.cluster
577            for _ in range(required_clusters - 1):
578                assert cluster is not None, "Cluster is None"
579                cluster = self.next_cluster(cluster)
580            assert cluster is not None, "Cluster is None"
581
582            next_cluster = self.next_cluster(cluster)
583            # mark last as EOF
584            self.write_fat_entry(cluster, 0xFFFF)
585            # free the rest
586            while next_cluster is not None:
587                cluster = next_cluster
588                next_cluster = self.next_cluster(next_cluster)
589                self.write_fat_entry(cluster, 0)
590
591        self.flush_fats()
592
593        # verify number of clusters
594        cluster = entry.cluster
595        count = 0
596        while cluster is not None:
597            count += 1
598            affected_clusters.add(cluster)
599            cluster = self.next_cluster(cluster)
600        assert (
601            count == required_clusters
602        ), f"Expected {required_clusters} clusters, got {count}"
603
604        # update the size
605        entry.size_bytes = new_size
606        self.update_direntry(entry)
607
608        # trigger every affected cluster
609        for cluster in affected_clusters:
610            first_sector = self.boot_sector.first_sector_of_cluster(cluster)
611            first_sector_data = self.read_sectors(first_sector, 1)
612            self.write_sectors(first_sector, first_sector_data)
613
614    def write_file(self, entry: FatDirectoryEntry, data: bytes) -> None:
615        """
616        Write the content of the file at the given path.
617        """
618        if entry is None:
619            raise Exception("entry is None")
620        if entry.attributes & 0x10 != 0:
621            raise Exception(f"{entry.whole_name()} is a directory")
622
623        data_len = len(data)
624
625        self.truncate_file(entry, data_len)
626
627        cluster: Optional[int] = entry.cluster
628        while cluster is not None:
629            data_to_write = data[: self.boot_sector.cluster_bytes()]
630            if len(data_to_write) < self.boot_sector.cluster_bytes():
631                old_data = self.read_cluster(cluster)
632                data_to_write += old_data[len(data_to_write) :]
633
634            self.write_cluster(cluster, data_to_write)
635            data = data[self.boot_sector.cluster_bytes() :]
636            if len(data) == 0:
637                break
638            cluster = self.next_cluster(cluster)
639
640        assert (
641            len(data) == 0
642        ), "Data was not written completely, clusters missing"
643
644    def create_file(self, path: str) -> Optional[FatDirectoryEntry]:
645        """
646        Create a new file at the given path.
647        """
648        assert path[0] == "/", "Path must start with /"
649
650        path = path[1:]  # remove the leading /
651
652        parts = path.split("/")
653
654        directory_cluster = None
655        directory = self.read_root_directory()
656
657        parts, filename = parts[:-1], parts[-1]
658
659        for _, part in enumerate(parts):
660            current_entry = None
661            for entry in directory:
662                if entry.whole_name() == part:
663                    current_entry = entry
664                    break
665            if current_entry is None:
666                return None
667
668            if current_entry.attributes & 0x10 == 0:
669                raise Exception(
670                    f"{current_entry.whole_name()} is not a directory"
671                )
672
673            directory = self.read_directory(current_entry.cluster)
674            directory_cluster = current_entry.cluster
675
676        # add new entry to the directory
677
678        filename, ext = filename.split(".")
679
680        if len(ext) > 3:
681            raise Exception("Ext must be 3 characters or less")
682        if len(filename) > 8:
683            raise Exception("Name must be 8 characters or less")
684
685        for c in filename + ext:
686
687            if c not in ALLOWED_FILE_CHARS:
688                raise Exception("Invalid character in filename")
689
690        return self.add_direntry(directory_cluster, filename, ext, 0)
691