1__all__ = [
2    'LXMLTreeBuilderForXML',
3    'LXMLTreeBuilder',
4    ]
5
6from io import BytesIO
7from io import StringIO
8import collections
9from lxml import etree
10from bs4.element import (
11    Comment,
12    Doctype,
13    NamespacedAttribute,
14    ProcessingInstruction,
15)
16from bs4.builder import (
17    FAST,
18    HTML,
19    HTMLTreeBuilder,
20    PERMISSIVE,
21    ParserRejectedMarkup,
22    TreeBuilder,
23    XML)
24from bs4.dammit import EncodingDetector
25
26LXML = 'lxml'
27
28class LXMLTreeBuilderForXML(TreeBuilder):
29    DEFAULT_PARSER_CLASS = etree.XMLParser
30
31    is_xml = True
32
33    NAME = "lxml-xml"
34    ALTERNATE_NAMES = ["xml"]
35
36    # Well, it's permissive by XML parser standards.
37    features = [NAME, LXML, XML, FAST, PERMISSIVE]
38
39    CHUNK_SIZE = 512
40
41    # This namespace mapping is specified in the XML Namespace
42    # standard.
43    DEFAULT_NSMAPS = {'http://www.w3.org/XML/1998/namespace' : "xml"}
44
45    def default_parser(self, encoding):
46        # This can either return a parser object or a class, which
47        # will be instantiated with default arguments.
48        if self._default_parser is not None:
49            return self._default_parser
50        return etree.XMLParser(
51            target=self, strip_cdata=False, recover=True, encoding=encoding)
52
53    def parser_for(self, encoding):
54        # Use the default parser.
55        parser = self.default_parser(encoding)
56
57        if isinstance(parser, collections.Callable):
58            # Instantiate the parser with default arguments
59            parser = parser(target=self, strip_cdata=False, encoding=encoding)
60        return parser
61
62    def __init__(self, parser=None, empty_element_tags=None):
63        # TODO: Issue a warning if parser is present but not a
64        # callable, since that means there's no way to create new
65        # parsers for different encodings.
66        self._default_parser = parser
67        if empty_element_tags is not None:
68            self.empty_element_tags = set(empty_element_tags)
69        self.soup = None
70        self.nsmaps = [self.DEFAULT_NSMAPS]
71
72    def _getNsTag(self, tag):
73        # Split the namespace URL out of a fully-qualified lxml tag
74        # name. Copied from lxml's src/lxml/sax.py.
75        if tag[0] == '{':
76            return tuple(tag[1:].split('}', 1))
77        else:
78            return (None, tag)
79
80    def prepare_markup(self, markup, user_specified_encoding=None,
81                       exclude_encodings=None,
82                       document_declared_encoding=None):
83        """
84        :yield: A series of 4-tuples.
85         (markup, encoding, declared encoding,
86          has undergone character replacement)
87
88        Each 4-tuple represents a strategy for parsing the document.
89        """
90        if isinstance(markup, str):
91            # We were given Unicode. Maybe lxml can parse Unicode on
92            # this system?
93            yield markup, None, document_declared_encoding, False
94
95        if isinstance(markup, str):
96            # No, apparently not. Convert the Unicode to UTF-8 and
97            # tell lxml to parse it as UTF-8.
98            yield (markup.encode("utf8"), "utf8",
99                   document_declared_encoding, False)
100
101        # Instead of using UnicodeDammit to convert the bytestring to
102        # Unicode using different encodings, use EncodingDetector to
103        # iterate over the encodings, and tell lxml to try to parse
104        # the document as each one in turn.
105        is_html = not self.is_xml
106        try_encodings = [user_specified_encoding, document_declared_encoding]
107        detector = EncodingDetector(
108            markup, try_encodings, is_html, exclude_encodings)
109        for encoding in detector.encodings:
110            yield (detector.markup, encoding, document_declared_encoding, False)
111
112    def feed(self, markup):
113        if isinstance(markup, bytes):
114            markup = BytesIO(markup)
115        elif isinstance(markup, str):
116            markup = StringIO(markup)
117
118        # Call feed() at least once, even if the markup is empty,
119        # or the parser won't be initialized.
120        data = markup.read(self.CHUNK_SIZE)
121        try:
122            self.parser = self.parser_for(self.soup.original_encoding)
123            self.parser.feed(data)
124            while len(data) != 0:
125                # Now call feed() on the rest of the data, chunk by chunk.
126                data = markup.read(self.CHUNK_SIZE)
127                if len(data) != 0:
128                    self.parser.feed(data)
129            self.parser.close()
130        except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
131            raise ParserRejectedMarkup(str(e))
132
133    def close(self):
134        self.nsmaps = [self.DEFAULT_NSMAPS]
135
136    def start(self, name, attrs, nsmap={}):
137        # Make sure attrs is a mutable dict--lxml may send an immutable dictproxy.
138        attrs = dict(attrs)
139        nsprefix = None
140        # Invert each namespace map as it comes in.
141        if len(self.nsmaps) > 1:
142            # There are no new namespaces for this tag, but
143            # non-default namespaces are in play, so we need a
144            # separate tag stack to know when they end.
145            self.nsmaps.append(None)
146        elif len(nsmap) > 0:
147            # A new namespace mapping has come into play.
148            inverted_nsmap = dict((value, key) for key, value in list(nsmap.items()))
149            self.nsmaps.append(inverted_nsmap)
150            # Also treat the namespace mapping as a set of attributes on the
151            # tag, so we can recreate it later.
152            attrs = attrs.copy()
153            for prefix, namespace in list(nsmap.items()):
154                attribute = NamespacedAttribute(
155                    "xmlns", prefix, "http://www.w3.org/2000/xmlns/")
156                attrs[attribute] = namespace
157
158        # Namespaces are in play. Find any attributes that came in
159        # from lxml with namespaces attached to their names, and
160        # turn then into NamespacedAttribute objects.
161        new_attrs = {}
162        for attr, value in list(attrs.items()):
163            namespace, attr = self._getNsTag(attr)
164            if namespace is None:
165                new_attrs[attr] = value
166            else:
167                nsprefix = self._prefix_for_namespace(namespace)
168                attr = NamespacedAttribute(nsprefix, attr, namespace)
169                new_attrs[attr] = value
170        attrs = new_attrs
171
172        namespace, name = self._getNsTag(name)
173        nsprefix = self._prefix_for_namespace(namespace)
174        self.soup.handle_starttag(name, namespace, nsprefix, attrs)
175
176    def _prefix_for_namespace(self, namespace):
177        """Find the currently active prefix for the given namespace."""
178        if namespace is None:
179            return None
180        for inverted_nsmap in reversed(self.nsmaps):
181            if inverted_nsmap is not None and namespace in inverted_nsmap:
182                return inverted_nsmap[namespace]
183        return None
184
185    def end(self, name):
186        self.soup.endData()
187        completed_tag = self.soup.tagStack[-1]
188        namespace, name = self._getNsTag(name)
189        nsprefix = None
190        if namespace is not None:
191            for inverted_nsmap in reversed(self.nsmaps):
192                if inverted_nsmap is not None and namespace in inverted_nsmap:
193                    nsprefix = inverted_nsmap[namespace]
194                    break
195        self.soup.handle_endtag(name, nsprefix)
196        if len(self.nsmaps) > 1:
197            # This tag, or one of its parents, introduced a namespace
198            # mapping, so pop it off the stack.
199            self.nsmaps.pop()
200
201    def pi(self, target, data):
202        self.soup.endData()
203        self.soup.handle_data(target + ' ' + data)
204        self.soup.endData(ProcessingInstruction)
205
206    def data(self, content):
207        self.soup.handle_data(content)
208
209    def doctype(self, name, pubid, system):
210        self.soup.endData()
211        doctype = Doctype.for_name_and_ids(name, pubid, system)
212        self.soup.object_was_parsed(doctype)
213
214    def comment(self, content):
215        "Handle comments as Comment objects."
216        self.soup.endData()
217        self.soup.handle_data(content)
218        self.soup.endData(Comment)
219
220    def test_fragment_to_document(self, fragment):
221        """See `TreeBuilder`."""
222        return '<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment
223
224
225class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
226
227    NAME = LXML
228    ALTERNATE_NAMES = ["lxml-html"]
229
230    features = ALTERNATE_NAMES + [NAME, HTML, FAST, PERMISSIVE]
231    is_xml = False
232
233    def default_parser(self, encoding):
234        return etree.HTMLParser
235
236    def feed(self, markup):
237        encoding = self.soup.original_encoding
238        try:
239            self.parser = self.parser_for(encoding)
240            self.parser.feed(markup)
241            self.parser.close()
242        except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
243            raise ParserRejectedMarkup(str(e))
244
245
246    def test_fragment_to_document(self, fragment):
247        """See `TreeBuilder`."""
248        return '<html><body>%s</body></html>' % fragment
249