Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:58:20

0001 #!/usr/bin/env python
0002 #
0003 # Licensed under the Apache License, Version 2.0 (the "License");
0004 # You may not use this file except in compliance with the License.
0005 # You may obtain a copy of the License at
0006 # http://www.apache.org/licenses/LICENSE-2.0OA
0007 #
0008 # Authors:
0009 # - Wen Guan, <wen.guan@cern.ch>, 2019 - 2022
0010 
0011 import json
0012 from traceback import format_exc
0013 
0014 from flask import Blueprint
0015 
0016 from idds.common import exceptions
0017 from idds.common.constants import HTTP_STATUS_CODE
0018 from idds.common.constants import CollectionRelationType, ContentStatus
0019 from idds.core import catalog
0020 from idds.rest.v1.controller import IDDSController
0021 
0022 
0023 class HyperParameterOpt(IDDSController):
0024     """  get and update hyper parameters. """
0025 
0026     def put(self, workload_id, request_id, id, loss):
0027         """ Update the loss for the hyper parameter.
0028 
0029         HTTP Success:
0030             200 OK
0031         HTTP Error:
0032             400 Bad request
0033             404 Not Found
0034             500 Internal Error
0035         """
0036         try:
0037             if workload_id == 'null':
0038                 workload_id = None
0039             if request_id == 'null':
0040                 request_id = None
0041 
0042             if workload_id is None and request_id is None:
0043                 error = "One of workload_id and request_id should not be None or empty"
0044                 return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=exceptions.CoreException.__name__, exc_msg=error)
0045         except Exception as error:
0046             print(error)
0047             print(format_exc())
0048             return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=exceptions.CoreException.__name__, exc_msg=error)
0049 
0050         try:
0051             contents = catalog.get_contents(request_id=request_id, workload_id=workload_id,
0052                                             relation_type=CollectionRelationType.Output)
0053 
0054             if id:
0055                 new_contents = []
0056                 for content in contents:
0057                     if str(content['name']) == str(id):
0058                         new_contents.append(content)
0059                 contents = new_contents
0060             content = contents[0]
0061 
0062             loss = float(loss)
0063             content_id = content['content_id']
0064             point = content['path']
0065             param, origin_loss = json.loads(point)
0066             params = {'path': json.dumps((param, loss)),
0067                       'status': ContentStatus.Available,
0068                       'substatus': ContentStatus.Available}
0069             catalog.update_content(content_id, params)
0070         except exceptions.NoObject as error:
0071             return self.generate_http_response(HTTP_STATUS_CODE.NotFound, exc_cls=error.__class__.__name__, exc_msg=error)
0072         except exceptions.IDDSException as error:
0073             return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=error.__class__.__name__, exc_msg=error)
0074         except Exception as error:
0075             print(error)
0076             print(format_exc())
0077             return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=exceptions.CoreException.__name__, exc_msg=error)
0078 
0079         return self.generate_http_response(HTTP_STATUS_CODE.OK, data={'status': 0, 'message': 'update successfully'})
0080 
0081     def get(self, workload_id, request_id, id=None, status=None, limit=None):
0082         """ Get hyper parameters.
0083         :param request_id: The id of the request.
0084         :param status: status of the hyper parameters. None for all statuses.
0085         :param limit: Limit number of hyperparameters.
0086 
0087         HTTP Success:
0088             200 OK
0089         HTTP Error:
0090             404 Not Found
0091             500 InternalError
0092         :returns: list of hyper parameters.
0093         """
0094 
0095         try:
0096             if workload_id == 'null':
0097                 workload_id = None
0098             if request_id == 'null':
0099                 request_id = None
0100 
0101             if status == 'null':
0102                 status = None
0103             if limit == 'null':
0104                 limit = None
0105             if id == 'null':
0106                 id = None
0107 
0108             contents = catalog.get_contents(request_id=request_id, workload_id=workload_id,
0109                                             status=status, relation_type=CollectionRelationType.Output)
0110 
0111             if id:
0112                 new_contents = []
0113                 for content in contents:
0114                     if str(content['name']) == str(id):
0115                         new_contents.append(content)
0116                 contents = new_contents
0117 
0118             if contents and limit and len(contents) > limit:
0119                 contents = contents[:limit]
0120 
0121             hyperparameters = []
0122             for content in contents:
0123                 point = content['path']
0124                 parameter, loss = json.loads(point)
0125                 param = {'id': content['name'],
0126                          'parameters': parameter,
0127                          'loss': loss}
0128                 hyperparameters.append(param)
0129         except exceptions.NoObject as error:
0130             return self.generate_http_response(HTTP_STATUS_CODE.NotFound, exc_cls=error.__class__.__name__, exc_msg=error)
0131         except exceptions.IDDSException as error:
0132             return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=error.__class__.__name__, exc_msg=error)
0133         except Exception as error:
0134             print(error)
0135             print(format_exc())
0136             return self.generate_http_response(HTTP_STATUS_CODE.InternalError, exc_cls=exceptions.CoreException.__name__, exc_msg=error)
0137 
0138         return self.generate_http_response(HTTP_STATUS_CODE.OK, data=hyperparameters)
0139 
0140     def post_test(self):
0141         import pprint
0142         pprint.pprint(self.get_request())
0143         pprint.pprint(self.get_request().endpoint)
0144         pprint.pprint(self.get_request().url_rule)
0145 
0146 
0147 """----------------------
0148    Web service url maps
0149 ----------------------"""
0150 
0151 
0152 def get_blueprint():
0153     bp = Blueprint('hpo', __name__)
0154 
0155     hpo_view = HyperParameterOpt.as_view('hpo')
0156     bp.add_url_rule('/hpo/<workload_id>/<request_id>/<id>/<loss>', view_func=hpo_view, methods=['put', ])
0157     bp.add_url_rule('/hpo/<workload_id>/<request_id>/<id>/<status>/<limit>', view_func=hpo_view, methods=['get', ])
0158     return bp