Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-10 08:39:07

0001 import decimal
0002 import pickle
0003 import sys
0004 from copyreg import _reconstructor as map__reconstructor
0005 from io import BytesIO
0006 
0007 
0008 class Common_Unpickler(pickle.Unpickler):
0009     def __setattr__(self, key, value):
0010         if key == "find_global":
0011             pickle.Unpickler.__setattr__(self, "find_class", value)
0012         else:
0013             pickle.Unpickler.__setattr__(self, key, value)
0014 
0015 
0016 # conversion for unserializable values
0017 def conversion_func(item):
0018     if isinstance(item, list):
0019         return [conversion_func(i) for i in item]
0020     if isinstance(item, dict):
0021         return {k: conversion_func(item[k]) for k in item}
0022     if isinstance(item, decimal.Decimal):
0023         if item == item.to_integral_value():
0024             item = int(item)
0025         else:
0026             item = float(item)
0027     return item
0028 
0029 
0030 # wrapper to avoid de-serializing unsafe objects
0031 class WrappedPickle(object):
0032     # allowed modules and classes
0033     allowedModClass = {
0034         "copy_reg": ["_reconstructor"],
0035         "__builtin__": ["object"],
0036         "datetime": ["datetime"],
0037         "taskbuffer.JobSpec": ["JobSpec"],
0038         "taskbuffer.FileSpec": ["FileSpec"],
0039         "pandaserver.taskbuffer.JobSpec": ["JobSpec"],
0040         "pandaserver.taskbuffer.FileSpec": ["FileSpec"],
0041     }
0042     # bare modules
0043     bareMods = {"taskbuffer.": "pandaserver."}
0044     # predefined class map
0045     predefined_class = {
0046         ("copy_reg", "_reconstructor"): map__reconstructor,
0047         ("__builtin__", "object"): object,
0048     }
0049 
0050     # check module and class
0051     @classmethod
0052     def find_class(cls, module, name):
0053         # append prefix to bare modules
0054         for bareMod in cls.bareMods:
0055             if module.startswith(bareMod):
0056                 module = cls.bareMods[bareMod] + module
0057                 break
0058         # check module
0059         if module not in cls.allowedModClass:
0060             raise pickle.UnpicklingError(f"Attempting to import disallowed module {module}")
0061         # return predefined class
0062         key = (module, name)
0063         if key in cls.predefined_class:
0064             return cls.predefined_class[key]
0065         # import module
0066         __import__(module)
0067         mod = sys.modules[module]
0068         # check class
0069         if name not in cls.allowedModClass[module]:
0070             raise pickle.UnpicklingError(f"Attempting to get disallowed class {name} in {module}")
0071         klass = getattr(mod, name)
0072         return klass
0073 
0074     # loads
0075     @classmethod
0076     def loads(cls, pickle_string):
0077         if isinstance(pickle_string, str):
0078             pickle_string = pickle_string.encode()
0079         pickle_obj = Common_Unpickler(BytesIO(pickle_string))
0080         pickle_obj.find_global = cls.find_class
0081         return pickle_obj.load()
0082 
0083     # dumps
0084     @classmethod
0085     def dumps(cls, obj, convert_to_safe=False):
0086         if convert_to_safe:
0087             obj = conversion_func(obj)
0088         return pickle.dumps(obj, protocol=0)