# coding: utf-8
import os
import struct
import requests
import traceback
import time
from logging import Logger
from supervisely.worker_api.retriers import retriers_from_cfg
from supervisely.io.network_exceptions import process_requests_exception, process_unhandled_request
[docs]
class AgentAPI:
"""Client for communicating with the Supervisely agent service (image/data downloads, logging, etc.)."""
def __init__(self, token, server_address: str, ext_logger: Logger, cfg_path=None):
"""
:param token: Auth token.
:type token: str
:param server_address: Agent server URL.
:type server_address: str
:param ext_logger: Logger.
:type ext_logger: Logger
:param cfg_path: Optional retrier config path.
:type cfg_path: str
"""
self.logger = ext_logger
self._base_server_adress = server_address
if ("http://" not in self._base_server_adress) and (
"https://" not in self._base_server_adress
):
self._base_server_adress = os.path.join("http://", self._base_server_adress)
self.server_address = os.path.join(self._base_server_adress, "agent")
self.headers = {
"Content-type": "application/octet-stream",
"Accept-Encoding": "deflate", # to override default 'Accept-Encoding': 'gzip, deflate'
}
if token is not None:
self.headers["x-token"] = token
self._retriers = retriers_from_cfg(cfg_path)
self._require_https_redirect_check = not self._base_server_adress.startswith("https://")
def _get_retrier(self, api_method_name, common_selector):
retrier = self._retriers.get(api_method_name)
if retrier is None:
retrier = self._retriers[common_selector]
return retrier
def add_to_metadata(self, key, value):
self.headers[key] = value
def rm_from_metadata(self, key):
self.headers.pop(key, None)
def _send_request(self, api_method_name, request_data, timeout, in_stream, addit_headers):
self._check_https_redirect()
url = os.path.join(self.server_address, api_method_name)
if not addit_headers:
addit_headers = {}
cur_header = {**self.headers, **addit_headers}
not_log_request = api_method_name != "Log"
server_reply = None
try:
server_reply = requests.post(
url, headers=cur_header, data=request_data, stream=in_stream, timeout=timeout
)
server_reply.raise_for_status()
except requests.RequestException as exc:
process_requests_exception(
self.logger,
exc,
api_method_name,
url,
verbose=not_log_request,
swallow_exc=False,
response=server_reply,
)
except Exception as exc:
process_unhandled_request(self.logger, exc)
return server_reply
# magic value 4 means four bytes for message length
# http://www.sureshjoshi.com/development/streaming-protocol-buffers/
# https://www.datadoghq.com/blog/engineering/protobuf-parsing-in-python/
def _get_input_stream(
self, api_method_name, res_proto_fn, request_data, timeout, addit_headers
):
def cut_len(msg_buff_):
cur_m_len = struct.unpack(">I", msg_buff_[0:4])[0]
return cur_m_len, msg_buff_[4:]
def append_to_msg_buffer(msg_len_, msg_buf_, rest_buf):
if msg_len_ > len(msg_buf_) + len(rest_buf):
msg_buf_ = msg_buf_ + rest_buf
rest_buf = b""
else:
tmp_cut_len = msg_len_ - len(msg_buf_)
msg_buf_ = msg_buf_ + rest_buf[0:tmp_cut_len]
rest_buf = rest_buf[tmp_cut_len:]
return msg_buf_, rest_buf
# request package crashes on empty chunks
try:
with self._send_request(
api_method_name, request_data, timeout, in_stream=True, addit_headers=addit_headers
) as reply:
msg_len = None
msg_buf = b""
for buffer in reply.iter_content(chunk_size=None):
while len(buffer) > 0:
if msg_len is None:
if len(msg_buf) < 4:
msg_buf, buffer = append_to_msg_buffer(4, msg_buf, buffer)
if len(msg_buf) < 4:
continue
msg_len, msg_buf = cut_len(msg_buf)
msg_buf, buffer = append_to_msg_buffer(msg_len, msg_buf, buffer)
if msg_len == len(msg_buf):
if msg_len == 0:
pass
else:
proto_msg = res_proto_fn()
proto_msg.ParseFromString(msg_buf)
yield proto_msg
msg_len = None
msg_buf = b""
if msg_len is not None:
raise RuntimeError("MISSED_STREAM_CHUNKS")
except requests.exceptions.ChunkedEncodingError:
raise RuntimeError("Unknown error during stream. Please contact support.")
except requests.RequestException:
raise
except Exception as e:
self.logger.error(traceback.format_exc(), exc_info=True, extra={"exc_str": str(e)})
raise e
def _put_out_stream(
self, api_method_name, res_proto_fn, chunk_generator, timeout, addit_headers
):
def bindata_generator():
for chunk in chunk_generator:
size = chunk.ByteSize()
res_bytes_with_len = struct.pack(">I", size) + chunk.SerializeToString()
yield res_bytes_with_len
resp = self._send_request(
api_method_name,
bindata_generator(),
timeout,
in_stream=False,
addit_headers=addit_headers,
)
res_proto = res_proto_fn()
res_proto.ParseFromString(resp.content)
return res_proto
# will not log it now
def simple_request(self, api_method_name, res_proto_fn, proto_request, addit_headers=None):
data_to_send = proto_request.SerializeToString()
retrier = self._get_retrier(api_method_name, "__simple_request")
resp = retrier.request(
self._send_request,
api_method_name,
data_to_send,
in_stream=False,
addit_headers=addit_headers,
)
if resp is None:
return None # swallowed exception
res_proto = res_proto_fn()
res_proto.ParseFromString(resp.content)
return res_proto
def get_stream_with_data(
self, api_method_name, res_proto_fn, proto_request, addit_headers=None
):
data_to_send = proto_request.SerializeToString()
retrier = self._get_retrier(api_method_name, "__data_stream_in")
yield from retrier.request(
self._get_input_stream,
api_method_name,
res_proto_fn,
data_to_send,
addit_headers=addit_headers,
)
def get_endless_stream(
self,
api_method_name,
res_proto_fn,
proto_request,
addit_headers=None,
server_fail_limit=10,
wait_server_sec=10,
):
data_to_send = proto_request.SerializeToString()
for attempt in range(server_fail_limit):
retrier = self._get_retrier(api_method_name, "__endless_stream_in")
yield from retrier.request(
self._get_input_stream,
api_method_name,
res_proto_fn,
data_to_send,
addit_headers=addit_headers,
)
self.logger.warning("Endless input stream end", extra={"method": api_method_name})
if attempt != server_fail_limit - 1:
time.sleep(wait_server_sec)
def put_stream_with_data(
self, api_method_name, res_proto_fn, chunk_generator, addit_headers=None
):
retrier = self._get_retrier(api_method_name, "__data_stream_out")
res = retrier.request(
self._put_out_stream,
api_method_name,
res_proto_fn,
chunk_generator,
addit_headers=addit_headers,
)
return res
@classmethod
def catch_http_err(cls, code_list, fn, *args, **kwargs):
try:
res = fn(*args, **kwargs)
return res, None
except requests.exceptions.HTTPError as e:
res_code = e.response.status_code
if res_code in code_list:
return None, res_code
raise
def _check_https_redirect(self):
if self._require_https_redirect_check is True:
response = requests.get(self._base_server_adress, allow_redirects=False)
if (300 <= response.status_code < 400) or (
response.headers.get("Location", "").startswith("https://")
):
self._base_server_adress = self._base_server_adress.replace("http://", "https://")
self.server_address = self.server_address.replace("http://", "https://")
msg = (
"You're using HTTP server address on agent while the server requires HTTPS. "
"Supervisely automatically changed the server address to HTTPS for you. "
f"Consider updating your server address to {self._base_server_adress}"
)
self.logger.warning(msg)
self._require_https_redirect_check = False