Client#
- class deepinv.models.Client(endpoint, api_key='', return_metadata=False)[source]#
Bases:
Reconstructor
,Denoiser
DeepInverse model API Client.
Perform inference on models hosted in the cloud directly from DeepInverse.
This functionality allows contributors to develop APIs to disseminate their reconstruction models, without requiring the client user to host the model themselves or to accurately define their physics. As an API developer, all you have to do is:
Define your model to take tensors as input and output tensors (like
deepinv.models.Reconstructor
)Create a simple API (see below for example)
Deploy it to the cloud, and distribute the endpoint URL and API keys to anyone who might want to use it!
The user then only needs to define this client, specify the endpoint URL and API key, and pass in an image as a tensor.
Warning
This feature is experimental. Its interface and behavior may change without notice in future releases. Use with caution in production workflows.
- Example:
import deepinv as dinv import torch y = torch.tensor([...]) # Your measurements model = dinv.models.Client("<ENDPOINT>", "<API_KEY>") x_hat = model(y, physics="denoising")
Create your own API: In order to develop an API to be compatible with this client:
Since we cannot pass objects via the API, physics are passed as strings with optional parameters and must be rebuilt in the API.
The API must accept the following input body:
{ "input": { "file": <b64 serialized file>, "metadata": { "param1": "such as a config str", "param2": <or a b64 serialized param>, ... }, } }
The API must return the following output response:
{ "output": { "file": "<b64 serialized file>", "metadata": { "other_outputs": "such as inference time", } } }
During forward pass, the client passes input tensor serialized as base64 to API, along with any optional params, which must either be plain text, numbers, or serializable, depending on the API input requirements, such as
physics
string,config
,sigma
,mask
etc.The API can be developed and deployed on any platform you prefer, e.g. server, containers, or functions. See below for some simple examples.
Note
Authentication is handled at the application level via the API key by default. However, you may also choose to enforce authentication or rate-limiting at an upstream layer (e.g. an nginx reverse proxy or API gateway) if preferred.
Warning
Security is critical when exposing models via Web APIs. Always use HTTPS, validate and sanitize inputs, and restrict access with strong API keys or authentication mechanisms. Consider rate-limiting and monitoring to reduce attack surface.
- Example:
Simple server using Flask
from flask import Flask, request, jsonify from deepinv.models import Client app = Flask(__name__) model = ... # Your DeepInverse model @app.route("/", methods=["POST"]) def infer(): inp = request.get_json()["input"] y = Client.deserialize(inp["file"]) physics = ... # Create physics depending on metadata x_hat = model(y, physics) # Server-side inference return jsonify({ "output": { "file": Client.serialize(x_hat) } }) if __name__ == "__main__": app.run()
Serverless container using RunPod
import runpod from deepinv.models import Client model = ... # Your DeepInverse model def handler(event): inp = event['input'] y = Client.deserialize(inp["file"]) physics = ... # Create physics depending on metadata x_hat = model(y, physics) # Server-side inference return { "output": { "file": Client.serialize(x_hat) } } if __name__ == '__main__': runpod.serverless.start({'handler': handler })
- Parameters:
- static deserialize(data)[source]#
Helper function to deserialize client outputs.
The media type for the pickled documents is expected to be
application/octet-stream
.- Parameters:
data (str) – input serialized using
serialize()
- Returns:
torch.Tensor deserialized Tensor
- Return type:
- forward(y, **kwargs)[source]#
Client model forward pass.
- Parameters:
y (torch.Tensor) – input measurements tensor
kwargs – any optional params depending on the API input requirements e.g.
physics
string,config
,sigma
,mask
etc.
- Returns:
torch.Tensor output reconstruction tensor
- Return type:
- static serialize(tensor)[source]#
Helper function to serialize client inputs.
Instances of torch.Tensor are serialized by first pickling them using
torch.save()
and then returning a URI pointing to the pickle file. For now, only data URIs are supported, but in the future short-lived URLs may also be supported.- Parameters:
tensor (torch.Tensor) – input tensor
- Returns:
tensor serialized as base64 string
- Return type: