zoukankan      html  css  js  c++  java
  • pytorch-第一章基本操作-强大的hub模块

    使用torch的hub模块载入模型,输入数据进行模型的结果输出,对输出的结果做可视化处理 

    ## GITHUB https://github.com/pytorch/hub 
    
    import torch 
    
    model = torch.hub.load('pytorch/vision:v0.4.2', 'deeplabv3_resnet101', pretrained=True)
    model.eval() 
    
    print(torch.hub.list('pytorch/vision:v0.4.2'))
    
    #数据载入,获得图片 
    import urllib 
    url, filename = ("https://github.com/pytorch/hub/raw/master/dog.jpg", "dog.jpg")
    try:
        urllib.URLopener().retrieve(url, filename)
    except:
        urllib.request.urlretrieve(url, filename)
    
    from PIL import Image 
    from torchvision import transforms 
    
    input_image = Image.open(filename)
    #构建处理图片的函数 
    preprocess = transforms.Compose(
        [
            transforms.ToTensor(), 
            transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]), 
        ] 
    )
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # 产生一个样本 
    
    if torch.cuda.is_available():
        input_batch = input_batch.to("cuda")
        model.to("cuda")
    
    with torch.no_grad():
        output = model(input_batch)['out'][0]
    
    output_predictions = output.argmax(0)
    
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * colors 
    colors = (colors % 255).numpy().astype("uint8")
    
    r = Image.fromarray(output_predictions.bytes().cpu().numpy()).resize(input_image.size)
    r.putpalette(colors)
    
    import matplotlib.pyplot as plt 
    plt.show(r)
  • 相关阅读:
    C# .net页面乱码
    Spring Cloud 微服务三: API网关Spring cloud gateway
    Spring Cloud 微服务二:API网关spring cloud zuul
    Spring Cloud 微服务一:Consul注册中心
    Log4j2升级jar包冲突问题
    Log4j2配置
    opensearch空查询
    阿里云Opensearch数据类型
    Spring mybatis自动扫描dao
    【EDAS问题】轻量级EDAS部署hsf服务出现找不到类的解决方案
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/12650573.html
Copyright © 2011-2022 走看看